dgl.to_simple
- dgl.to_simple(g, return_counts='count', writeback_mapping=False, copy_ndata=True, copy_edata=False, aggregator='arbitrary')[源代码]
将图转换为一个没有并行边的简单图并返回。
对于具有多个边类型的异构图,DGL 将具有相同边类型和端点的边视为并行边并将其移除。可以选择通过指定
return_counts
参数来获取并行边的数量。要获取从输入图中的边 ID 到结果图中的边 ID 的映射,请将writeback_mapping
设置为 true。- 参数:
g (DGLGraph) – 输入图。必须在 CPU 上。
return_counts (str, optional) –
如果给定,原始图中每条边的计数将作为边特征存储在
return_counts
名称下。同名的旧特征将被替换。(默认: “count”)
writeback_mapping (bool, optional) –
如果为 True,则为每种边类型返回一个额外的回写映射。回写映射是一个张量,记录了从输入图中的边 ID 到结果图中的边 ID 的映射。如果图是异构的,DGL 会返回一个包含边类型和这些张量的字典。
如果为 False,则只返回简单图。
(默认: False)
copy_ndata (bool, optional) –
如果为 True,则简单图的节点特征将从原始图复制。
如果为 False,则简单图将不包含任何节点特征。
(默认: True)
copy_edata (bool, optional) –
如果为 True,则简单图的边特征将从原始图复制。如果两个节点 (u, v) 之间存在重复边,则该边的特征是重复边特征的聚合。
如果为 False,则简单图将不包含任何边特征。
(默认: False)
aggregator (str, optional) –
指示如何合并重复边的边特征。如果为
arbitrary
,则选择其中一条重复边的特征。如果为sum
,则计算重复边特征的总和。如果为mean
,则计算重复边特征的平均值。(默认:
arbitrary
)
- 返回值:
DGLGraph – 结果图。
tensor 或 dict of tensor – 回写映射。仅当
writeback_mapping
为 True 时返回。
注意
如果
copy_ndata
为 True,结果图将与输入图共享节点特征张量。因此,用户应尽量避免可能对两个图都可见的就地操作。此函数会丢弃批次信息。请在转换后的图上使用
dgl.DGLGraph.set_batch_num_nodes()
和dgl.DGLGraph.set_batch_num_edges()
来维护该信息。示例
同构图
创建一个图来演示 to_simple API。在原始图中,节点 1 和节点 2 之间有多条边。
>>> import dgl >>> import torch as th >>> g = dgl.graph((th.tensor([0, 1, 2, 1]), th.tensor([1, 2, 0, 2]))) >>> g.ndata['h'] = th.tensor([[0.], [1.], [2.]]) >>> g.edata['h'] = th.tensor([[3.], [4.], [5.], [6.]])
将图转换为简单图。返回的计数存储在边特征 'cnt' 中,回写映射在一个张量中返回。
>>> sg, wm = dgl.to_simple(g, return_counts='cnt', writeback_mapping=True) >>> sg.ndata['h'] tensor([[0.], [1.], [2.]]) >>> u, v, eid = sg.edges(form='all') >>> u tensor([0, 1, 2]) >>> v tensor([1, 2, 0]) >>> eid tensor([0, 1, 2]) >>> sg.edata['cnt'] tensor([1, 2, 1]) >>> wm tensor([0, 1, 2, 1]) >>> 'h' in g.edata False
异构图
>>> g = dgl.heterograph({ ... ('user', 'wins', 'user'): (th.tensor([0, 2, 0, 2, 2]), th.tensor([1, 1, 2, 1, 0])), ... ('user', 'plays', 'game'): (th.tensor([1, 2, 1]), th.tensor([2, 1, 1])) ... }) >>> g.nodes['game'].data['hv'] = th.ones(3, 1) >>> g.edges['plays'].data['he'] = th.zeros(3, 1)
返回的计数作为每种边类型的默认边特征 'count' 存储。
>>> sg, wm = dgl.to_simple(g, copy_ndata=False, writeback_mapping=True) >>> sg Graph(num_nodes={'game': 3, 'user': 3}, num_edges={('user', 'wins', 'user'): 4, ('game', 'plays', 'user'): 3}, metagraph=[('user', 'user'), ('game', 'user')]) >>> sg.edges(etype='wins') (tensor([0, 2, 0, 2]), tensor([1, 1, 2, 0])) >>> wm[('user', 'wins', 'user')] tensor([0, 1, 2, 1, 3]) >>> sg.edges(etype='plays') (tensor([2, 1, 1]), tensor([1, 2, 1])) >>> wm[('user', 'plays', 'game')] tensor([0, 1, 2]) >>> 'hv' in sg.nodes['game'].data False >>> 'he' in sg.edges['plays'].data False >>> sg.edata['count'] {('user', 'wins', 'user'): tensor([1, 2, 1, 1]) ('user', 'plays', 'game'): tensor([1, 1, 1])}