CFConv
- class dgl.nn.pytorch.conv.CFConv(node_in_feats, edge_in_feats, hidden_feats, out_feats)[源码]
基类:
Module
来自SchNet: 用于模拟量子相互作用的连续滤波器卷积神经网络 的 CFConv
它在消息传递中结合节点和边特征,并更新节点表示。
其中
表示逐元素乘法,对于- 参数:
示例
>>> import dgl >>> import numpy as np >>> import torch as th >>> from dgl.nn import CFConv >>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3])) >>> nfeat = th.ones(6, 10) >>> efeat = th.ones(6, 5) >>> conv = CFConv(10, 5, 3, 2) >>> res = conv(g, nfeat, efeat) >>> res tensor([[-0.1209, -0.2289], [-0.1209, -0.2289], [-0.1209, -0.2289], [-0.1135, -0.2338], [-0.1209, -0.2289], [-0.1283, -0.2240]], grad_fn=<SubBackward0>)
- forward(g, node_feats, edge_feats)[源码]
描述
执行消息传递并更新节点表示。
- 参数 g:
图。
- 类型 g:
DGLGraph
- 参数 node_feats:
输入节点特征。如果给定一个 torch.Tensor,它表示输入节点特征,形状为
,其中 是输入特征大小, 是节点数量。如果是二分图,则给定一对 torch.Tensor,这对张量必须分别包含形状为 和 的张量,分别对应源节点和目标节点。- 类型 node_feats:
torch.Tensor 或 torch.Tensor 对
- 参数 edge_feats:
输入边特征,形状为
,其中 是边的数量。- 类型 edge_feats:
torch.Tensor
- 返回:
输出节点特征,形状为
,其中 是目标节点的数量。- 返回类型:
torch.Tensor