dgl.ops

用于图上消息传递的框架无关算子。

GSpMM 函数

广义稀疏矩阵稠密矩阵乘法函数。它将两个步骤融合成一个核。

  1. 通过对源节点和边特征进行加/减/乘/除运算,或将节点特征复制到边上,来计算消息。

  2. 通过求和/最大值/最小值/平均值将消息聚合成目标节点上的特征。

我们的实现支持 PyTorch/MXNet/Tensorflow 中 CPU/GPU 上的张量作为输入。所有算子都支持自动微分 (根据输出梯度计算输入梯度) 和广播 (如果操作数特征形状不匹配,我们会先将它们广播到相同形状,然后应用二元算子)。我们的广播语义遵循 NumPy,更多详情请参见 https://docs.scipy.org.cn/doc/numpy/user/basics.broadcasting.html

我们所说的融合是指消息不会在边上具体化,而是直接在目标节点上计算结果,从而节省内存成本。GSpMM 算子的空间复杂度为 \(O(|N|D)\),其中 \(|N|\) 是图中节点的数量,\(D\) 是特征大小 (如果你的特征是多维张量,则 \(D=\prod_{i=1}^{N}D_i\))。

以下是展示 GSpMM 工作原理的示例 (这里我们使用 PyTorch 作为后端,您可以在其他框架上通过类似用法享受同样的便利)

>>> import dgl
>>> import torch as th
>>> import dgl.ops as F
>>> g = dgl.graph(([0, 0, 0, 1, 1, 2], [0, 1, 2, 1, 2, 2]))  # 3 nodes, 6 edges
>>> x = th.ones(3, 2, requires_grad=True)
>>> x
tensor([[1., 1.],
        [1., 1.],
        [1., 1.]], requires_grad=True)
>>> y = th.arange(1, 13).float().view(6, 2).requires_grad_()
tensor([[ 1.,  2.],
        [ 3.,  4.],
        [ 5.,  6.],
        [ 7.,  8.],
        [ 9., 10.],
        [11., 12.]], requires_grad=True)
>>> out_1 = F.u_mul_e_sum(g, x, y)
>>> out_1  # (10, 12) = ((1, 1) * (3, 4)) + ((1, 1) * (7, 8))
tensor([[ 1.,  2.],
        [10., 12.],
        [25., 28.]], grad_fn=<GSpMMBackward>)
>>> out_1.sum().backward()
>>> x.grad
tensor([[12., 15.],
        [18., 20.],
        [12., 13.]])
>>> y.grad
tensor([[1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 1.]])
>>> out_2 = F.copy_u_sum(g, x)
>>> out_2
tensor([[1., 1.],
        [2., 2.],
        [3., 3.]], grad_fn=<GSpMMBackward>)
>>> out_3 = F.u_add_e_max(g, x, y)
>>> out_3
tensor([[ 2.,  3.],
        [ 8.,  9.],
        [12., 13.]], grad_fn=<GSpMMBackward>)
>>> y1 = th.rand(6, 4, 2, requires_grad=True)  # test broadcast
>>> F.u_mul_e_sum(g, x, y1).shape  # (2,), (4, 2) -> (4, 2)
torch.Size([3, 4, 2])

对于所有算子,输入图可以是同构图或二部图。

gspmm(g, op, reduce_op, lhs_data, rhs_data)

广义稀疏矩阵乘法接口。

u_add_e_sum(g, x, y)

广义 SpMM 函数。

u_sub_e_sum(g, x, y)

广义 SpMM 函数。

u_mul_e_sum(g, x, y)

广义 SpMM 函数。

u_div_e_sum(g, x, y)

广义 SpMM 函数。

u_add_e_max(g, x, y)

广义 SpMM 函数。

u_sub_e_max(g, x, y)

广义 SpMM 函数。

u_mul_e_max(g, x, y)

广义 SpMM 函数。

u_div_e_max(g, x, y)

广义 SpMM 函数。

u_add_e_min(g, x, y)

广义 SpMM 函数。

u_sub_e_min(g, x, y)

广义 SpMM 函数。

u_mul_e_min(g, x, y)

广义 SpMM 函数。

u_div_e_min(g, x, y)

广义 SpMM 函数。

u_add_e_mean(g, x, y)

广义 SpMM 函数。

u_sub_e_mean(g, x, y)

广义 SpMM 函数。

u_mul_e_mean(g, x, y)

广义 SpMM 函数。

u_div_e_mean(g, x, y)

广义 SpMM 函数。

copy_u_sum(g, x)

广义 SpMM 函数。

copy_e_sum(g, x)

广义 SpMM 函数。

copy_u_max(g, x)

广义 SpMM 函数。

copy_e_max(g, x)

广义 SpMM 函数。

copy_u_min(g, x)

广义 SpMM 函数。

copy_e_min(g, x)

广义 SpMM 函数。

copy_u_mean(g, x)

广义 SpMM 函数。

copy_e_mean(g, x)

广义 SpMM 函数。

GSDDMM 函数

广义采样稠密-稠密矩阵乘法。它通过对源节点/目标节点或边上的特征进行加/减/乘/除/点积运算来计算边特征。

与 GSpMM 类似,我们的实现支持 PyTorch/MXNet/Tensorflow 中 CPU/GPU 上的张量作为输入。所有算子都支持自动微分和广播。

GSDDMM 的内存成本为 \(O(|E|D)\),其中 \(|E|\) 是图中边的数量,\(D\) 是特征大小。

注意,我们支持 dot 算子,它在语义上与对 mul 算子结果沿最后一个维度求和归约相同。然而,dot 更节省内存,因为它融合mul 和求和归约,这在最后一个维度特征大小较大 (例如 Transformer 类模型中的多头注意力) 的情况下至关重要。

以下是展示 GSDDMM 工作原理的示例

>>> import dgl
>>> import torch as th
>>> import dgl.ops as F
>>> g = dgl.graph(([0, 0, 0, 1, 1, 2], [0, 1, 2, 1, 2, 2]))  # 3 nodes, 6 edges
>>> x = th.ones(3, 2, requires_grad=True)
>>> x
tensor([[1., 1.],
        [1., 1.],
        [1., 1.]], requires_grad=True)
>>> y = th.arange(1, 7).float().view(3, 2).requires_grad_()
>>> y
tensor([[1., 2.],
        [3., 4.],
        [5., 6.]], requires_grad=True)
>>> e = th.ones(6, 1, 2, requires_grad=True) * 2
tensor([[[2., 2.]],
        [[2., 2.]],
        [[2., 2.]],
        [[2., 2.]],
        [[2., 2.]],
        [[2., 2.]]], grad_fn=<MulBackward0>)
>>> out1 = F.u_div_v(g, x, y)
tensor([[1.0000, 0.5000],
        [0.3333, 0.2500],
        [0.2000, 0.1667],
        [0.3333, 0.2500],
        [0.2000, 0.1667],
        [0.2000, 0.1667]], grad_fn=<GSDDMMBackward>)
>>> out1.sum().backward()
>>> x.grad
tensor([[1.5333, 0.9167],
        [0.5333, 0.4167],
        [0.2000, 0.1667]])
>>> y.grad
tensor([[-1.0000, -0.2500],
        [-0.2222, -0.1250],
        [-0.1200, -0.0833]])
>>> out2 = F.e_sub_v(g, e, y)
>>> out2
tensor([[[ 1.,  0.]],
        [[-1., -2.]],
        [[-3., -4.]],
        [[-1., -2.]],
        [[-3., -4.]],
        [[-3., -4.]]], grad_fn=<GSDDMMBackward>)
>>> out3 = F.copy_v(g, y)
>>> out3
tensor([[1., 2.],
        [3., 4.],
        [5., 6.],
        [3., 4.],
        [5., 6.],
        [5., 6.]], grad_fn=<GSDDMMBackward>)
>>> out4 = F.u_dot_v(g, x, y)
>>> out4  # the last dimension was reduced to size 1.
tensor([[ 3.],
        [ 7.],
        [11.],
        [ 7.],
        [11.],
        [11.]], grad_fn=<GSDDMMBackward>)

gsddmm(g, op, lhs_data, rhs_data[, ...])

广义采样稠密-稠密矩阵乘法接口。

u_add_v(g, x, y)

广义 SDDMM 函数。

u_sub_v(g, x, y)

广义 SDDMM 函数。

u_mul_v(g, x, y)

广义 SDDMM 函数。

u_dot_v(g, x, y)

广义 SDDMM 函数。

u_div_v(g, x, y)

广义 SDDMM 函数。

u_add_e(g, x, y)

广义 SDDMM 函数。

u_sub_e(g, x, y)

广义 SDDMM 函数。

u_mul_e(g, x, y)

广义 SDDMM 函数。

u_dot_e(g, x, y)

广义 SDDMM 函数。

u_div_e(g, x, y)

广义 SDDMM 函数。

e_add_v(g, x, y)

广义 SDDMM 函数。

e_sub_v(g, x, y)

广义 SDDMM 函数。

e_mul_v(g, x, y)

广义 SDDMM 函数。

e_dot_v(g, x, y)

广义 SDDMM 函数。

e_div_v(g, x, y)

广义 SDDMM 函数。

v_add_u(g, x, y)

广义 SDDMM 函数。

v_sub_u(g, x, y)

广义 SDDMM 函数。

v_mul_u(g, x, y)

广义 SDDMM 函数。

v_dot_u(g, x, y)

广义 SDDMM 函数。

v_div_u(g, x, y)

广义 SDDMM 函数。

e_add_u(g, x, y)

广义 SDDMM 函数。

e_sub_u(g, x, y)

广义 SDDMM 函数。

e_mul_u(g, x, y)

广义 SDDMM 函数。

e_dot_u(g, x, y)

广义 SDDMM 函数。

e_div_u(g, x, y)

广义 SDDMM 函数。

v_add_e(g, x, y)

广义 SDDMM 函数。

v_sub_e(g, x, y)

广义 SDDMM 函数。

v_mul_e(g, x, y)

广义 SDDMM 函数。

v_dot_e(g, x, y)

广义 SDDMM 函数。

v_div_e(g, x, y)

广义 SDDMM 函数。

copy_u(g, x)

将源节点特征复制到边上的广义 SDDMM 函数。

copy_v(g, x)

将目标节点特征复制到边上的广义 SDDMM 函数。

与 GSpMM 类似,GSDDMM 算子同时支持同构图和二部图。

段归约模块

DGL 提供了按段沿第一个维度归约值张量的算子。

segment_reduce(seglen, value[, reducer])

段归约算子。

GatherMM 和 SegmentMM 模块

SegmentMM: DGL 提供了按段执行矩阵乘法的算子。

GatherMM: DGL 提供了根据给定索引聚集数据并执行矩阵乘法的算子。

gather_mm(a, b, *, idx_b)

根据给定索引聚集数据并执行矩阵乘法。

segment_mm(a, b, seglen_a)

按段执行矩阵乘法。

支持的数据类型

定义在 dgl.ops 中的算子支持浮点数据类型,即操作数必须是 half (float16) /float/double 张量。输入张量必须具有相同的数据类型 (如果一个输入张量是 float16 类型而另一个是 float32 类型,用户必须将其中一个转换为与另一个对齐)。

默认情况下,float16 数据类型支持是禁用的,因为它要求最低 GPU 计算能力为 sm_53 (Pascal, Volta, Turing 和 Ampere 架构)。

用户可以通过从源代码编译 DGL 来启用 float16 以进行混合精度训练 (详情请参见 混合精度训练 教程)。

与消息传递 API 的关系

带有内置消息/归约函数的 dgl.update_alldgl.apply_edges 调用会被分派到定义在 dgl.ops 中的算子的函数调用。

>>> import dgl
>>> import torch as th
>>> import dgl.ops as F
>>> import dgl.function as fn
>>> g = dgl.rand_graph(100, 1000)   # create a DGLGraph with 100 nodes and 1000 edges.
>>> x = th.rand(100, 20)            # node features.
>>> e = th.rand(1000, 20)
>>>
>>> # dgl.update_all + builtin functions
>>> g.srcdata['x'] = x              # srcdata is the same as ndata for graphs with one node type.
>>> g.edata['e'] = e
>>> g.update_all(fn.u_mul_e('x', 'e', 'm'), fn.sum('m', 'y'))
>>> y = g.dstdata['y']              # dstdata is the same as ndata for graphs with one node type.
>>>
>>> # use GSpMM operators defined in dgl.ops directly
>>> y = F.u_mul_e_sum(g, x, e)
>>>
>>> # dgl.apply_edges + builtin functions
>>> g.srcdata['x'] = x
>>> g.dstdata['y'] = y
>>> g.apply_edges(fn.u_dot_v('x', 'y', 'z'))
>>> z = g.edata['z']
>>>
>>> # use GSDDMM operators defined in dgl.ops directly
>>> z = F.u_dot_v(g, x, y)

用户可以决定是使用消息传递 API 还是 GSpMM/GSDDMM 算子,两者的效率相同。使用消息传递 API 编写的程序更具 DGL 风格,但在某些情况下,调用 GSpMM/GSDDMM 算子会更简洁。

注意,在 PyTorch 中,定义在 dgl.ops 中的所有算子都支持高阶梯度,消息传递 API 也支持,因为它们完全依赖于这些算子。