dgl.from_networkx

dgl.from_networkx(nx_graph, node_attrs=None, edge_attrs=None, edge_id_attr_name=None, idtype=None, device=None)[source]

从 NetworkX 图创建 DGL 图并返回。

注意

从 NetworkX 图创建 DGLGraph 速度不快,尤其对于大型图。建议先将 NetworkX 图转换为一个节点张量元组,然后使用 dgl.graph() 构建 DGLGraph。

参数:
  • nx_graph (networkx.Graph) – 持有图结构和节点/边属性的 NetworkX 图。如果节点不是从零开始的连续整数,DGL 将重新标记节点。如果输入图是无向图,DGL 会通过 networkx.Graph.to_directed() 将其转换为有向图。

  • node_attrs (list[str], 可选) – 要从 NetworkX 图中检索的节点属性名称。如果给定,DGL 会使用其原始名称将检索到的节点属性存储在返回图的 ndata 中。属性数据必须可转换为 Tensor 类型(例如,标量、numpy.ndarray、list 等)。

  • edge_attrs (list[str], 可选) – 要从 NetworkX 图中检索的边属性名称。如果给定,DGL 会使用其原始名称将检索到的边属性存储在返回图的 edata 中。属性数据必须可转换为 Tensor 类型(例如,标量、numpy.ndarray、list 等)。如果 nx_graph 是无向图,则此参数必须为 None。

  • edge_id_attr_name (str, 可选) – 存储边 ID 的边属性名称。如果给定,DGL 会在创建图时相应地分配边 ID,因此此属性必须是有效的 ID,即从零开始的连续整数。默认情况下,返回图的边 ID 可以是任意值。如果 nx_graph 是无向图,则此参数必须为 None。

  • idtype (int32int64, 可选) – 用于存储与结构相关的图信息(如节点和边 ID)的数据类型。它应该是一个框架特定的数据类型对象(例如,torch.int32)。默认情况下,DGL 使用 int64。

  • device (设备上下文, 可选) – 结果图的设备。它应该是一个框架特定的设备对象(例如,torch.device)。默认情况下,DGL 将图存储在 CPU 上。

返回:

创建的图。

返回类型:

DGLGraph

说明

DGL 内部维护着多种稀疏格式的图结构副本,并根据调用的计算选择最有效的格式。如果在大图情况下内存使用成为问题,可以使用 dgl.DGLGraph.formats() 来限制允许的格式。

示例

以下示例使用 PyTorch 后端。

>>> import dgl
>>> import networkx as nx
>>> import numpy as np
>>> import torch

创建一个包含 2 条边的 NetworkX 图。

>>> nx_g = nx.DiGraph()
>>> # Add 3 nodes and two features for them
>>> nx_g.add_nodes_from([0, 1, 2], feat1=np.zeros((3, 1)), feat2=np.ones((3, 1)))
>>> # Add 2 edges (1, 2) and (2, 1) with two features, one being edge IDs
>>> nx_g.add_edge(1, 2, weight=np.ones((1, 1)), eid=np.array([1]))
>>> nx_g.add_edge(2, 1, weight=np.ones((1, 1)), eid=np.array([0]))

将其转换为仅包含结构的 DGLGraph。

>>> g = dgl.from_networkx(nx_g)

检索图的节点/边特征。

>>> g = dgl.from_networkx(nx_g, node_attrs=['feat1', 'feat2'], edge_attrs=['weight'])

使用预先指定的边排序。

>>> g.edges()
(tensor([1, 2]), tensor([2, 1]))
>>> g = dgl.from_networkx(nx_g, edge_id_attr_name='eid')
(tensor([2, 1]), tensor([1, 2]))

在第一块 GPU 上创建一个数据类型为 int32 的图。

>>> g = dgl.from_networkx(nx_g, idtype=torch.int32, device='cuda:0')

另请参阅

graphfrom_scipy