dgl.sparse.bsddmm

dgl.sparse.bsddmm(A: SparseMatrix, X1: Tensor, X2: Tensor) SparseMatrix[source]

批量采样-稠密-稠密矩阵乘法 (SDDMM)。

sddmm 将两个稠密矩阵 X1X2 相乘,然后将结果与稀疏矩阵 A 在非零位置进行逐元素乘法。

数学上,sddmm 的公式表示为

\[out = (X1 @ X2) * A\]

输入稠密矩阵的批量维度是最后一个维度。特别地,如果稀疏矩阵具有标量非零值,则将在 bsddmm 中进行广播。

参数:
  • A (SparseMatrix) – 形状为 (L, N) 的稀疏矩阵,其值为标量或长度为 K 的向量。

  • X1 (Tensor) – 形状为 (L, M, K) 的稠密矩阵。

  • X2 (Tensor) – 形状为 (M, N, K) 的稠密矩阵。

返回值:

形状为 (L, N) 的稀疏矩阵,其值为长度为 K 的向量。

返回类型:

SparseMatrix

示例

>>> indices = torch.tensor([[1, 1, 2], [2, 3, 3]])
>>> val = torch.arange(1, 4).float()
>>> A = dglsp.spmatrix(indices, val, (3, 4))
>>> X1 = torch.arange(0, 3 * 5 * 2).view(3, 5, 2).float()
>>> X2 = torch.arange(0, 5 * 4 * 2).view(5, 4, 2).float()
>>> dglsp.bsddmm(A, X1, X2)
SparseMatrix(indices=tensor([[1, 1, 2],
                             [2, 3, 3]]),
             values=tensor([[1560., 1735.],
                            [3400., 3770.],
                            [8400., 9105.]]),
             shape=(3, 4), nnz=3, val_size=(2,))