dgl.graph

dgl.graph(data, *, num_nodes=None, idtype=None, device=None, row_sorted=False, col_sorted=False)[源码]

创建一个图并返回。

参数:
  • data (图数据) –

    用于构建图的数据,形式为 \(U, V)\)。\((U[i], V[i])\) 在图中形成 ID 为 \(i\) 的边。允许的数据格式有

    • (Tensor, Tensor): 每个张量必须是包含节点 ID 的一维张量。DGL 将这种格式称为“节点张量元组”。张量应具有相同的数据类型(int32/int64)和设备上下文(详见下方 idtypedevice 的描述)。

    • ('coo', (Tensor, Tensor)): 与 (Tensor, Tensor) 相同。

    • ('csr', (Tensor, Tensor, Tensor)): 这三个张量构成了图邻接矩阵的 CSR 表示。第一个是行索引指针。第二个是列索引。第三个是边 ID,可以为空,表示从 0 开始的连续整数 ID。

    • ('csc', (Tensor, Tensor, Tensor)): 这三个张量构成了图邻接矩阵的 CSC 表示。第一个是列索引指针。第二个是行索引。第三个是边 ID,可以为空,表示从 0 开始的连续整数 ID。

    张量可以替换为任何整数可迭代对象(例如 list、tuple、numpy.ndarray)。

  • num_nodes (int, 可选) – 图中节点的数量。如果未指定,则默认为 data 参数中最大的节点 ID 加 1。如果指定的值不大于 data 参数中最大的节点 ID,DGL 将报错。

  • idtype (int32int64, 可选) – 用于存储图结构相关信息(如节点和边 ID)的数据类型。应为框架特定的数据类型对象(例如 torch.int32)。如果为 None(默认),DGL 将从 data 参数中推断 ID 类型。更多详情请参阅“说明”。

  • device (设备上下文, 可选) – 返回图的设备,应为框架特定的设备对象(例如 torch.device)。如果为 None(默认),DGL 将使用 data 参数中张量的设备。如果 data 不是节点张量元组,则返回的图位于 CPU 上。如果指定的 device 与提供的张量的设备不同,DGL 会先将给定的张量转换为指定的设备。

  • row_sorted (bool, 可选) – COO 的行是否按升序排列。

  • col_sorted (bool, 可选) – COO 的列在每行内是否按升序排列。这仅在 row_sorted 为 True 时有效。

返回:

创建的图。

返回类型:

DGLGraph

说明

  1. 如果未提供 idtype 参数,则

    • 对于节点张量元组格式,DGL 使用给定 ID 张量的数据类型。

    • 对于序列元组格式,DGL 使用 int64。

    图创建后,您可以使用 dgl.DGLGraph.long()dgl.DGLGraph.int() 更改数据类型。

    如果指定的 idtype 参数与提供的张量的数据类型不同,DGL 会先将给定的张量转换为指定的数据类型。

  2. 最有效的构建方法是提供节点张量元组,并且不指定 idtypedevice。因为在这种情况下,返回的图与输入的节点张量共享存储。

  3. DGL 内部以不同的稀疏格式维护图结构的多个副本,并根据调用的计算选择最有效的格式。如果处理大型图时内存使用成为问题,请使用 dgl.DGLGraph.formats() 限制允许的格式。

示例

以下示例使用 PyTorch 后端。

>>> import dgl
>>> import torch

创建一个小的三边图。

>>> # Source nodes for edges (2, 1), (3, 2), (4, 3)
>>> src_ids = torch.tensor([2, 3, 4])
>>> # Destination nodes for edges (2, 1), (3, 2), (4, 3)
>>> dst_ids = torch.tensor([1, 2, 3])
>>> g = dgl.graph((src_ids, dst_ids))

显式指定图中的节点数。

>>> g = dgl.graph((src_ids, dst_ids), num_nodes=100)

在第一个 GPU 上创建一个数据类型为 int32 的图。

>>> g = dgl.graph((src_ids, dst_ids), idtype=torch.int32, device='cuda:0')

使用 CSR 表示创建图

>>> g = dgl.graph(('csr', ([0, 0, 0, 1, 2, 3], [1, 2, 3], [])))

使用 CSR 表示和边 ID 创建相同的图。

>>> g = dgl.graph(('csr', ([0, 0, 0, 1, 2, 3], [1, 2, 3], [0, 1, 2])))

另请参阅

from_scipy, from_networkx