dgl.node_subgraph
- dgl.node_subgraph(graph, nodes, *, relabel_nodes=True, store_ids=True, output_device=None)[源码]
返回在给定节点上诱导的子图。
节点诱导子图是包含其两端节点都在指定节点集中的边的图。除了提取子图,DGL 还将提取的节点和边的特征复制到结果图中。复制是惰性的,仅在需要时才会发生数据移动。
如果图是异构图,DGL 会为每种关系提取一个子图,并将它们组合成结果图。因此,结果图具有与输入图相同的关系集合。
- 参数:
graph (DGLGraph) – 用于提取子图的图。
nodes (nodes 或 dict[str, nodes]) –
用于构成子图的节点,不能有任何重复值。否则结果将是未定义的。允许的节点格式有
Int Tensor: 每个元素都是一个节点 ID。该 Tensor 必须与图具有相同的设备类型和 ID 数据类型。
iterable[int]: 每个元素都是一个节点 ID。
Bool Tensor: 每个第 \(i\) 个元素是一个布尔标志,指示节点 \(i\) 是否在子图中。
如果图是同构图,可以直接传递以上格式。否则,该参数必须是一个字典,键是节点类型,值是以上述格式表示的节点 ID。
relabel_nodes (bool, 可选) – 如果为 True,提取的子图将仅包含指定节点集中的节点,并将按顺序重新标记这些节点。
store_ids (bool, 可选) – 如果为 True,它将在结果图的
edata
中以名称dgl.EID
存储提取边的原始 ID;如果relabel_nodes
为True
,它还将在结果图的ndata
中以名称dgl.NID
存储指定节点的原始 ID。output_device (框架特定的设备上下文对象, 可选) – 输出设备。默认为与输入图相同。
- 返回值:
G – 子图。
- 返回类型:
注意
此函数会丢弃批处理信息。请在转换后的图上使用
dgl.DGLGraph.set_batch_num_nodes()
和dgl.DGLGraph.set_batch_num_edges()
来维护该信息。示例
以下示例使用 PyTorch 后端。
>>> import dgl >>> import torch
从同构图中提取子图。
>>> g = dgl.graph(([0, 1, 2, 3, 4], [1, 2, 3, 4, 0])) # 5-node cycle >>> sg = dgl.node_subgraph(g, [0, 1, 4]) >>> sg Graph(num_nodes=3, num_edges=2, ndata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64)} edata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64)}) >>> sg.edges() (tensor([0, 2]), tensor([1, 0])) >>> sg.ndata[dgl.NID] # original node IDs tensor([0, 1, 4]) >>> sg.edata[dgl.EID] # original edge IDs tensor([0, 4])
使用布尔掩码指定节点。
>>> nodes = torch.tensor([True, True, False, False, True]) # choose nodes [0, 1, 4] >>> dgl.node_subgraph(g, nodes) Graph(num_nodes=3, num_edges=2, ndata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64)} edata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64)})
结果子图也会从父图复制特征。
>>> g.ndata['x'] = torch.arange(10).view(5, 2) >>> sg = dgl.node_subgraph(g, [0, 1, 4]) >>> sg Graph(num_nodes=3, num_edges=2, ndata_schemes={'x': Scheme(shape=(2,), dtype=torch.int64), '_ID': Scheme(shape=(), dtype=torch.int64)} edata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64)}) >>> sg.ndata['x'] tensor([[0, 1], [2, 3], [8, 9]])
从异构图中提取子图。
>>> g = dgl.heterograph({ >>> ('user', 'plays', 'game'): ([0, 1, 1, 2], [0, 0, 2, 1]), >>> ('user', 'follows', 'user'): ([0, 1, 1], [1, 2, 2]) >>> }) >>> sub_g = dgl.node_subgraph(g, {'user': [1, 2]}) >>> sub_g Graph(num_nodes={'game': 0, 'user': 2}, num_edges={('user', 'follows', 'user'): 2, ('user', 'plays', 'game'): 0}, metagraph=[('user', 'user', 'follows'), ('user', 'game', 'plays')])
另请参阅