dgl.ops.gather_mm

dgl.ops.gather_mm(a, b, *, idx_b)[源码]

根据给定的索引收集数据并执行矩阵乘法。

令结果张量为 c,该操作执行以下计算

c[i] = a[i] @ b[idx_b[i]] ,其中 len(c) == len(idx_b)

参数:
  • a (张量) – 一个形状为 (N, D1) 的二维张量

  • b (张量) – 一个形状为 (R, D1, D2) 的三维张量

  • idx_b (张量可选) – 一个形状为 (N,) 的一维整型张量。

返回值:

形状为 (N, D2) 的输出密集矩阵

返回类型:

张量