dgl.ops.segment_mm
- dgl.ops.segment_mm(a, b, seglen_a)[source]
根据分段执行矩阵乘法。
假设
seglen_a == [10, 5, 0, 3]
,该操作符将执行四次矩阵乘法a[0:10] @ b[0], a[10:15] @ b[1], a[15:15] @ b[2], a[15:18] @ b[3]
- 参数:
a (Tensor) – 左操作数,形状为
(N, D1)
的二维张量b (Tensor) – 右操作数,形状为
(R, D1, D2)
的三维张量seglen_a (Tensor) – 形状为
(R,)
的整数张量。每个元素是输入a
中对应分段的长度。所有元素的总和必须等于N
。
- 返回:
输出的密集矩阵,形状为
(N, D2)
- 返回类型:
Tensor