HeteroGraphConv

class dgl.nn.pytorch.HeteroGraphConv(mods, aggregate='sum')[source]

基类: Module

一个用于在异构图上进行卷积计算的通用模块。

异构图卷积将子模块应用于其关联的关系图,从源节点读取特征并将其更新到目标节点。如果多个关系具有相同的目标节点类型,则其结果将通过指定的方法进行聚合。如果关系图没有边,则不会调用相应的模块。

伪代码

outputs = {nty : [] for nty in g.dsttypes}
# Apply sub-modules on their associating relation graphs in parallel
for relation in g.canonical_etypes:
    stype, etype, dtype = relation
    dstdata = relation_submodule(g[relation], ...)
    outputs[dtype].append(dstdata)

# Aggregate the results for each destination node type
rsts = {}
for ntype, ntype_outputs in outputs.items():
    if len(ntype_outputs) != 0:
        rsts[ntype] = aggregate(ntype_outputs)
return rsts

示例

创建一个包含三种关系类型和节点类型的异构图。

>>> import dgl
>>> g = dgl.heterograph({
...     ('user', 'follows', 'user') : edges1,
...     ('user', 'plays', 'game') : edges2,
...     ('store', 'sells', 'game')  : edges3})

创建一个 HeteroGraphConv,它对不同的关系应用不同的卷积模块。注意,'follows''plays' 关系对应的模块不共享权重。

>>> import dgl.nn.pytorch as dglnn
>>> conv = dglnn.HeteroGraphConv({
...     'follows' : dglnn.GraphConv(...),
...     'plays' : dglnn.GraphConv(...),
...     'sells' : dglnn.SAGEConv(...)},
...     aggregate='sum')

使用一些 'user' 特征调用 forward。这将计算 'user''game' 节点的最新特征。

>>> import torch as th
>>> h1 = {'user' : th.randn((g.num_nodes('user'), 5))}
>>> h2 = conv(g, h1)
>>> print(h2.keys())
dict_keys(['user', 'game'])

同时使用 'user''store' 特征调用 forward。由于 'plays''sells' 关系都会更新 'game' 特征,因此它们的结果会通过指定的方法(此处为求和)进行聚合。

>>> f1 = {'user' : ..., 'store' : ...}
>>> f2 = conv(g, f1)
>>> print(f2.keys())
dict_keys(['user', 'game'])

使用一些 'store' 特征调用 forward。这只计算 'game' 节点的最新特征。

>>> g1 = {'store' : ...}
>>> g2 = conv(g, g1)
>>> print(g2.keys())
dict_keys(['game'])

允许使用一对输入调用 forward,并且每个子模块也将使用一对输入进行调用。

>>> x_src = {'user' : ..., 'store' : ...}
>>> x_dst = {'user' : ..., 'game' : ...}
>>> y_dst = conv(g, (x_src, x_dst))
>>> print(y_dst.keys())
dict_keys(['user', 'game'])
参数:
  • mods (dict[str, nn.Module]) – 与每种边类型关联的模块。每个模块的 forward 函数必须将一个 DGLGraph 对象作为第一个参数,其第二个参数可以是表示节点特征的张量对象,或表示源节点和目标节点特征的张量对象对。

  • aggregate (str, callable, optional) –

    聚合不同关系生成的节点特征的方法。允许的字符串值为 ‘sum’, ‘max’, ‘min’, ‘mean’, ‘stack’。‘stack’ 聚合沿第二维执行,其顺序是确定的。用户还可以通过提供可调用实例来自定义聚合器。例如,求和聚合等效于以下代码

    def my_agg_func(tensors, dsttype):
        # tensors: is a list of tensors to aggregate
        # dsttype: string name of the destination node type for which the
        #          aggregation is performed
        stacked = torch.stack(tensors, dim=0)
        return torch.sum(stacked, dim=0)
    

mods

与每种边类型关联的模块。

类型:

dict[str, nn.Module]

forward(g, inputs, mod_args=None, mod_kwargs=None)[source]

前向计算

调用每个模块的 forward 函数并聚合其结果。

参数:
  • g (DGLGraph) – 图数据。

  • inputs (dict[str, Tensor] 或 pair of dict[str, Tensor]) – 输入节点特征。

  • mod_args (dict[str, tuple[any]], optional) – 子模块的额外位置参数。

  • mod_kwargs (dict[str, dict[str, any]], optional) – 子模块的额外关键字参数。

返回:

每种节点类型的输出表示。

返回类型:

dict[str, Tensor]