dgl.sparse.bsddmm
- dgl.sparse.bsddmm(A: SparseMatrix, X1: Tensor, X2: Tensor) SparseMatrix [source]
批量采样-稠密-稠密矩阵乘法 (SDDMM)。
sddmm
将两个稠密矩阵X1
和X2
相乘,然后将结果与稀疏矩阵A
在非零位置进行逐元素乘法。数学上,
sddmm
的公式表示为\[out = (X1 @ X2) * A\]输入稠密矩阵的批量维度是最后一个维度。特别地,如果稀疏矩阵具有标量非零值,则将在 bsddmm 中进行广播。
- 参数:
A (SparseMatrix) – 形状为
(L, N)
的稀疏矩阵,其值为标量或长度为K
的向量。X1 (Tensor) – 形状为
(L, M, K)
的稠密矩阵。X2 (Tensor) – 形状为
(M, N, K)
的稠密矩阵。
- 返回值:
形状为
(L, N)
的稀疏矩阵,其值为长度为K
的向量。- 返回类型:
示例
>>> indices = torch.tensor([[1, 1, 2], [2, 3, 3]]) >>> val = torch.arange(1, 4).float() >>> A = dglsp.spmatrix(indices, val, (3, 4)) >>> X1 = torch.arange(0, 3 * 5 * 2).view(3, 5, 2).float() >>> X2 = torch.arange(0, 5 * 4 * 2).view(5, 4, 2).float() >>> dglsp.bsddmm(A, X1, X2) SparseMatrix(indices=tensor([[1, 1, 2], [2, 3, 3]]), values=tensor([[1560., 1735.], [3400., 3770.], [8400., 9105.]]), shape=(3, 4), nnz=3, val_size=(2,))