AtomicConv
- class dgl.nn.pytorch.conv.AtomicConv(interaction_cutoffs, rbf_kernel_means, rbf_kernel_scaling, features_to_use=None)[source]
基类:
Module
来自 Atomic Convolutional Networks for Predicting Protein-Ligand Binding Affinity 的原子卷积层
用 \(z_i\) 表示原子 \(i\) 的类型,用 \(r_{ij}\) 表示原子 \(i\) 和 \(j\) 之间的距离。
距离转换
原子卷积层首先使用径向滤波器转换距离,然后执行池化操作。
对于索引为 \(k\) 的径向滤波器,它将边距离投影为
\[h_{ij}^{k} = \exp(-\gamma_{k}|r_{ij}-r_{k}|^2)\]如果 \(r_{ij} < c_k\),
\[f_{ij}^{k} = 0.5 * \cos(\frac{\pi r_{ij}}{c_k} + 1),\]否则,
\[f_{ij}^{k} = 0.\]最后,
\[e_{ij}^{k} = h_{ij}^{k} * f_{ij}^{k}\]聚合
对于每种类型 \(t\),每个原子从所有类型为 \(t\) 的邻居原子收集距离信息
\[p_{i, t}^{k} = \sum_{j\in N(i)} e_{ij}^{k} * 1(z_j == t)\]然后连接所有 RBF 核和原子类型的结果。
- 参数:
interaction_cutoffs (形状为 (K) 的 float32 tensor) – 上述公式中的 \(c_k\)。大致可以视为可学习的截止距离,如果两个原子之间的距离小于该截止距离,则认为它们是连接的。K 表示径向滤波器的数量。
rbf_kernel_means (形状为 (K) 的 float32 tensor) – 上述公式中的 \(r_k\)。K 表示径向滤波器的数量。
rbf_kernel_scaling (形状为 (K) 的 float32 tensor) – 上述公式中的 \(\gamma_k\)。K 表示径向滤波器的数量。
features_to_use (None 或 形状为 (T) 的 float tensor) – 在原始论文中,这些是要考虑的原子序数,代表原子的类型。T 表示原子序数的类型数量。默认为 None。
注意
此卷积操作是为化学中的分子图设计的,但可能可以将其扩展到更通用的图。
论文和作者实现中关于 \(e_{ij}^{k}\) 的定义似乎存在不一致之处。我们遵循作者的实现。在论文中,\(e_{ij}^{k}\) 被定义为 \(\exp(-\gamma_{k}|r_{ij}-r_{k}|^2 * f_{ij}^{k})\)。
\(\gamma_{k}\)、\(r_k\) 和 \(c_k\) 都是可学习的。
示例
>>> import dgl >>> import numpy as np >>> import torch as th >>> from dgl.nn import AtomicConv
>>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3])) >>> feat = th.ones(6, 1) >>> edist = th.ones(6, 1) >>> interaction_cutoffs = th.ones(3).float() * 2 >>> rbf_kernel_means = th.ones(3).float() >>> rbf_kernel_scaling = th.ones(3).float() >>> conv = AtomicConv(interaction_cutoffs, rbf_kernel_means, rbf_kernel_scaling) >>> res = conv(g, feat, edist) >>> res tensor([[0.5000, 0.5000, 0.5000], [0.5000, 0.5000, 0.5000], [0.5000, 0.5000, 0.5000], [1.0000, 1.0000, 1.0000], [0.5000, 0.5000, 0.5000], [0.0000, 0.0000, 0.0000]], grad_fn=<ViewBackward>)
- forward(graph, feat, distances)[source]
描述
应用原子卷积层。
- 参数 graph:
执行消息传递所基于的拓扑结构。
- 类型 graph:
DGLGraph
- 参数 feat:
初始节点特征,在论文中为原子序数。\(V\) 表示节点数量。
- 类型 feat:
形状为 \((V, 1)\) 的 Float32 tensor
- 参数 distances:
边的端节点之间的距离。E 表示边的数量。
- 类型 distances:
形状为 \((E, 1)\) 的 Float32 tensor
- 返回:
更新后的节点表示。\(V\) 表示节点数量,\(K\) 表示径向滤波器的数量,\(T\) 表示原子序数的类型数量。
- 返回类型:
形状为 \((V, K * T)\) 的 Float32 tensor