DGL 如何表示图?

在本教程结束时,您将能够

  • 从头开始在 DGL 中构建图。

  • 为图分配节点和边特征。

  • 查询 DGL 图的属性,例如节点度数和连接性。

  • 将一个 DGL 图转换为另一个图。

  • 加载和保存 DGL 图。

(预计时间:16 分钟)

DGL 图的构建

DGL 将有向图表示为 DGLGraph 对象。您可以通过指定图中的节点数以及源节点和目标节点的列表来构建图。图中的节点具有从 0 开始的连续 ID。

例如,以下代码构建了一个具有 5 个叶节点的有向星形图。中心节点的 ID 为 0。边从中心节点指向叶节点。

import os

os.environ["DGLBACKEND"] = "pytorch"
import dgl
import numpy as np
import torch

g = dgl.graph(([0, 0, 0, 0, 0], [1, 2, 3, 4, 5]), num_nodes=6)
# Equivalently, PyTorch LongTensors also work.
g = dgl.graph(
    (torch.LongTensor([0, 0, 0, 0, 0]), torch.LongTensor([1, 2, 3, 4, 5])),
    num_nodes=6,
)

# You can omit the number of nodes argument if you can tell the number of nodes from the edge list alone.
g = dgl.graph(([0, 0, 0, 0, 0], [1, 2, 3, 4, 5]))

图中的边具有从 0 开始的连续 ID,并且顺序与创建时源节点和目标节点的列表相同。

# Print the source and destination nodes of every edge.
print(g.edges())
(tensor([0, 0, 0, 0, 0]), tensor([1, 2, 3, 4, 5]))

注意

DGLGraph 对象始终是有向的,以最适合图神经网络的计算模式,其中从一个节点发送到另一个节点的消息在两个方向上通常是不同的。如果您想处理无向图,可以考虑将其视为双向图。请参阅图变换了解如何创建双向图的示例。

为图分配节点和边特征

许多图数据包含节点和边的属性。尽管节点和边属性的类型在现实世界中可以是任意的,但 DGLGraph 只接受存储在张量中(具有数值内容)的属性。因此,所有节点或边的属性必须具有相同的形状。在深度学习的背景下,这些属性通常被称为特征

您可以通过 ndataedata 接口分配和检索节点和边特征。

# Assign a 3-dimensional node feature vector for each node.
g.ndata["x"] = torch.randn(6, 3)
# Assign a 4-dimensional edge feature vector for each edge.
g.edata["a"] = torch.randn(5, 4)
# Assign a 5x4 node feature matrix for each node.  Node and edge features in DGL can be multi-dimensional.
g.ndata["y"] = torch.randn(6, 5, 4)

print(g.edata["a"])
tensor([[ 0.0240, -1.5106,  1.5920,  0.6187],
        [ 0.4512,  1.4332,  0.2833, -0.1837],
        [ 1.7433, -2.2757,  0.4255,  0.8959],
        [ 1.3513,  1.0305,  0.2473, -0.6149],
        [ 0.2174, -2.3516,  0.5551, -0.9183]])

注意

深度学习的巨大发展为我们提供了许多将各种类型的属性编码为数值特征的方法。以下是一些一般性建议

  • 对于类别属性(例如性别、职业),考虑将它们转换为整数或进行独热编码。

  • 对于变长字符串内容(例如新闻文章、引用),考虑应用语言模型。

  • 对于图像,考虑应用视觉模型,例如 CNN。

您可以在 PyTorch 深度学习教程 中找到大量关于如何将此类属性编码到张量中的资料。

查询图结构

DGLGraph 对象提供了多种方法来查询图结构。

print(g.num_nodes())
print(g.num_edges())
# Out degrees of the center node
print(g.out_degrees(0))
# In degrees of the center node - note that the graph is directed so the in degree should be 0.
print(g.in_degrees(0))
6
5
5
0

图变换

DGL 提供了许多 API 来将一个图转换为另一个图,例如提取子图

# Induce a subgraph from node 0, node 1 and node 3 from the original graph.
sg1 = g.subgraph([0, 1, 3])
# Induce a subgraph from edge 0, edge 1 and edge 3 from the original graph.
sg2 = g.edge_subgraph([0, 1, 3])

您可以通过查看新图中的节点特征 dgl.NID 或边特征 dgl.EID 来获取从子图到原始图的节点/边映射。

# The original IDs of each node in sg1
print(sg1.ndata[dgl.NID])
# The original IDs of each edge in sg1
print(sg1.edata[dgl.EID])
# The original IDs of each node in sg2
print(sg2.ndata[dgl.NID])
# The original IDs of each edge in sg2
print(sg2.edata[dgl.EID])
tensor([0, 1, 3])
tensor([0, 2])
tensor([0, 1, 2, 4])
tensor([0, 1, 3])

subgraphedge_subgraph 也会将原始特征复制到子图

# The original node feature of each node in sg1
print(sg1.ndata["x"])
# The original edge feature of each node in sg1
print(sg1.edata["a"])
# The original node feature of each node in sg2
print(sg2.ndata["x"])
# The original edge feature of each node in sg2
print(sg2.edata["a"])
tensor([[-0.4628,  0.5122,  0.8058],
        [ 1.3146,  0.5920,  0.0904],
        [-1.0726,  1.5550,  1.3860]])
tensor([[ 0.0240, -1.5106,  1.5920,  0.6187],
        [ 1.7433, -2.2757,  0.4255,  0.8959]])
tensor([[-0.4628,  0.5122,  0.8058],
        [ 1.3146,  0.5920,  0.0904],
        [ 1.4425,  1.2905,  0.8477],
        [-1.5461,  0.0847,  1.1992]])
tensor([[ 0.0240, -1.5106,  1.5920,  0.6187],
        [ 0.4512,  1.4332,  0.2833, -0.1837],
        [ 1.3513,  1.0305,  0.2473, -0.6149]])

另一种常见的变换是使用 dgl.add_reverse_edges 为原始图中的每条边添加一条反向边。

注意

如果您有一个无向图,最好先通过添加反向边将其转换为双向图。

newg = dgl.add_reverse_edges(g)
print(newg.edges())
(tensor([0, 0, 0, 0, 0, 1, 2, 3, 4, 5]), tensor([1, 2, 3, 4, 5, 0, 0, 0, 0, 0]))

加载和保存图

您可以使用 dgl.save_graphs 保存一个图或图列表,并使用 dgl.load_graphs 将它们重新加载回来。

# Save graphs
dgl.save_graphs("graph.dgl", g)
dgl.save_graphs("graphs.dgl", [g, sg1, sg2])

# Load graphs
(g,), _ = dgl.load_graphs("graph.dgl")
print(g)
(g, sg1, sg2), _ = dgl.load_graphs("graphs.dgl")
print(g)
print(sg1)
print(sg2)
Graph(num_nodes=6, num_edges=5,
      ndata_schemes={'y': Scheme(shape=(5, 4), dtype=torch.float32), 'x': Scheme(shape=(3,), dtype=torch.float32)}
      edata_schemes={'a': Scheme(shape=(4,), dtype=torch.float32)})
Graph(num_nodes=6, num_edges=5,
      ndata_schemes={'y': Scheme(shape=(5, 4), dtype=torch.float32), 'x': Scheme(shape=(3,), dtype=torch.float32)}
      edata_schemes={'a': Scheme(shape=(4,), dtype=torch.float32)})
Graph(num_nodes=3, num_edges=2,
      ndata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64), 'y': Scheme(shape=(5, 4), dtype=torch.float32), 'x': Scheme(shape=(3,), dtype=torch.float32)}
      edata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64), 'a': Scheme(shape=(4,), dtype=torch.float32)})
Graph(num_nodes=4, num_edges=3,
      ndata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64), 'y': Scheme(shape=(5, 4), dtype=torch.float32), 'x': Scheme(shape=(3,), dtype=torch.float32)}
      edata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64), 'a': Scheme(shape=(4,), dtype=torch.float32)})

接下来?

# Thumbnail credits: Wikipedia
# sphinx_gallery_thumbnail_path = '_static/blitz_2_dglgraph.png'

脚本总运行时间: (0 分钟 0.014 秒)

由 Sphinx-Gallery 生成的图库