dgl.graph
- dgl.graph(data, *, num_nodes=None, idtype=None, device=None, row_sorted=False, col_sorted=False)[源码]
创建一个图并返回。
- 参数:
data (图数据) –
用于构建图的数据,形式为 \(U, V)\)。\((U[i], V[i])\) 在图中形成 ID 为 \(i\) 的边。允许的数据格式有
(Tensor, Tensor)
: 每个张量必须是包含节点 ID 的一维张量。DGL 将这种格式称为“节点张量元组”。张量应具有相同的数据类型(int32/int64)和设备上下文(详见下方idtype
和device
的描述)。('coo', (Tensor, Tensor))
: 与(Tensor, Tensor)
相同。('csr', (Tensor, Tensor, Tensor))
: 这三个张量构成了图邻接矩阵的 CSR 表示。第一个是行索引指针。第二个是列索引。第三个是边 ID,可以为空,表示从 0 开始的连续整数 ID。('csc', (Tensor, Tensor, Tensor))
: 这三个张量构成了图邻接矩阵的 CSC 表示。第一个是列索引指针。第二个是行索引。第三个是边 ID,可以为空,表示从 0 开始的连续整数 ID。
张量可以替换为任何整数可迭代对象(例如 list、tuple、numpy.ndarray)。
num_nodes (int, 可选) – 图中节点的数量。如果未指定,则默认为
data
参数中最大的节点 ID 加 1。如果指定的值不大于data
参数中最大的节点 ID,DGL 将报错。idtype (int32 或 int64, 可选) – 用于存储图结构相关信息(如节点和边 ID)的数据类型。应为框架特定的数据类型对象(例如
torch.int32
)。如果为None
(默认),DGL 将从data
参数中推断 ID 类型。更多详情请参阅“说明”。device (设备上下文, 可选) – 返回图的设备,应为框架特定的设备对象(例如
torch.device
)。如果为None
(默认),DGL 将使用data
参数中张量的设备。如果data
不是节点张量元组,则返回的图位于 CPU 上。如果指定的device
与提供的张量的设备不同,DGL 会先将给定的张量转换为指定的设备。row_sorted (bool, 可选) – COO 的行是否按升序排列。
col_sorted (bool, 可选) – COO 的列在每行内是否按升序排列。这仅在
row_sorted
为 True 时有效。
- 返回:
创建的图。
- 返回类型:
说明
如果未提供
idtype
参数,则对于节点张量元组格式,DGL 使用给定 ID 张量的数据类型。
对于序列元组格式,DGL 使用 int64。
图创建后,您可以使用
dgl.DGLGraph.long()
或dgl.DGLGraph.int()
更改数据类型。如果指定的
idtype
参数与提供的张量的数据类型不同,DGL 会先将给定的张量转换为指定的数据类型。最有效的构建方法是提供节点张量元组,并且不指定
idtype
和device
。因为在这种情况下,返回的图与输入的节点张量共享存储。DGL 内部以不同的稀疏格式维护图结构的多个副本,并根据调用的计算选择最有效的格式。如果处理大型图时内存使用成为问题,请使用
dgl.DGLGraph.formats()
限制允许的格式。
示例
以下示例使用 PyTorch 后端。
>>> import dgl >>> import torch
创建一个小的三边图。
>>> # Source nodes for edges (2, 1), (3, 2), (4, 3) >>> src_ids = torch.tensor([2, 3, 4]) >>> # Destination nodes for edges (2, 1), (3, 2), (4, 3) >>> dst_ids = torch.tensor([1, 2, 3]) >>> g = dgl.graph((src_ids, dst_ids))
显式指定图中的节点数。
>>> g = dgl.graph((src_ids, dst_ids), num_nodes=100)
在第一个 GPU 上创建一个数据类型为 int32 的图。
>>> g = dgl.graph((src_ids, dst_ids), idtype=torch.int32, device='cuda:0')
使用 CSR 表示创建图
>>> g = dgl.graph(('csr', ([0, 0, 0, 1, 2, 3], [1, 2, 3], [])))
使用 CSR 表示和边 ID 创建相同的图。
>>> g = dgl.graph(('csr', ([0, 0, 0, 1, 2, 3], [1, 2, 3], [0, 1, 2])))
另请参阅