SpatialEncoder3d

class dgl.nn.pytorch.gt.SpatialEncoder3d(num_kernels, num_heads=1, max_node_type=100)[源代码]

基类: Module

3D 空间编码器,如 One Transformer Can Understand Both 2D & 3D Molecular Data 中介绍的那样

此模块根据高斯基核函数对 3D 几何空间中节点对 \((i,j)\) 之间的成对关系进行编码

\(\psi _{(i,j)} ^k = \frac{1}{\sqrt{2\pi} \lvert \sigma^k \rvert} \exp{\left ( -\frac{1}{2} \left( \frac{\gamma_{(i,j)} \lvert \lvert r_i - r_j \rvert \rvert + \beta_{(i,j)} - \mu^k}{\lvert \sigma^k \rvert} \right) ^2 \right)},k=1,...,K,\)

其中 \(K\) 是高斯基核的数量。 \(r_i\) 是节点 \(i\) 的笛卡尔坐标。 \(\gamma_{(i,j)}, \beta_{(i,j)}\) 是由节点类型决定的可学习缩放因子和偏差。 \(\mu^k, \sigma^k\) 是高斯基核的可学习中心和标准差。

参数:
  • num_kernels (int) – 要应用的高斯基核的数量。每个高斯基核包含一个可学习的核中心和一个可学习的标准差。

  • num_heads (int, 可选) – 如果应用多头注意力机制,则为注意力头的数量。默认值:1。

  • max_node_type (int, 可选) – 最大节点类型数量。每种节点类型都有相应的可学习缩放因子和偏差。默认值:100。

示例

>>> import torch as th
>>> import dgl
>>> from dgl.nn import SpatialEncoder3d
>>> coordinate = th.rand(1, 4, 3)
>>> node_type = th.tensor([[1, 0, 2, 1]])
>>> spatial_encoder = SpatialEncoder3d(num_kernels=4,
...                                    num_heads=8,
...                                    max_node_type=3)
>>> out = spatial_encoder(coordinate, node_type=node_type)
>>> print(out.shape)
torch.Size([1, 4, 4, 8])
forward(coord, node_type=None)[源代码]
参数:
  • coord (torch.Tensor) – 节点的 3D 坐标,形状为 \((B, N, 3)\),其中 \(B\) 是批量大小,\(N\) 是最大节点数量。

  • node_type (torch.Tensor, 可选) –

    节点的节点类型 ID。默认值:None。

    • 如果指定,node_type 应为形状为 \((B, N,)\) 的张量。每对节点在高斯核中的缩放因子由其节点类型决定。

    • 否则,node_type 将默认为相同形状的零张量。

返回值:

返回注意力偏差作为 3D 空间编码,形状为 \((B, N, N, H)\),其中 \(H\)num_heads

返回类型:

torch.Tensor