dgl.in_subgraph

dgl.in_subgraph(graph, nodes, *, relabel_nodes=False, store_ids=True, output_device=None)[源代码]

返回由给定节点的所有边类型的入边诱导的子图。

入边子图等价于使用给定节点的入边创建新图。除了提取子图外,DGL 还会将提取的节点和边的特征复制到结果图中。复制是惰性的,仅在需要时才会发生数据移动。

如果图是异构的,DGL 将按关系提取子图,并将其组合为结果图。因此,结果图与输入图具有相同的关系集。

参数:
  • graph (DGLGraph) – 输入图。

  • nodes (节点dict[str, 节点]) –

    用于形成子图的节点,不能包含任何重复值。否则结果将是未定义的。允许的节点格式为

    • Int Tensor:每个元素是一个节点 ID。Tensor 的设备类型和 ID 数据类型必须与图的相同。

    • iterable[int]:每个元素是一个节点 ID。

    如果图是同构的,可以直接传递上述格式。否则,参数必须是一个字典,其中键是节点类型,值是上述格式的节点 ID。

  • relabel_nodes (bool, 可选) – 如果为 True,它将删除孤立节点并重新标记提取子图中的其余节点。

  • store_ids (bool, 可选) – 如果为 True,它将在结果图的 edata 中存储提取边的原始 ID,名称为 dgl.EID;如果 relabel_nodesTrue,它还将在结果图的 ndata 中存储提取节点的原始 ID,名称为 dgl.NID

  • output_device (框架特定的设备上下文对象, 可选) – 输出设备。默认为与输入图相同。

返回:

子图。

返回类型:

DGLGraph

附注

此函数会丢弃批处理信息。请在转换后的图上使用 dgl.DGLGraph.set_batch_num_nodes()dgl.DGLGraph.set_batch_num_edges() 来维护该信息。

示例

以下示例使用 PyTorch 后端。

>>> import dgl
>>> import torch

从同构图中提取子图。

>>> g = dgl.graph(([0, 1, 2, 3, 4], [1, 2, 3, 4, 0]))  # 5-node cycle
>>> g.edata['w'] = torch.arange(10).view(5, 2)
>>> sg = dgl.in_subgraph(g, [2, 0])
>>> sg
Graph(num_nodes=5, num_edges=2,
      ndata_schemes={}
      edata_schemes={'w': Scheme(shape=(2,), dtype=torch.int64),
                     '_ID': Scheme(shape=(), dtype=torch.int64)})
>>> sg.edges()
(tensor([1, 4]), tensor([2, 0]))
>>> sg.edata[dgl.EID]  # original edge IDs
tensor([1, 4])
>>> sg.edata['w']  # also extract the features
tensor([[2, 3],
        [8, 9]])

提取带节点标签的子图。

>>> sg = dgl.in_subgraph(g, [2, 0], relabel_nodes=True)
>>> sg
Graph(num_nodes=4, num_edges=2,
      ndata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64}
      edata_schemes={'w': Scheme(shape=(2,), dtype=torch.int64),
                     '_ID': Scheme(shape=(), dtype=torch.int64)})
>>> sg.edges()
(tensor([1, 3]), tensor([2, 0]))
>>> sg.edata[dgl.EID]  # original edge IDs
tensor([1, 4])
>>> sg.ndata[dgl.NID]  # original node IDs
tensor([0, 1, 2, 4])

从异构图中提取子图。

>>> g = dgl.heterograph({
...     ('user', 'plays', 'game'): ([0, 1, 1, 2], [0, 0, 2, 1]),
...     ('user', 'follows', 'user'): ([0, 1, 1], [1, 2, 2])})
>>> sub_g = g.in_subgraph({'user': [2], 'game': [2]})
>>> sub_g
Graph(num_nodes={'game': 3, 'user': 3},
      num_edges={('user', 'plays', 'game'): 1, ('user', 'follows', 'user'): 2},
      metagraph=[('user', 'game', 'plays'), ('user', 'user', 'follows')])

另请参阅

out_subgraph