dgl.to_simple

dgl.to_simple(g, return_counts='count', writeback_mapping=False, copy_ndata=True, copy_edata=False, aggregator='arbitrary')[源代码]

将图转换为一个没有并行边的简单图并返回。

对于具有多个边类型的异构图,DGL 将具有相同边类型和端点的边视为并行边并将其移除。可以选择通过指定 return_counts 参数来获取并行边的数量。要获取从输入图中的边 ID 到结果图中的边 ID 的映射,请将 writeback_mapping 设置为 true。

参数:
  • g (DGLGraph) – 输入图。必须在 CPU 上。

  • return_counts (str, optional) –

    如果给定,原始图中每条边的计数将作为边特征存储在 return_counts 名称下。同名的旧特征将被替换。

    (默认: “count”)

  • writeback_mapping (bool, optional) –

    如果为 True,则为每种边类型返回一个额外的回写映射。回写映射是一个张量,记录了从输入图中的边 ID 到结果图中的边 ID 的映射。如果图是异构的,DGL 会返回一个包含边类型和这些张量的字典。

    如果为 False,则只返回简单图。

    (默认: False)

  • copy_ndata (bool, optional) –

    如果为 True,则简单图的节点特征将从原始图复制。

    如果为 False,则简单图将不包含任何节点特征。

    (默认: True)

  • copy_edata (bool, optional) –

    如果为 True,则简单图的边特征将从原始图复制。如果两个节点 (u, v) 之间存在重复边,则该边的特征是重复边特征的聚合。

    如果为 False,则简单图将不包含任何边特征。

    (默认: False)

  • aggregator (str, optional) –

    指示如何合并重复边的边特征。如果为 arbitrary,则选择其中一条重复边的特征。如果为 sum,则计算重复边特征的总和。如果为 mean,则计算重复边特征的平均值。

    (默认: arbitrary)

返回值:

  • DGLGraph – 结果图。

  • tensor 或 dict of tensor – 回写映射。仅当 writeback_mapping 为 True 时返回。

注意

如果 copy_ndata 为 True,结果图将与输入图共享节点特征张量。因此,用户应尽量避免可能对两个图都可见的就地操作。

此函数会丢弃批次信息。请在转换后的图上使用 dgl.DGLGraph.set_batch_num_nodes()dgl.DGLGraph.set_batch_num_edges() 来维护该信息。

示例

同构图

创建一个图来演示 to_simple API。在原始图中,节点 1 和节点 2 之间有多条边。

>>> import dgl
>>> import torch as th
>>> g = dgl.graph((th.tensor([0, 1, 2, 1]), th.tensor([1, 2, 0, 2])))
>>> g.ndata['h'] = th.tensor([[0.], [1.], [2.]])
>>> g.edata['h'] = th.tensor([[3.], [4.], [5.], [6.]])

将图转换为简单图。返回的计数存储在边特征 'cnt' 中,回写映射在一个张量中返回。

>>> sg, wm = dgl.to_simple(g, return_counts='cnt', writeback_mapping=True)
>>> sg.ndata['h']
tensor([[0.],
        [1.],
        [2.]])
>>> u, v, eid = sg.edges(form='all')
>>> u
tensor([0, 1, 2])
>>> v
tensor([1, 2, 0])
>>> eid
tensor([0, 1, 2])
>>> sg.edata['cnt']
tensor([1, 2, 1])
>>> wm
tensor([0, 1, 2, 1])
>>> 'h' in g.edata
False

异构图

>>> g = dgl.heterograph({
...     ('user', 'wins', 'user'): (th.tensor([0, 2, 0, 2, 2]), th.tensor([1, 1, 2, 1, 0])),
...     ('user', 'plays', 'game'): (th.tensor([1, 2, 1]), th.tensor([2, 1, 1]))
... })
>>> g.nodes['game'].data['hv'] = th.ones(3, 1)
>>> g.edges['plays'].data['he'] = th.zeros(3, 1)

返回的计数作为每种边类型的默认边特征 'count' 存储。

>>> sg, wm = dgl.to_simple(g, copy_ndata=False, writeback_mapping=True)
>>> sg
Graph(num_nodes={'game': 3, 'user': 3},
      num_edges={('user', 'wins', 'user'): 4, ('game', 'plays', 'user'): 3},
      metagraph=[('user', 'user'), ('game', 'user')])
>>> sg.edges(etype='wins')
(tensor([0, 2, 0, 2]), tensor([1, 1, 2, 0]))
>>> wm[('user', 'wins', 'user')]
tensor([0, 1, 2, 1, 3])
>>> sg.edges(etype='plays')
(tensor([2, 1, 1]), tensor([1, 2, 1]))
>>> wm[('user', 'plays', 'game')]
tensor([0, 1, 2])
>>> 'hv' in sg.nodes['game'].data
False
>>> 'he' in sg.edges['plays'].data
False
>>> sg.edata['count']
{('user', 'wins', 'user'): tensor([1, 2, 1, 1])
 ('user', 'plays', 'game'): tensor([1, 1, 1])}