入门
进阶资料
API 参考
gather_mm()
注意事项
杂项
根据给定的索引收集数据并执行矩阵乘法。
令结果张量为 c,该操作执行以下计算
c
c[i] = a[i] @ b[idx_b[i]] ,其中 len(c) == len(idx_b)
a (张量) – 一个形状为 (N, D1) 的二维张量
(N, D1)
b (张量) – 一个形状为 (R, D1, D2) 的三维张量
(R, D1, D2)
idx_b (张量,可选) – 一个形状为 (N,) 的一维整型张量。
(N,)
形状为 (N, D2) 的输出密集矩阵
(N, D2)
张量