dgl.ops.segment_mm

dgl.ops.segment_mm(a, b, seglen_a)[source]

根据分段执行矩阵乘法。

假设 seglen_a == [10, 5, 0, 3],该操作符将执行四次矩阵乘法

a[0:10] @ b[0], a[10:15] @ b[1],
a[15:15] @ b[2], a[15:18] @ b[3]
参数:
  • a (Tensor) – 左操作数,形状为 (N, D1) 的二维张量

  • b (Tensor) – 右操作数,形状为 (R, D1, D2) 的三维张量

  • seglen_a (Tensor) – 形状为 (R,) 的整数张量。每个元素是输入 a 中对应分段的长度。所有元素的总和必须等于 N

返回:

输出的密集矩阵,形状为 (N, D2)

返回类型:

Tensor