dgl.readout_nodes
- dgl.readout_nodes(graph, feat, weight=None, *, op='sum', ntype=None)[source]
通过聚合节点特征
feat
生成图级别的表示。该函数通常用作批处理图上的 readout 函数,以生成图级别的表示。因此,结果张量的形状取决于输入图的批处理大小。给定一个批处理大小为 \(B\) 且特征大小为 \(D\) 的图,结果形状将是 \((B, D)\),其中每一行是每个图的聚合节点特征。
- 参数:
- 返回:
结果张量。
- 返回类型:
张量
示例
>>> import dgl >>> import torch as th
创建两个
DGLGraph
对象并初始化它们的节点特征。>>> g1 = dgl.graph(([0, 1], [1, 0])) # Graph 1 >>> g1.ndata['h'] = th.tensor([1., 2.]) >>> g2 = dgl.graph(([0, 1], [1, 2])) # Graph 2 >>> g2.ndata['h'] = th.tensor([1., 2., 3.])
在一个图上求和
>>> dgl.readout_nodes(g1, 'h') tensor([3.]) # 1 + 2
在一个批处理图上求和
>>> bg = dgl.batch([g1, g2]) >>> dgl.readout_nodes(bg, 'h') tensor([3., 6.]) # [1 + 2, 1 + 2 + 3]
加权求和
>>> bg.ndata['w'] = th.tensor([.1, .2, .1, .5, .2]) >>> dgl.readout_nodes(bg, 'h', 'w') tensor([.5, 1.7])
按最大值聚合
>>> dgl.readout_nodes(bg, 'h', op='max') tensor([2., 3.])