快速入门

本教程快速介绍了 dgl.sparse 包提供的类和运算符。

Open In Colab GitHub

[1]:
# Install the required packages.

import os
# Uncomment following commands to download Pytorch and DGL
# !pip install torch==2.0.0+cpu torchvision==0.15.1+cpu torchaudio==2.0.1 --index-url https://download.pytorch.org/whl/cpu > /dev/null
# !pip install  dgl==1.1.0 -f https://data.dgl.ai/wheels/repo.html > /dev/null
import torch
os.environ['TORCH'] = torch.__version__
os.environ['DGLBACKEND'] = "pytorch"


try:
    import dgl.sparse as dglsp
    installed = True
except ImportError:
    installed = False
print("DGL installed!" if installed else "DGL not found!")
DGL installed!

稀疏矩阵

DGL 稀疏包的核心抽象是 SparseMatrix 类。与其他稀疏矩阵库(例如 scipy.sparsetorch.sparse)相比,DGL 的 SparseMatrix 专为结构化数据上的深度学习工作负载(例如图神经网络)而优化,具有以下特点

  • 自动稀疏格式。 无需费心选择不同的稀疏格式。只有一个 SparseMatrix,它将为要执行的操作选择最佳格式。

  • 非零元素可以是标量或向量。 易于通过向量表示对关系(例如边)进行建模。

  • 完全兼容 PyTorch。 该软件包基于 PyTorch 构建,并与 PyTorch 生态系统中的其他工具原生兼容。

创建 DGL 稀疏矩阵

创建稀疏矩阵最简单的方法是使用 spmatrix API,提供非零元素的索引。索引存储在一个形状为 (2, nnz) 的张量中,其中第 i 个非零元素存储在位置 (indices[0][i], indices[1][i])。下面的代码创建了一个 3x3 的稀疏矩阵。

[2]:
import torch
import dgl.sparse as dglsp

i = torch.tensor([[1, 1, 2],
                  [0, 2, 0]])
A = dglsp.spmatrix(i)  # 1.0 is default value for nnz elements.

print(A)
print("")
print("In dense format:")
print(A.to_dense())
SparseMatrix(indices=tensor([[1, 1, 2],
                             [0, 2, 0]]),
             values=tensor([1., 1., 1.]),
             shape=(3, 3), nnz=3)

In dense format:
tensor([[0., 0., 0.],
        [1., 0., 1.],
        [1., 0., 0.]])

如果未指定,形状将从索引中自动推断,但您也可以显式指定。

[3]:
i = torch.tensor([[0, 0, 1],
                  [0, 2, 0]])

A1 = dglsp.spmatrix(i)
print(f"Implicit Shape: {A1.shape}")
print(A1.to_dense())
print("")

A2 = dglsp.spmatrix(i, shape=(3, 3))
print(f"Explicit Shape: {A2.shape}")
print(A2.to_dense())
Implicit Shape: (2, 3)
tensor([[1., 0., 1.],
        [1., 0., 0.]])

Explicit Shape: (3, 3)
tensor([[1., 0., 1.],
        [1., 0., 0.],
        [0., 0., 0.]])

稀疏矩阵中的非零元素既可以设置为标量值,也可以设置为向量值。

[4]:
i = torch.tensor([[1, 1, 2],
                  [0, 2, 0]])
# The length of the value should match the nnz elements represented by the
# sparse matrix format.
scalar_val = torch.tensor([1., 2., 3.])
vector_val = torch.tensor([[1., 1.], [2., 2.], [3., 3.]])

print("-----Scalar Values-----")
A = dglsp.spmatrix(i, scalar_val)
print(A)
print("")
print("In dense format:")
print(A.to_dense())
print("")

print("-----Vector Values-----")
A = dglsp.spmatrix(i, vector_val)
print(A)
print("")
print("In dense format:")
print(A.to_dense())
-----Scalar Values-----
SparseMatrix(indices=tensor([[1, 1, 2],
                             [0, 2, 0]]),
             values=tensor([1., 2., 3.]),
             shape=(3, 3), nnz=3)

In dense format:
tensor([[0., 0., 0.],
        [1., 0., 2.],
        [3., 0., 0.]])

-----Vector Values-----
SparseMatrix(indices=tensor([[1, 1, 2],
                             [0, 2, 0]]),
             values=tensor([[1., 1.],
                            [2., 2.],
                            [3., 3.]]),
             shape=(3, 3), nnz=3, val_size=(2,))

In dense format:
tensor([[[0., 0.],
         [0., 0.],
         [0., 0.]],

        [[1., 1.],
         [0., 0.],
         [2., 2.]],

        [[3., 3.],
         [0., 0.],
         [0., 0.]]])

重复索引

[5]:
i = torch.tensor([[0, 0, 0, 1],
                  [0, 2, 2, 0]])
val = torch.tensor([1., 2., 3., 4])
A = dglsp.spmatrix(i, val)
print(A)
print(f"Whether A contains duplicate indices: {A.has_duplicate()}")
print("")

B = A.coalesce()
print(B)
print(f"Whether B contains duplicate indices: {B.has_duplicate()}")
SparseMatrix(indices=tensor([[0, 0, 0, 1],
                             [0, 2, 2, 0]]),
             values=tensor([1., 2., 3., 4.]),
             shape=(2, 3), nnz=4)
Whether A contains duplicate indices: True

SparseMatrix(indices=tensor([[0, 0, 1],
                             [0, 2, 0]]),
             values=tensor([1., 5., 4.]),
             shape=(2, 3), nnz=3)
Whether B contains duplicate indices: False

val_like

您可以通过保留给定稀疏矩阵的非零索引,但使用不同的非零值来创建一个新的稀疏矩阵。

[6]:
i = torch.tensor([[1, 1, 2],
                  [0, 2, 0]])
val = torch.tensor([1., 2., 3.])
A = dglsp.spmatrix(i, val)

new_val = torch.tensor([4., 5., 6.])
B = dglsp.val_like(A, new_val)
print(B)
SparseMatrix(indices=tensor([[1, 1, 2],
                             [0, 2, 0]]),
             values=tensor([4., 5., 6.]),
             shape=(3, 3), nnz=3)

从各种稀疏格式创建稀疏矩阵

  • from_coo(): 从 COO 格式创建稀疏矩阵。

  • from_csr(): 从 CSR 格式创建稀疏矩阵。

  • from_csc(): 从 CSC 格式创建稀疏矩阵。

[7]:
row = torch.tensor([0, 1, 2, 2, 2])
col = torch.tensor([1, 2, 0, 1, 2])

print("-----Create from COO format-----")
A = dglsp.from_coo(row, col)
print(A)
print("")
print("In dense format:")
print(A.to_dense())
print("")

indptr = torch.tensor([0, 1, 2, 5])
indices = torch.tensor([1, 2, 0, 1, 2])

print("-----Create from CSR format-----")
A = dglsp.from_csr(indptr, indices)
print(A)
print("")
print("In dense format:")
print(A.to_dense())
print("")

print("-----Create from CSC format-----")
B = dglsp.from_csc(indptr, indices)
print(B)
print("")
print("In dense format:")
print(B.to_dense())
-----Create from COO format-----
SparseMatrix(indices=tensor([[0, 1, 2, 2, 2],
                             [1, 2, 0, 1, 2]]),
             values=tensor([1., 1., 1., 1., 1.]),
             shape=(3, 3), nnz=5)

In dense format:
tensor([[0., 1., 0.],
        [0., 0., 1.],
        [1., 1., 1.]])

-----Create from CSR format-----
SparseMatrix(indices=tensor([[0, 1, 2, 2, 2],
                             [1, 2, 0, 1, 2]]),
             values=tensor([1., 1., 1., 1., 1.]),
             shape=(3, 3), nnz=5)

In dense format:
tensor([[0., 1., 0.],
        [0., 0., 1.],
        [1., 1., 1.]])

-----Create from CSC format-----
SparseMatrix(indices=tensor([[1, 2, 0, 1, 2],
                             [0, 1, 2, 2, 2]]),
             values=tensor([1., 1., 1., 1., 1.]),
             shape=(3, 3), nnz=5)

In dense format:
tensor([[0., 0., 1.],
        [1., 0., 1.],
        [0., 1., 1.]])

DGL 稀疏矩阵的属性和方法

[8]:
i = torch.tensor([[0, 1, 1, 2],
                  [1, 0, 2, 0]])
val = torch.tensor([1., 2., 3., 4.])
A = dglsp.spmatrix(i, val)

print(f"Shape of sparse matrix: {A.shape}")
print(f"The number of nonzero elements of sparse matrix: {A.nnz}")
print(f"Datatype of sparse matrix: {A.dtype}")
print(f"Device sparse matrix is stored on: {A.device}")
print(f"Get the values of the nonzero elements: {A.val}")
print(f"Get the row indices of the nonzero elements: {A.row}")
print(f"Get the column indices of the nonzero elements: {A.col}")
print(f"Get the coordinate (COO) representation: {A.coo()}")
print(f"Get the compressed sparse row (CSR) representation: {A.csr()}")
print(f"Get the compressed sparse column (CSC) representation: {A.csc()}")
Shape of sparse matrix: (3, 3)
The number of nonzero elements of sparse matrix: 4
Datatype of sparse matrix: torch.float32
Device sparse matrix is stored on: cpu
Get the values of the nonzero elements: tensor([1., 2., 3., 4.])
Get the row indices of the nonzero elements: tensor([0, 1, 1, 2])
Get the column indices of the nonzero elements: tensor([1, 0, 2, 0])
Get the coordinate (COO) representation: (tensor([0, 1, 1, 2]), tensor([1, 0, 2, 0]))
Get the compressed sparse row (CSR) representation: (tensor([0, 1, 3, 4]), tensor([1, 0, 2, 0]), tensor([0, 1, 2, 3]))
Get the compressed sparse column (CSC) representation: (tensor([0, 2, 3, 4]), tensor([1, 2, 0, 1]), tensor([1, 3, 0, 2]))

数据类型和/或设备转换

[9]:
i = torch.tensor([[0, 1, 1, 2],
                  [1, 0, 2, 0]])
val = torch.tensor([1., 2., 3., 4.])
A = dglsp.spmatrix(i, val)

B = A.to(device='cpu', dtype=torch.int32)
print(f"Device sparse matrix is stored on: {B.device}")
print(f"Datatype of sparse matrix: {B.dtype}")
Device sparse matrix is stored on: cpu
Datatype of sparse matrix: torch.int32

与 PyTorch 类似,我们也提供了各种细粒度的 API (文档) 用于数据类型和/或设备转换。

对角矩阵

对角矩阵是一种特殊的稀疏矩阵,其中主对角线以外的所有元素都为零。

初始化 DGL 对角稀疏矩阵

可以通过 dglsp.diag() 初始化 DGL 对角稀疏矩阵。

单位矩阵是一种特殊的对角稀疏矩阵,其中对角线上的所有值都为 1.0。使用 dglsp.identity() 初始化对角稀疏矩阵。

[10]:
val = torch.tensor([1., 2., 3., 4.])
D = dglsp.diag(val)
print(D)

I = dglsp.identity(shape=(3, 3))
print(I)
SparseMatrix(indices=tensor([[0, 1, 2, 3],
                             [0, 1, 2, 3]]),
             values=tensor([1., 2., 3., 4.]),
             shape=(4, 4), nnz=4)
SparseMatrix(indices=tensor([[0, 1, 2],
                             [0, 1, 2]]),
             values=tensor([1., 1., 1.]),
             shape=(3, 3), nnz=3)

稀疏矩阵上的操作

  • 元素级操作

    • A + B

    • A - B

    • A * B

    • A / B

    • A ** scalar

  • 广播操作

    • sp_<op>_v()

  • 归约操作

    • reduce()

    • sum()

    • smax()

    • smin()

    • smean()

  • 矩阵变换

    • SparseMatrix.transpose()SparseMatrix.T

    • SparseMatrix.neg()

    • SparseMatrix.inv()

  • 矩阵乘法

    • matmul()

    • sddmm()

在本教程中,我们使用密集格式打印稀疏矩阵,因为它更直观易读。

元素级操作

add(A, B),等价于 A + B

两个稀疏矩阵的元素级加法,返回一个稀疏矩阵。

[11]:
i = torch.tensor([[1, 1, 2],
                  [0, 2, 0]])
val = torch.tensor([1., 2., 3.])
A1 = dglsp.spmatrix(i, val, shape=(3, 3))
print("A1:")
print(A1.to_dense())

i = torch.tensor([[0, 1, 2],
                  [0, 2, 1]])
val = torch.tensor([4., 5., 6.])
A2 = dglsp.spmatrix(i, val, shape=(3, 3))
print("A2:")
print(A2.to_dense())

val = torch.tensor([-1., -2., -3.])
D1 = dglsp.diag(val)
print("D1:")
print(D1.to_dense())

val = torch.tensor([-4., -5., -6.])
D2 = dglsp.diag(val)
print("D2:")
print(D2.to_dense())

print("A1 + A2:")
print((A1 + A2).to_dense())

print("A1 + D1:")
print((A1 + D1).to_dense())

print("D1 + D2:")
print((D1 + D2).to_dense())
A1:
tensor([[0., 0., 0.],
        [1., 0., 2.],
        [3., 0., 0.]])
A2:
tensor([[4., 0., 0.],
        [0., 0., 5.],
        [0., 6., 0.]])
D1:
tensor([[-1.,  0.,  0.],
        [ 0., -2.,  0.],
        [ 0.,  0., -3.]])
D2:
tensor([[-4.,  0.,  0.],
        [ 0., -5.,  0.],
        [ 0.,  0., -6.]])
A1 + A2:
tensor([[4., 0., 0.],
        [1., 0., 7.],
        [3., 6., 0.]])
A1 + D1:
tensor([[-1.,  0.,  0.],
        [ 1., -2.,  2.],
        [ 3.,  0., -3.]])
D1 + D2:
tensor([[-5.,  0.,  0.],
        [ 0., -7.,  0.],
        [ 0.,  0., -9.]])

sub(A, B),等价于 A - B

两个稀疏矩阵的元素级减法,返回一个稀疏矩阵。

[12]:
i = torch.tensor([[1, 1, 2],
                  [0, 2, 0]])
val = torch.tensor([1., 2., 3.])
A1 = dglsp.spmatrix(i, val, shape=(3, 3))
print("A1:")
print(A1.to_dense())

i = torch.tensor([[0, 1, 2],
                  [0, 2, 1]])
val = torch.tensor([4., 5., 6.])
A2 = dglsp.spmatrix(i, val, shape=(3, 3))
print("A2:")
print(A2.to_dense())

val = torch.tensor([-1., -2., -3.])
D1 = dglsp.diag(val)
print("D1:")
print(D1.to_dense())

val = torch.tensor([-4., -5., -6.])
D2 = dglsp.diag(val)
print("D2:")
print(D2.to_dense())

print("A1 - A2:")
print((A1 - A2).to_dense())

print("A1 - D1:")
print((A1 - D1).to_dense())

print("D1 - A1:")
print((D1 - A1).to_dense())

print("D1 - D2:")
print((D1 - D2).to_dense())
A1:
tensor([[0., 0., 0.],
        [1., 0., 2.],
        [3., 0., 0.]])
A2:
tensor([[4., 0., 0.],
        [0., 0., 5.],
        [0., 6., 0.]])
D1:
tensor([[-1.,  0.,  0.],
        [ 0., -2.,  0.],
        [ 0.,  0., -3.]])
D2:
tensor([[-4.,  0.,  0.],
        [ 0., -5.,  0.],
        [ 0.,  0., -6.]])
A1 - A2:
tensor([[-4.,  0.,  0.],
        [ 1.,  0., -3.],
        [ 3., -6.,  0.]])
A1 - D1:
tensor([[1., 0., 0.],
        [1., 2., 2.],
        [3., 0., 3.]])
D1 - A1:
tensor([[-1.,  0.,  0.],
        [-1., -2., -2.],
        [-3.,  0., -3.]])
D1 - D2:
tensor([[3., 0., 0.],
        [0., 3., 0.],
        [0., 0., 3.]])

mul(A, B),等价于 A * B

两个稀疏矩阵或稀疏矩阵与标量的元素级乘法,返回一个稀疏矩阵。

[13]:
i = torch.tensor([[1, 1, 2],
                  [0, 2, 0]])
val = torch.tensor([1., 2., 3.])
A1 = dglsp.spmatrix(i, val, shape=(3, 3))
print("A1:")
print(A1.to_dense())

i = torch.tensor([[0, 1, 2, 2],
                  [0, 2, 0, 1]])
val = torch.tensor([1., 2., 3., 4.])
A2 = dglsp.spmatrix(i, val, shape=(3, 3))

print("A2:")
print(A2.to_dense())

print("A1 * 3:")
print((A1 * 3).to_dense())
print("3 * A1:")
print((3 * A1).to_dense())

print("A1 * A2")
print((A1 * A2).to_dense())

val = torch.tensor([-1., -2., -3.])
D1 = dglsp.diag(val)
print("D1:")
print(D1.to_dense())

print("D1 * A2")
print((D1 * A2).to_dense())

val = torch.tensor([-4., -5., -6.])
D2 = dglsp.diag(val)
print("D2:")
print(D2.to_dense())

print("D1 * -2:")
print((D1 * -2).to_dense())
print("-2 * D1:")
print((-2 * D1).to_dense())

print("D1 * D2:")
print((D1 * D2).to_dense())
A1:
tensor([[0., 0., 0.],
        [1., 0., 2.],
        [3., 0., 0.]])
A2:
tensor([[1., 0., 0.],
        [0., 0., 2.],
        [3., 4., 0.]])
A1 * 3:
tensor([[0., 0., 0.],
        [3., 0., 6.],
        [9., 0., 0.]])
3 * A1:
tensor([[0., 0., 0.],
        [3., 0., 6.],
        [9., 0., 0.]])
A1 * A2
tensor([[0., 0., 0.],
        [0., 0., 4.],
        [9., 0., 0.]])
D1:
tensor([[-1.,  0.,  0.],
        [ 0., -2.,  0.],
        [ 0.,  0., -3.]])
D1 * A2
tensor([[-1.,  0.,  0.],
        [ 0.,  0.,  0.],
        [ 0.,  0.,  0.]])
D2:
tensor([[-4.,  0.,  0.],
        [ 0., -5.,  0.],
        [ 0.,  0., -6.]])
D1 * -2:
tensor([[2., 0., 0.],
        [0., 4., 0.],
        [0., 0., 6.]])
-2 * D1:
tensor([[2., 0., 0.],
        [0., 4., 0.],
        [0., 0., 6.]])
D1 * D2:
tensor([[ 4.,  0.,  0.],
        [ 0., 10.,  0.],
        [ 0.,  0., 18.]])

div(A, B),等价于 A / B

两个稀疏矩阵或稀疏矩阵与标量的元素级乘法,返回一个稀疏矩阵。如果 AB 都是稀疏矩阵,则它们必须具有相同的稀疏度。并且返回的矩阵与 A 具有相同的非零条目顺序。

[14]:
i = torch.tensor([[1, 1, 2],
                  [0, 2, 0]])
val = torch.tensor([1., 2., 3.])
A1 = dglsp.spmatrix(i, val, shape=(3, 3))
print("A1:")
print(A1.to_dense())

i = torch.tensor([[1, 2, 1],
                  [0, 0, 2]])
val = torch.tensor([1., 3., 2.])
A2 = dglsp.spmatrix(i, val, shape=(3, 3))

print("A1 / 2:")
print((A1 / 2).to_dense())

print("A1 / A2")
print((A1 / A2).to_dense())

val = torch.tensor([-1., -2., -3.])
D1 = dglsp.diag(val)
print("D1:")
print(D1.to_dense())

val = torch.tensor([-4., -5., -6.])
D2 = dglsp.diag(val)
print("D2:")
print(D2.to_dense())

print("D1 / D2:")
print((D1 / D2).to_dense())

print("D1 / 2:")
print((D1 / 2).to_dense())
A1:
tensor([[0., 0., 0.],
        [1., 0., 2.],
        [3., 0., 0.]])
A1 / 2:
tensor([[0.0000, 0.0000, 0.0000],
        [0.5000, 0.0000, 1.0000],
        [1.5000, 0.0000, 0.0000]])
A1 / A2
tensor([[0., 0., 0.],
        [1., 0., 1.],
        [1., 0., 0.]])
D1:
tensor([[-1.,  0.,  0.],
        [ 0., -2.,  0.],
        [ 0.,  0., -3.]])
D2:
tensor([[-4.,  0.,  0.],
        [ 0., -5.,  0.],
        [ 0.,  0., -6.]])
D1 / D2:
tensor([[0.2500, 0.0000, 0.0000],
        [0.0000, 0.4000, 0.0000],
        [0.0000, 0.0000, 0.5000]])
D1 / 2:
tensor([[-0.5000,  0.0000,  0.0000],
        [ 0.0000, -1.0000,  0.0000],
        [ 0.0000,  0.0000, -1.5000]])

power(A, B),等价于 A ** B

稀疏矩阵与标量的元素级幂运算,返回一个稀疏矩阵。

[15]:
i = torch.tensor([[1, 1, 2],
                  [0, 2, 0]])
val = torch.tensor([1., 2., 3.])
A = dglsp.spmatrix(i, val, shape=(3, 3))
print("A:")
print(A.to_dense())

print("A ** 3:")
print((A ** 3).to_dense())

val = torch.tensor([-1., -2., -3.])
D = dglsp.diag(val)
print("D:")
print(D.to_dense())

print("D1 ** 2:")
print((D1 ** 2).to_dense())
A:
tensor([[0., 0., 0.],
        [1., 0., 2.],
        [3., 0., 0.]])
A ** 3:
tensor([[ 0.,  0.,  0.],
        [ 1.,  0.,  8.],
        [27.,  0.,  0.]])
D:
tensor([[-1.,  0.,  0.],
        [ 0., -2.,  0.],
        [ 0.,  0., -3.]])
D1 ** 2:
tensor([[1., 0., 0.],
        [0., 4., 0.],
        [0., 0., 9.]])

广播操作

**sp_<op>_v(A, v)**

对稀疏矩阵和向量进行广播操作,返回一个稀疏矩阵。v 被广播到 A 的形状,然后运算符应用于 A 的非零值。<op> 可以是 add, sub, mul 和 div。

关于 v 的形状有两种情况

  1. v 是形状为 (1, A.shape[1])(A.shape[1]) 的向量。在这种情况下,vA 的行维度上进行广播。

  2. v 是形状为 (A.shape[0], 1) 的向量。在这种情况下,vA 的列维度上进行广播。

[16]:
i = torch.tensor([[1, 0, 2], [0, 3, 2]])
val = torch.tensor([10, 20, 30])
A = dglsp.spmatrix(i, val, shape=(3, 4))

v1 = torch.tensor([1, 2, 3, 4])
print("A:")
print(A.to_dense())

print("v1:")
print(v1)

print("sp_add_v(A, v1)")
print(dglsp.sp_add_v(A, v1).to_dense())

v2 = v1.reshape(1, -1)
print("v2:")
print(v2)

print("sp_add_v(A, v2)")
print(dglsp.sp_add_v(A, v2).to_dense())

v3 = torch.tensor([1, 2, 3]).reshape(-1, 1)
print("v3:")
print(v3)

print("sp_add_v(A, v3)")
print(dglsp.sp_add_v(A, v3).to_dense())
A:
tensor([[ 0,  0,  0, 20],
        [10,  0,  0,  0],
        [ 0,  0, 30,  0]])
v1:
tensor([1, 2, 3, 4])
sp_add_v(A, v1)
tensor([[ 0,  0,  0, 24],
        [11,  0,  0,  0],
        [ 0,  0, 33,  0]])
v2:
tensor([[1, 2, 3, 4]])
sp_add_v(A, v2)
tensor([[ 0,  0,  0, 24],
        [11,  0,  0,  0],
        [ 0,  0, 33,  0]])
v3:
tensor([[1],
        [2],
        [3]])
sp_add_v(A, v3)
tensor([[ 0,  0,  0, 21],
        [12,  0,  0,  0],
        [ 0,  0, 33,  0]])

归约操作

所有 DGL 稀疏归约操作仅考虑非零元素。为了将它们与考虑零元素的密集 PyTorch 归约操作区分开来,我们使用名称 smaxsminsmeans 代表 sparse)。

[17]:
i = torch.tensor([[0, 1, 1, 2],
                  [1, 0, 2, 0]])
val = torch.tensor([1., 2., 3., 4.])
A = dglsp.spmatrix(i, val)
print(A.T.to_dense())
print("")

# O1, O2 will have the same value.
O1 = A.reduce(0, 'sum')
O2 = A.sum(0)
print("Reduce with reducer:sum along dim = 0:")
print(O1)
print("")

# O3, O4 will have the same value.
O3 = A.reduce(0, 'smax')
O4 = A.smax(0)
print("Reduce with reducer:max along dim = 0:")
print(O3)
print("")

# O5, O6 will have the same value.
O5 = A.reduce(0, 'smin')
O6 = A.smin(0)
print("Reduce with reducer:min along dim = 0:")
print(O5)
print("")

# O7, O8 will have the same value.
O7 = A.reduce(0, 'smean')
O8 = A.smean(0)
print("Reduce with reducer:smean along dim = 0:")
print(O7)
print("")
tensor([[0., 2., 4.],
        [1., 0., 0.],
        [0., 3., 0.]])

Reduce with reducer:sum along dim = 0:
tensor([6., 1., 3.])

Reduce with reducer:max along dim = 0:
tensor([4., 1., 3.])

Reduce with reducer:min along dim = 0:
tensor([2., 1., 3.])

Reduce with reducer:smean along dim = 0:
tensor([3., 1., 3.])

矩阵变换

稀疏矩阵

[18]:
i = torch.tensor([[0, 1, 1, 2],
                  [1, 0, 2, 0]])
val = torch.tensor([1., 2., 3., 4.])
A = dglsp.spmatrix(i, val)
print(A.to_dense())
print("")

print("Get transpose of sparse matrix.")
print(A.T.to_dense())
# Alias
# A.transpose()
# A.t()
print("")

print("Get a sparse matrix with the negation of the original nonzero values.")
print(A.neg().to_dense())
print("")
tensor([[0., 1., 0.],
        [2., 0., 3.],
        [4., 0., 0.]])

Get transpose of sparse matrix.
tensor([[0., 2., 4.],
        [1., 0., 0.],
        [0., 3., 0.]])

Get a sparse matrix with the negation of the original nonzero values.
tensor([[ 0., -1.,  0.],
        [-2.,  0., -3.],
        [-4.,  0.,  0.]])

矩阵乘法

matmul(A, B),等价于 A @ B

稀疏矩阵和/或密集矩阵上的矩阵乘法。有两种情况如下。

SparseMatrix @ SparseMatrix -> SparseMatrix

对于一个 \(L \times M\) 稀疏矩阵 A 和一个 \(M \times N\) 稀疏矩阵 B,A @ B 的形状将是一个 \(L \times N\) 稀疏矩阵。

[19]:
i = torch.tensor([[1, 1, 2],
                  [0, 2, 0]])
val = torch.tensor([1., 2., 3.])
A1 = dglsp.spmatrix(i, val, shape=(3, 3))
print("A1:")
print(A1.to_dense())

i = torch.tensor([[0, 1, 2],
                  [0, 2, 1]])
val = torch.tensor([4., 5., 6.])
A2 = dglsp.spmatrix(i, val, shape=(3, 3))
print("A2:")
print(A2.to_dense())

val = torch.tensor([-1., -2., -3.])
D1 = dglsp.diag(val)
print("D1:")
print(D1.to_dense())

val = torch.tensor([-4., -5., -6.])
D2 = dglsp.diag(val)
print("D2:")
print(D2.to_dense())

print("A1 @ A2:")
print((A1 @ A2).to_dense())

print("A1 @ D1:")
print((A1 @ D1).to_dense())

print("D1 @ A1:")
print((D1 @ A1).to_dense())

print("D1 @ D2:")
print((D1 @ D2).to_dense())
A1:
tensor([[0., 0., 0.],
        [1., 0., 2.],
        [3., 0., 0.]])
A2:
tensor([[4., 0., 0.],
        [0., 0., 5.],
        [0., 6., 0.]])
D1:
tensor([[-1.,  0.,  0.],
        [ 0., -2.,  0.],
        [ 0.,  0., -3.]])
D2:
tensor([[-4.,  0.,  0.],
        [ 0., -5.,  0.],
        [ 0.,  0., -6.]])
A1 @ A2:
tensor([[ 0.,  0.,  0.],
        [ 4., 12.,  0.],
        [12.,  0.,  0.]])
A1 @ D1:
tensor([[ 0.,  0.,  0.],
        [-1.,  0., -6.],
        [-3.,  0.,  0.]])
D1 @ A1:
tensor([[ 0.,  0.,  0.],
        [-2.,  0., -4.],
        [-9.,  0.,  0.]])
D1 @ D2:
tensor([[ 4.,  0.,  0.],
        [ 0., 10.,  0.],
        [ 0.,  0., 18.]])

SparseMatrix @ Tensor -> Tensor

对于一个 \(L \times M\) 稀疏矩阵 A 和一个 \(M \times N\) 密集矩阵 B,A @ B 的形状将是一个 \(L \times N\) 密集矩阵。

[20]:
i = torch.tensor([[1, 1, 2],
                  [0, 2, 0]])
val = torch.tensor([1., 2., 3.])
A = dglsp.spmatrix(i, val, shape=(3, 3))
print("A:")
print(A.to_dense())

val = torch.tensor([-1., -2., -3.])
D = dglsp.diag(val)
print("D:")
print(D.to_dense())

X = torch.tensor([[11., 22.], [33., 44.], [55., 66.]])
print("X:")
print(X)

print("A @ X:")
print(A @ X)

print("D @ X:")
print(D @ X)
A:
tensor([[0., 0., 0.],
        [1., 0., 2.],
        [3., 0., 0.]])
D:
tensor([[-1.,  0.,  0.],
        [ 0., -2.,  0.],
        [ 0.,  0., -3.]])
X:
tensor([[11., 22.],
        [33., 44.],
        [55., 66.]])
A @ X:
tensor([[  0.,   0.],
        [121., 154.],
        [ 33.,  66.]])
D @ X:
tensor([[ -11.,  -22.],
        [ -66.,  -88.],
        [-165., -198.]])

此运算符还支持批量稀疏-密集矩阵乘法。稀疏矩阵 A 的形状应为 \(L \times M\),其中非零值为长度为 \(K\) 的向量。密集矩阵 B 的形状应为 \(M \times N \times K\)。输出是形状为 \(L \times N \times K\) 的密集矩阵。

[21]:
i = torch.tensor([[1, 1, 2],
                  [0, 2, 0]])
val = torch.tensor([[1., 1.], [2., 2.], [3., 3.]])
A = dglsp.spmatrix(i, val, shape=(3, 3))
print("A:")
print(A.to_dense())

X = torch.tensor([[[1., 1.], [1., 2.]],
                  [[1., 3.], [1., 4.]],
                  [[1., 5.], [1., 6.]]])
print("X:")
print(X)

print("A @ X:")
print(A @ X)
A:
tensor([[[0., 0.],
         [0., 0.],
         [0., 0.]],

        [[1., 1.],
         [0., 0.],
         [2., 2.]],

        [[3., 3.],
         [0., 0.],
         [0., 0.]]])
X:
tensor([[[1., 1.],
         [1., 2.]],

        [[1., 3.],
         [1., 4.]],

        [[1., 5.],
         [1., 6.]]])
A @ X:
tensor([[[ 0.,  0.],
         [ 0.,  0.]],

        [[ 3., 11.],
         [ 3., 14.]],

        [[ 3.,  3.],
         [ 3.,  6.]]])

采样-密集-密集矩阵乘法 (SDDMM)

sddmm 将两个密集矩阵 X1 和 X2 相乘,然后在非零位置将结果与稀疏矩阵 A 进行元素级乘法。这适用于带有标量值的稀疏矩阵。

\[out = (X_1 @ X_2) * A\]

对于一个 \(L \times N\) 稀疏矩阵 A、一个 \(L \times M\) 密集矩阵 X1 和一个 \(M \times N\) 密集矩阵 X2,sddmm(A, X1, X2) 将是一个 \(L \times N\) 稀疏矩阵。

[22]:
i = torch.tensor([[1, 1, 2],
                  [2, 3, 3]])
val = torch.tensor([1., 2., 3.])
A = dglsp.spmatrix(i, val, (3, 4))
print("A:")
print(A.to_dense())

X1 = torch.randn(3, 5)
X2 = torch.randn(5, 4)
print("X1:")
print(X1)
print("X2:")
print(X2)

O = dglsp.sddmm(A, X1, X2)
print("dglsp.sddmm(A, X1, X2):")
print(O.to_dense())
A:
tensor([[0., 0., 0., 0.],
        [0., 0., 1., 2.],
        [0., 0., 0., 3.]])
X1:
tensor([[-1.8730,  1.1163,  0.1023, -2.2753, -0.4902],
        [ 0.2650,  0.6934, -0.0514, -1.1388,  0.6427],
        [ 0.2405,  1.2207,  0.8124,  0.4997, -0.4565]])
X2:
tensor([[-1.6420, -0.9294, -0.4997, -0.0368],
        [-0.5273, -0.4092,  0.1657,  0.6031],
        [ 0.1476, -0.0357, -0.2013,  0.0733],
        [ 1.1635, -0.8145, -1.0337,  1.9323],
        [ 0.0064,  0.0051, -0.3688, -0.8393]])
dglsp.sddmm(A, X1, X2):
tensor([[ 0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.9330, -4.6706],
        [ 0.0000,  0.0000,  0.0000,  6.4070]])

此运算符还支持批量采样-密集-密集矩阵乘法。对于一个 \(L \times N\) 稀疏矩阵 A(具有长度为 \(𝐾\) 的非零向量值)、一个 \(L \times M \times K\) 密集矩阵 X1 和一个 \(M \times N \times K\) 密集矩阵 X2,sddmm(A, X1, X2) 将是一个 \(L \times N \times K\) 稀疏矩阵。

[23]:
i = torch.tensor([[1, 1, 2],
                  [2, 3, 3]])
val = torch.tensor([[1., 1.], [2., 2.], [3., 3.]])
A = dglsp.spmatrix(i, val, (3, 4))
print("A:")
print(A.to_dense())

X1 = torch.randn(3, 5, 2)
X2 = torch.randn(5, 4, 2)
print("X1:")
print(X1)
print("X2:")
print(X2)

O = dglsp.sddmm(A, X1, X2)
print("dglsp.sddmm(A, X1, X2):")
print(O.to_dense())
A:
tensor([[[0., 0.],
         [0., 0.],
         [0., 0.],
         [0., 0.]],

        [[0., 0.],
         [0., 0.],
         [1., 1.],
         [2., 2.]],

        [[0., 0.],
         [0., 0.],
         [0., 0.],
         [3., 3.]]])
X1:
tensor([[[-0.2226,  0.0104],
         [-1.0669, -0.1250],
         [ 0.3939,  0.1996],
         [-0.3159, -0.6870],
         [ 0.4624,  0.1524]],

        [[-0.3108,  1.4346],
         [ 0.2352, -0.2955],
         [ 0.0060, -0.3373],
         [ 0.9009, -0.1628],
         [ 0.5921,  2.1810]],

        [[ 0.8294,  0.2365],
         [ 1.2023,  2.0901],
         [ 0.2455, -0.0645],
         [-1.1549, -0.0709],
         [-0.5096,  0.4146]]])
X2:
tensor([[[-1.0710, -0.1782],
         [ 1.1551, -1.3017],
         [-0.3828, -0.3858],
         [ 0.5985, -0.4638]],

        [[ 1.3747, -0.0801],
         [ 0.2521,  0.1669],
         [-0.0505, -0.3232],
         [-0.3047,  0.3976]],

        [[-0.0027,  1.2815],
         [-0.5634, -1.1472],
         [ 0.3738,  0.6490],
         [ 0.0763, -0.1001]],

        [[-0.2826,  1.9760],
         [ 1.3104, -0.6009],
         [ 0.5239,  0.0615],
         [-1.2930,  1.3311]],

        [[ 0.7598,  0.9944],
         [-0.3392,  0.5695],
         [-0.7028, -0.2905],
         [ 1.3781, -0.2783]]])
dglsp.sddmm(A, X1, X2):
tensor([[[ 0.0000,  0.0000],
         [ 0.0000,  0.0000],
         [ 0.0000,  0.0000],
         [ 0.0000,  0.0000]],

        [[ 0.0000,  0.0000],
         [ 0.0000,  0.0000],
         [ 0.1652, -1.3205],
         [-1.2121, -3.1457]],

        [[ 0.0000,  0.0000],
         [ 0.0000,  0.0000],
         [ 0.0000,  0.0000],
         [ 2.8190,  1.5542]]])

非线性激活函数

元素级函数

大多数激活函数是元素级的,可以进一步分为两类

稀疏保留函数,例如 sin()tanh()sigmoid()relu() 等。您可以直接将它们应用于稀疏矩阵的 val 张量,然后使用 val_like 重建具有相同稀疏度的新矩阵。

[24]:
i = torch.tensor([[0, 1, 1, 2],
                  [1, 0, 2, 0]])
val = torch.randn(4)
A = dglsp.spmatrix(i, val)
print(A.to_dense())

print("Apply tanh.")
A_new = dglsp.val_like(A, torch.tanh(A.val))
print(A_new.to_dense())
tensor([[ 0.0000, -0.7017,  0.0000],
        [ 0.3844,  0.0000,  0.1851],
        [-1.6436,  0.0000,  0.0000]])
Apply tanh.
tensor([[ 0.0000, -0.6054,  0.0000],
        [ 0.3665,  0.0000,  0.1830],
        [-0.9280,  0.0000,  0.0000]])

非稀疏保留函数,例如 exp()cos() 等。在应用这些函数之前,您需要先将稀疏矩阵转换为密集矩阵。

[25]:
i = torch.tensor([[0, 1, 1, 2],
                  [1, 0, 2, 0]])
val = torch.randn(4)
A = dglsp.spmatrix(i, val)
print(A.to_dense())

print("Apply exp.")
A_new = A.to_dense().exp()
print(A_new)
tensor([[ 0.0000, -1.0587,  0.0000],
        [ 0.7273,  0.0000, -1.8200],
        [-0.2745,  0.0000,  0.0000]])
Apply exp.
tensor([[1.0000, 0.3469, 1.0000],
        [2.0695, 1.0000, 0.1620],
        [0.7599, 1.0000, 1.0000]])

Softmax

对稀疏矩阵的非零条目应用行级 softmax。

[26]:
i = torch.tensor([[0, 1, 1, 2],
                  [1, 0, 2, 0]])
val = torch.tensor([1., 2., 3., 4.])
A = dglsp.spmatrix(i, val)

print(A.softmax())
print("In dense format:")
print(A.softmax().to_dense())
print("\n")
SparseMatrix(indices=tensor([[0, 1, 1, 2],
                             [1, 0, 2, 0]]),
             values=tensor([1.0000, 0.2689, 0.7311, 1.0000]),
             shape=(3, 3), nnz=4)
In dense format:
tensor([[0.0000, 1.0000, 0.0000],
        [0.2689, 0.0000, 0.7311],
        [1.0000, 0.0000, 0.0000]])


练习 #1

让我们测试一下您学到的内容。请随意使用 在 Colab 中打开

给定一个稀疏的对称邻接矩阵 \(A\),计算其对称归一化邻接矩阵

\[norm = \bar{D}^{-\frac{1}{2}}\bar{A}\bar{D}^{-\frac{1}{2}}\]

其中 \(\bar{A} = A + I\)\(I\) 是单位矩阵,\(\bar{D}\)\(\bar{A}\) 的对角节点度矩阵。

[27]:
i = torch.tensor([[0, 0, 1, 1, 2, 2, 3],
                  [1, 3, 2, 5, 3, 5, 4]])
asym_A = dglsp.spmatrix(i, shape=(6, 6))
# Step 1: create symmetrical adjacency matrix A from asym_A.
# A =

# Step 2: calculate A_hat from A.
# A_hat =

# Step 3: diagonal node degree matrix of A_hat
# D_hat =

# Step 4: calculate the norm from D_hat and A_hat.
# norm =

练习 #2

让我们实现一个简化版的图注意力网络 (GAT) 层。

GAT 层有两个输入:邻接矩阵 \(A\) 和节点输入特征 \(X\)。GAT 层的思想是用节点自身表示及其邻居表示的加权平均来更新每个节点的表示。具体来说,当计算节点 \(i\) 的输出时,GAT 层执行以下操作:1. 计算分数 \(S_{ij}\),表示从邻居 \(j\) 到节点 \(i\) 的注意力 logits。\(S_{ij}\) 是关于 \(i\)\(j\) 的输入特征 \(X_i\)\(X_j\) 的函数

\[S_{ij} = LeakyReLU(X_i^\top v_1 + X_j^\top v_2)\]

,其中 \(v_1\)\(v_2\) 是可训练向量。2. 计算 softmax 注意力 \(R_{ij} = \exp S_{ij} / \left( \sum_{j' \in \mathcal{N}_i} s_{ij'} \right)\),其中 \(\mathcal{N}_j\) 表示 \(j\) 的邻居。这意味着 \(R\)\(S\) 的行级 softmax 注意力。3. 计算加权平均 \(H_i = \sum_{j' : j' \in \mathcal{N}_i} R_{j'} X_{j'} W\),其中 \(W\) 是可训练矩阵。

以下代码定义了您需要的所有参数,但只完成了步骤 1。您能否实现步骤 2 和步骤 3?

[28]:
import torch.nn as nn
import torch.nn.functional as F

class SimplifiedGAT(nn.Module):
    def __init__(self, in_size, out_size):
        super().__init__()

        self.W = nn.Parameter(torch.randn(in_size, out_size))
        self.v1 = nn.Parameter(torch.randn(in_size))
        self.v2 = nn.Parameter(torch.randn(in_size))

    def forward(self, A, X):
        # A: A sparse matrix with size (N, N).  A[i, j] represent the edge from j to i.
        # X: A dense matrix with size (N, D)
        # Step 1: compute S[i, j]
        Xv1 = X @ self.v1
        Xv2 = X @ self.v2
        s = F.leaky_relu(Xv1[A.col] + Xv2[A.row])
        S = dglsp.val_like(A, s)

        # Step 2: compute R[i, j] which is the row-wise attention of $S$.
        # EXERCISE: replace the statement below.
        R = S

        # Step 3: compute H.
        # EXERCISE: replace the statement below.
        H = X

        return H
[29]:
# Test:
# Let's use the symmetric A created above.
X = torch.randn(6, 20)
module = SimplifiedGAT(20, 10)
Y = module(A, X)