CFConv

class dgl.nn.pytorch.conv.CFConv(node_in_feats, edge_in_feats, hidden_feats, out_feats)[源码]

基类: Module

来自SchNet: 用于模拟量子相互作用的连续滤波器卷积神经网络 的 CFConv

它在消息传递中结合节点和边特征,并更新节点表示。

hi(l+1)=jN(i)hjlW(l)eij

其中表示逐元素乘法,对于SPP

SSP(x)=1βlog(1+exp(βx))log(shift)
参数:
  • node_in_feats (int) – 输入节点特征hj(l)的大小。

  • edge_in_feats (int) – 输入边特征eij的大小。

  • hidden_feats (int) – 隐藏表示的大小。

  • out_feats (int) – 输出表示hj(l+1)的大小。

示例

>>> import dgl
>>> import numpy as np
>>> import torch as th
>>> from dgl.nn import CFConv
>>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))
>>> nfeat = th.ones(6, 10)
>>> efeat = th.ones(6, 5)
>>> conv = CFConv(10, 5, 3, 2)
>>> res = conv(g, nfeat, efeat)
>>> res
tensor([[-0.1209, -0.2289],
        [-0.1209, -0.2289],
        [-0.1209, -0.2289],
        [-0.1135, -0.2338],
        [-0.1209, -0.2289],
        [-0.1283, -0.2240]], grad_fn=<SubBackward0>)
forward(g, node_feats, edge_feats)[源码]

描述

执行消息传递并更新节点表示。

参数 g:

图。

类型 g:

DGLGraph

参数 node_feats:

输入节点特征。如果给定一个 torch.Tensor,它表示输入节点特征,形状为 (N,Din),其中Din是输入特征大小,N是节点数量。如果是二分图,则给定一对 torch.Tensor,这对张量必须分别包含形状为(Nsrc,Dinsrc)(Ndst,Dindst)的张量,分别对应源节点和目标节点。

类型 node_feats:

torch.Tensor 或 torch.Tensor 对

参数 edge_feats:

输入边特征,形状为(E,edgeinfeats),其中E是边的数量。

类型 edge_feats:

torch.Tensor

返回:

输出节点特征,形状为(Nout,outfeats),其中Nout是目标节点的数量。

返回类型:

torch.Tensor