3.3 异构图 GraphConv 模块
HeteroGraphConv
是用于在异构图上运行 DGL NN 模块的模块级封装。其实现逻辑与消息传递级 API multi_update_all()
相同,包括
每个关系
内的 DGL NN 模块。将来自多个关系的同一节点类型的结果合并的归约操作。
这可以表达为
其中
HeteroGraphConv 实现逻辑:
import torch.nn as nn
class HeteroGraphConv(nn.Module):
def __init__(self, mods, aggregate='sum'):
super(HeteroGraphConv, self).__init__()
self.mods = nn.ModuleDict(mods)
if isinstance(aggregate, str):
# An internal function to get common aggregation functions
self.agg_fn = get_aggregate_fn(aggregate)
else:
self.agg_fn = aggregate
异构图卷积接收一个字典 mods
,该字典将每个关系映射到一个 nn 模块,并设置将来自多个关系的同一节点类型的结果进行聚合的函数。
def forward(self, g, inputs, mod_args=None, mod_kwargs=None):
if mod_args is None:
mod_args = {}
if mod_kwargs is None:
mod_kwargs = {}
outputs = {nty : [] for nty in g.dsttypes}
除了输入图和输入张量,forward()
函数还接受两个额外的字典参数 mod_args
和 mod_kwargs
。这两个字典与 self.mods
具有相同的键。它们在调用 self.mods
中对应不同关系类型的 NN 模块时用作定制参数。
创建一个输出字典,用于存储每个目标类型 nty
的输出张量。请注意,每个 nty
的值是一个列表,表示如果多个关系以 nty
作为目标类型,则单个节点类型可能会获得多个输出。HeteroGraphConv
将对这些列表执行进一步的聚合。
if g.is_block:
src_inputs = inputs
dst_inputs = {k: v[:g.number_of_dst_nodes(k)] for k, v in inputs.items()}
else:
src_inputs = dst_inputs = inputs
for stype, etype, dtype in g.canonical_etypes:
rel_graph = g[stype, etype, dtype]
if rel_graph.num_edges() == 0:
continue
if stype not in src_inputs or dtype not in dst_inputs:
continue
dstdata = self.mods[etype](
rel_graph,
(src_inputs[stype], dst_inputs[dtype]),
*mod_args.get(etype, ()),
**mod_kwargs.get(etype, {}))
outputs[dtype].append(dstdata)
输入 g
可以是异构图,也可以是异构图的子图块。与普通 NN 模块一样,forward()
函数需要分别处理不同的输入图类型。
每个关系表示为一个 canonical_etype
,即 (stype, etype, dtype)
。使用 canonical_etype
作为键,可以提取出一个二分图 rel_graph
。对于二分图,输入特征将组织为一个元组 (src_inputs[stype], dst_inputs[dtype])
。将调用每个关系的 NN 模块并保存输出。为避免不必要的调用,没有边的关系或没有源类型节点的关系将被跳过。
rsts = {}
for nty, alist in outputs.items():
if len(alist) != 0:
rsts[nty] = self.agg_fn(alist, nty)
最后,使用 self.agg_fn
函数聚合来自多个关系的同一目标节点类型的结果。示例可以在 HeteroGraphConv
的 API 文档中找到。