dgl.from_networkx
- dgl.from_networkx(nx_graph, node_attrs=None, edge_attrs=None, edge_id_attr_name=None, idtype=None, device=None)[source]
从 NetworkX 图创建 DGL 图并返回。
注意
从 NetworkX 图创建 DGLGraph 速度不快,尤其对于大型图。建议先将 NetworkX 图转换为一个节点张量元组,然后使用
dgl.graph()
构建 DGLGraph。- 参数:
nx_graph (networkx.Graph) – 持有图结构和节点/边属性的 NetworkX 图。如果节点不是从零开始的连续整数,DGL 将重新标记节点。如果输入图是无向图,DGL 会通过
networkx.Graph.to_directed()
将其转换为有向图。node_attrs (list[str], 可选) – 要从 NetworkX 图中检索的节点属性名称。如果给定,DGL 会使用其原始名称将检索到的节点属性存储在返回图的
ndata
中。属性数据必须可转换为 Tensor 类型(例如,标量、numpy.ndarray、list 等)。edge_attrs (list[str], 可选) – 要从 NetworkX 图中检索的边属性名称。如果给定,DGL 会使用其原始名称将检索到的边属性存储在返回图的
edata
中。属性数据必须可转换为 Tensor 类型(例如,标量、numpy.ndarray
、list 等)。如果nx_graph
是无向图,则此参数必须为 None。edge_id_attr_name (str, 可选) – 存储边 ID 的边属性名称。如果给定,DGL 会在创建图时相应地分配边 ID,因此此属性必须是有效的 ID,即从零开始的连续整数。默认情况下,返回图的边 ID 可以是任意值。如果
nx_graph
是无向图,则此参数必须为 None。idtype (int32 或 int64, 可选) – 用于存储与结构相关的图信息(如节点和边 ID)的数据类型。它应该是一个框架特定的数据类型对象(例如,
torch.int32
)。默认情况下,DGL 使用 int64。device (设备上下文, 可选) – 结果图的设备。它应该是一个框架特定的设备对象(例如,
torch.device
)。默认情况下,DGL 将图存储在 CPU 上。
- 返回:
创建的图。
- 返回类型:
说明
DGL 内部维护着多种稀疏格式的图结构副本,并根据调用的计算选择最有效的格式。如果在大图情况下内存使用成为问题,可以使用
dgl.DGLGraph.formats()
来限制允许的格式。示例
以下示例使用 PyTorch 后端。
>>> import dgl >>> import networkx as nx >>> import numpy as np >>> import torch
创建一个包含 2 条边的 NetworkX 图。
>>> nx_g = nx.DiGraph() >>> # Add 3 nodes and two features for them >>> nx_g.add_nodes_from([0, 1, 2], feat1=np.zeros((3, 1)), feat2=np.ones((3, 1))) >>> # Add 2 edges (1, 2) and (2, 1) with two features, one being edge IDs >>> nx_g.add_edge(1, 2, weight=np.ones((1, 1)), eid=np.array([1])) >>> nx_g.add_edge(2, 1, weight=np.ones((1, 1)), eid=np.array([0]))
将其转换为仅包含结构的 DGLGraph。
>>> g = dgl.from_networkx(nx_g)
检索图的节点/边特征。
>>> g = dgl.from_networkx(nx_g, node_attrs=['feat1', 'feat2'], edge_attrs=['weight'])
使用预先指定的边排序。
>>> g.edges() (tensor([1, 2]), tensor([2, 1])) >>> g = dgl.from_networkx(nx_g, edge_id_attr_name='eid') (tensor([2, 1]), tensor([1, 2]))
在第一块 GPU 上创建一个数据类型为 int32 的图。
>>> g = dgl.from_networkx(nx_g, idtype=torch.int32, device='cuda:0')
另请参阅