HeteroGraphConv
- class dgl.nn.pytorch.HeteroGraphConv(mods, aggregate='sum')[source]
基类:
Module
一个用于在异构图上进行卷积计算的通用模块。
异构图卷积将子模块应用于其关联的关系图,从源节点读取特征并将其更新到目标节点。如果多个关系具有相同的目标节点类型,则其结果将通过指定的方法进行聚合。如果关系图没有边,则不会调用相应的模块。
伪代码
outputs = {nty : [] for nty in g.dsttypes} # Apply sub-modules on their associating relation graphs in parallel for relation in g.canonical_etypes: stype, etype, dtype = relation dstdata = relation_submodule(g[relation], ...) outputs[dtype].append(dstdata) # Aggregate the results for each destination node type rsts = {} for ntype, ntype_outputs in outputs.items(): if len(ntype_outputs) != 0: rsts[ntype] = aggregate(ntype_outputs) return rsts
示例
创建一个包含三种关系类型和节点类型的异构图。
>>> import dgl >>> g = dgl.heterograph({ ... ('user', 'follows', 'user') : edges1, ... ('user', 'plays', 'game') : edges2, ... ('store', 'sells', 'game') : edges3})
创建一个
HeteroGraphConv
,它对不同的关系应用不同的卷积模块。注意,'follows'
和'plays'
关系对应的模块不共享权重。>>> import dgl.nn.pytorch as dglnn >>> conv = dglnn.HeteroGraphConv({ ... 'follows' : dglnn.GraphConv(...), ... 'plays' : dglnn.GraphConv(...), ... 'sells' : dglnn.SAGEConv(...)}, ... aggregate='sum')
使用一些
'user'
特征调用 forward。这将计算'user'
和'game'
节点的最新特征。>>> import torch as th >>> h1 = {'user' : th.randn((g.num_nodes('user'), 5))} >>> h2 = conv(g, h1) >>> print(h2.keys()) dict_keys(['user', 'game'])
同时使用
'user'
和'store'
特征调用 forward。由于'plays'
和'sells'
关系都会更新'game'
特征,因此它们的结果会通过指定的方法(此处为求和)进行聚合。>>> f1 = {'user' : ..., 'store' : ...} >>> f2 = conv(g, f1) >>> print(f2.keys()) dict_keys(['user', 'game'])
使用一些
'store'
特征调用 forward。这只计算'game'
节点的最新特征。>>> g1 = {'store' : ...} >>> g2 = conv(g, g1) >>> print(g2.keys()) dict_keys(['game'])
允许使用一对输入调用 forward,并且每个子模块也将使用一对输入进行调用。
>>> x_src = {'user' : ..., 'store' : ...} >>> x_dst = {'user' : ..., 'game' : ...} >>> y_dst = conv(g, (x_src, x_dst)) >>> print(y_dst.keys()) dict_keys(['user', 'game'])
- 参数:
mods (dict[str, nn.Module]) – 与每种边类型关联的模块。每个模块的 forward 函数必须将一个 DGLGraph 对象作为第一个参数,其第二个参数可以是表示节点特征的张量对象,或表示源节点和目标节点特征的张量对象对。
aggregate (str, callable, optional) –
聚合不同关系生成的节点特征的方法。允许的字符串值为 ‘sum’, ‘max’, ‘min’, ‘mean’, ‘stack’。‘stack’ 聚合沿第二维执行,其顺序是确定的。用户还可以通过提供可调用实例来自定义聚合器。例如,求和聚合等效于以下代码
def my_agg_func(tensors, dsttype): # tensors: is a list of tensors to aggregate # dsttype: string name of the destination node type for which the # aggregation is performed stacked = torch.stack(tensors, dim=0) return torch.sum(stacked, dim=0)