dgl.edge_type_subgraph

dgl.edge_type_subgraph(graph, etypes, output_device=None)[source]

返回在给定边类型上导出的子图。

边类型导出的子图包含图的给定边类型子集中的所有边。如果某种类型的节点与这些边关联,它也包含该类型的所有节点。除了提取子图,DGL 还会将提取的节点和边的特征复制到结果图中。复制是惰性的,仅在需要时才会引起数据移动。

参数:
  • graph (DGLGraph) – 从中提取子图的图。

  • etypes (list[str] or list[(str, str, str)]) –

    子图中边的类型名称。允许的类型名称格式为

    • (str, str, str) 表示源节点类型、边类型和目标节点类型。

    • 或者一个 str 表示边类型名称,如果该名称在图中可以唯一标识一个三元组格式。

  • output_device (Framework-specific device context object, optional) – 输出设备。默认为与输入图相同的设备。

返回值:

G – 子图。

返回值类型:

DGLGraph

注意事项

此函数会丢弃批处理信息。请在转换后的图上使用 dgl.DGLGraph.set_batch_num_nodes()dgl.DGLGraph.set_batch_num_edges() 来维护这些信息。

示例

以下示例使用 PyTorch 后端。

>>> import dgl
>>> import torch

实例化一个异构图。

>>> g = dgl.heterograph({
>>>     ('user', 'plays', 'game'): ([0, 1, 1, 2], [0, 0, 2, 1]),
>>>     ('user', 'follows', 'user'): ([0, 1, 1], [1, 2, 2])
>>> })
>>> # Set edge features
>>> g.edges['follows'].data['h'] = torch.tensor([[0.], [1.], [2.]])

获取子图。

>>> sub_g = g.edge_type_subgraph(['follows'])
>>> sub_g
Graph(num_nodes=3, num_edges=3,
      ndata_schemes={}
      edata_schemes={'h': Scheme(shape=(1,), dtype=torch.float32)})

获取共享边特征。

>>> sub_g.edges['follows'].data['h']
tensor([[0.],
        [1.],
        [2.]])

另请参阅

node_type_subgraph