dgl.heterograph

dgl.heterograph(data_dict, num_nodes_dict=None, idtype=None, device=None)[source]

创建异构图并返回。

参数:
  • data_dict (图数据) –

    用于构建异构图的字典数据。键是字符串三元组 (src_type, edge_type, dst_type),指定源节点类型、边类型和目标节点类型。值是 \((U, V)\) 形式的图数据,其中 \((U[i], V[i])\) 构成 ID 为 \(i\) 的边。允许的图数据格式包括

    • (Tensor, Tensor):每个张量必须是包含节点 ID 的一维张量。DGL 将此格式称为“节点张量元组”。张量必须具有相同的数据类型,必须是 int32 或 int64。它们也应该具有相同的设备上下文(参见下方 idtypedevice 的描述)。

    • ('coo', (Tensor, Tensor)):与 (Tensor, Tensor) 相同。

    • ('csr', (Tensor, Tensor, Tensor)):这三个张量构成了图邻接矩阵的 CSR 表示。第一个是行索引指针。第二个是列索引。第三个是边 ID,可以为空(即元素个数为 0),表示从 0 开始的连续整数 ID。

    • ('csc', (Tensor, Tensor, Tensor)):这三个张量构成了图邻接矩阵的 CSC 表示。第一个是列索引指针。第二个是行索引。第三个是边 ID,可以为空,表示从 0 开始的连续整数 ID。

    张量可以用任何可迭代的整数(例如 list, tuple, numpy.ndarray)代替。

  • num_nodes_dict (dict[str, int], optional) – 某些节点类型的节点数,是一个字典,将节点类型 \(T\) 映射到 \(T\) 类型节点的数量。如果未为节点类型 \(T\) 指定此参数,DGL 将查找其源节点或目标节点类型为 \(T\)所有图数据中出现的最大 ID,并将节点数设置为该 ID 加一。如果指定了此参数,且其值不大于某些节点类型的最大 ID,DGL 将引发错误。默认情况下,DGL 会推断所有节点类型的节点数。

  • idtype (int32 or int64, optional) – 用于存储结构相关图信息(如节点和边 ID)的数据类型。它应该是一个框架特定的数据类型对象(例如,torch.int32)。如果为 None(默认),DGL 会从 data_dict 参数推断 ID 类型。

  • device (设备上下文, optional) – 返回图的设备,应该是一个框架特定的设备对象(例如,torch.device)。如果为 None(默认),DGL 使用 data 参数的张量所在的设备。如果 data 不是节点张量元组,则返回的图在 CPU 上。如果指定的 device 与提供的张量所在的设备不同,DGL 会首先将给定的张量转换为指定的设备。

返回:

创建的图。

返回类型:

DGLGraph

注意事项

  1. 如果未给定 idtype 参数,则

    • 对于节点张量元组格式,DGL 使用给定 ID 张量的数据类型。

    • 对于序列元组格式,DGL 使用 int64。

    创建图后,您可以使用 dgl.DGLGraph.long()dgl.DGLGraph.int() 更改数据类型。

    如果指定的 idtype 参数与提供的张量的数据类型不同,DGL 会首先将给定的张量转换为指定的数据类型。

  2. 最有效的构建方法是提供节点张量元组,而不指定 idtypedevice。因为在这种情况下,返回的图与输入的节点张量共享存储。

  3. DGL 内部维护图结构在不同稀疏格式下的多个副本,并根据调用的计算选择最有效的格式。如果在大图情况下内存使用成为问题,请使用 dgl.DGLGraph.formats() 限制允许的格式。

  4. DGL 内部对相同集合的节点类型和规范边类型决定一个确定性顺序,该顺序不一定遵循 data_dict 中的顺序。

示例

以下示例使用 PyTorch 后端。

>>> import dgl
>>> import torch

创建一个包含三种规范边类型的异构图。

>>> data_dict = {
...     ('user', 'follows', 'user'): (torch.tensor([0, 1]), torch.tensor([1, 2])),
...     ('user', 'follows', 'topic'): (torch.tensor([1, 1]), torch.tensor([1, 2])),
...     ('user', 'plays', 'game'): (torch.tensor([0, 3]), torch.tensor([3, 4]))
... }
>>> g = dgl.heterograph(data_dict)
>>> g
Graph(num_nodes={'game': 5, 'topic': 3, 'user': 4},
      num_edges={('user', 'follows', 'topic'): 2, ('user', 'follows', 'user'): 2,
                 ('user', 'plays', 'game'): 2},
      metagraph=[('user', 'topic', 'follows'), ('user', 'user', 'follows'),
                 ('user', 'game', 'plays')])

显式指定图中每种节点类型的节点数。

>>> num_nodes_dict = {'user': 4, 'topic': 4, 'game': 6}
>>> g = dgl.heterograph(data_dict, num_nodes_dict=num_nodes_dict)

在第一个 GPU 上创建数据类型为 int32 的图。

>>> g = dgl.heterograph(data_dict, idtype=torch.int32, device='cuda:0')