dgl.edge_type_subgraph
- dgl.edge_type_subgraph(graph, etypes, output_device=None)[source]
返回在给定边类型上导出的子图。
边类型导出的子图包含图的给定边类型子集中的所有边。如果某种类型的节点与这些边关联,它也包含该类型的所有节点。除了提取子图,DGL 还会将提取的节点和边的特征复制到结果图中。复制是惰性的,仅在需要时才会引起数据移动。
- 参数:
- 返回值:
G – 子图。
- 返回值类型:
注意事项
此函数会丢弃批处理信息。请在转换后的图上使用
dgl.DGLGraph.set_batch_num_nodes()
和dgl.DGLGraph.set_batch_num_edges()
来维护这些信息。示例
以下示例使用 PyTorch 后端。
>>> import dgl >>> import torch
实例化一个异构图。
>>> g = dgl.heterograph({ >>> ('user', 'plays', 'game'): ([0, 1, 1, 2], [0, 0, 2, 1]), >>> ('user', 'follows', 'user'): ([0, 1, 1], [1, 2, 2]) >>> }) >>> # Set edge features >>> g.edges['follows'].data['h'] = torch.tensor([[0.], [1.], [2.]])
获取子图。
>>> sub_g = g.edge_type_subgraph(['follows']) >>> sub_g Graph(num_nodes=3, num_edges=3, ndata_schemes={} edata_schemes={'h': Scheme(shape=(1,), dtype=torch.float32)})
获取共享边特征。
>>> sub_g.edges['follows'].data['h'] tensor([[0.], [1.], [2.]])
另请参阅