AtomicConv

class dgl.nn.pytorch.conv.AtomicConv(interaction_cutoffs, rbf_kernel_means, rbf_kernel_scaling, features_to_use=None)[source]

基类: Module

来自 Atomic Convolutional Networks for Predicting Protein-Ligand Binding Affinity 的原子卷积层

\(z_i\) 表示原子 \(i\) 的类型,用 \(r_{ij}\) 表示原子 \(i\)\(j\) 之间的距离。

距离转换

原子卷积层首先使用径向滤波器转换距离,然后执行池化操作。

对于索引为 \(k\) 的径向滤波器,它将边距离投影为

\[h_{ij}^{k} = \exp(-\gamma_{k}|r_{ij}-r_{k}|^2)\]

如果 \(r_{ij} < c_k\)

\[f_{ij}^{k} = 0.5 * \cos(\frac{\pi r_{ij}}{c_k} + 1),\]

否则,

\[f_{ij}^{k} = 0.\]

最后,

\[e_{ij}^{k} = h_{ij}^{k} * f_{ij}^{k}\]

聚合

对于每种类型 \(t\),每个原子从所有类型为 \(t\) 的邻居原子收集距离信息

\[p_{i, t}^{k} = \sum_{j\in N(i)} e_{ij}^{k} * 1(z_j == t)\]

然后连接所有 RBF 核和原子类型的结果。

参数:
  • interaction_cutoffs (形状为 (K) 的 float32 tensor) – 上述公式中的 \(c_k\)。大致可以视为可学习的截止距离,如果两个原子之间的距离小于该截止距离,则认为它们是连接的。K 表示径向滤波器的数量。

  • rbf_kernel_means (形状为 (K) 的 float32 tensor) – 上述公式中的 \(r_k\)。K 表示径向滤波器的数量。

  • rbf_kernel_scaling (形状为 (K) 的 float32 tensor) – 上述公式中的 \(\gamma_k\)。K 表示径向滤波器的数量。

  • features_to_use (None形状为 (T) 的 float tensor) – 在原始论文中,这些是要考虑的原子序数,代表原子的类型。T 表示原子序数的类型数量。默认为 None。

注意

  • 此卷积操作是为化学中的分子图设计的,但可能可以将其扩展到更通用的图。

  • 论文和作者实现中关于 \(e_{ij}^{k}\) 的定义似乎存在不一致之处。我们遵循作者的实现。在论文中,\(e_{ij}^{k}\) 被定义为 \(\exp(-\gamma_{k}|r_{ij}-r_{k}|^2 * f_{ij}^{k})\)

  • \(\gamma_{k}\)\(r_k\)\(c_k\) 都是可学习的。

示例

>>> import dgl
>>> import numpy as np
>>> import torch as th
>>> from dgl.nn import AtomicConv
>>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))
>>> feat = th.ones(6, 1)
>>> edist = th.ones(6, 1)
>>> interaction_cutoffs = th.ones(3).float() * 2
>>> rbf_kernel_means = th.ones(3).float()
>>> rbf_kernel_scaling = th.ones(3).float()
>>> conv = AtomicConv(interaction_cutoffs, rbf_kernel_means, rbf_kernel_scaling)
>>> res = conv(g, feat, edist)
>>> res
tensor([[0.5000, 0.5000, 0.5000],
            [0.5000, 0.5000, 0.5000],
            [0.5000, 0.5000, 0.5000],
            [1.0000, 1.0000, 1.0000],
            [0.5000, 0.5000, 0.5000],
            [0.0000, 0.0000, 0.0000]], grad_fn=<ViewBackward>)
forward(graph, feat, distances)[source]

描述

应用原子卷积层。

参数 graph:

执行消息传递所基于的拓扑结构。

类型 graph:

DGLGraph

参数 feat:

初始节点特征,在论文中为原子序数。\(V\) 表示节点数量。

类型 feat:

形状为 \((V, 1)\) 的 Float32 tensor

参数 distances:

边的端节点之间的距离。E 表示边的数量。

类型 distances:

形状为 \((E, 1)\) 的 Float32 tensor

返回:

更新后的节点表示。\(V\) 表示节点数量,\(K\) 表示径向滤波器的数量,\(T\) 表示原子序数的类型数量。

返回类型:

形状为 \((V, K * T)\) 的 Float32 tensor