SortPooling

class dgl.nn.pytorch.glob.SortPooling(k)[源码]

基类: Module

Sort Pooling 源自 An End-to-End Deep Learning Architecture for Graph Classification

它首先沿特征维度按升序对节点特征进行排序,然后选择前 k 个节点(按每个节点的最大值排序)的排序特征。

参数:

k (int) – 每个图保留的节点数量。

注意事项

输入: 可以是单个图,也可以是批量图。如果使用批量图,请确保所有图中的节点具有相同的特征大小,并将节点的特征连接起来作为输入。

示例

>>> import dgl
>>> import torch as th
>>> from dgl.nn import SortPooling
>>>
>>> g1 = dgl.rand_graph(3, 4)  # g1 is a random graph with 3 nodes and 4 edges
>>> g1_node_feats = th.rand(3, 5)  # feature size is 5
>>> g1_node_feats
tensor([[0.8948, 0.0699, 0.9137, 0.7567, 0.3637],
        [0.8137, 0.8938, 0.8377, 0.4249, 0.6118],
        [0.5197, 0.9030, 0.6825, 0.5725, 0.4755]])
>>>
>>> g2 = dgl.rand_graph(4, 6)  # g2 is a random graph with 4 nodes and 6 edges
>>> g2_node_feats = th.rand(4, 5)  # feature size is 5
>>> g2_node_feats
tensor([[0.2053, 0.2426, 0.4111, 0.9028, 0.5658],
        [0.5278, 0.6365, 0.9990, 0.2351, 0.8945],
        [0.3134, 0.0580, 0.4349, 0.7949, 0.3891],
        [0.0142, 0.2709, 0.3330, 0.8521, 0.6925]])
>>>
>>> sortpool = SortPooling(k=2)  # create a sort pooling layer

情况 1: 输入单个图

>>> sortpool(g1, g1_node_feats)
tensor([[0.0699, 0.3637, 0.7567, 0.8948, 0.9137, 0.4755, 0.5197, 0.5725, 0.6825,
         0.9030]])

情况 2: 输入批量图

构建批量 DGL 图,并将所有图的节点特征连接到一个张量中。

>>> batch_g = dgl.batch([g1, g2])
>>> batch_f = th.cat([g1_node_feats, g2_node_feats])
>>>
>>> sortpool(batch_g, batch_f)
tensor([[0.0699, 0.3637, 0.7567, 0.8948, 0.9137, 0.4755, 0.5197, 0.5725, 0.6825,
         0.9030],
        [0.2351, 0.5278, 0.6365, 0.8945, 0.9990, 0.2053, 0.2426, 0.4111, 0.5658,
         0.9028]])
forward(graph, feat)[源码]

计算 Sort Pooling。

参数:
  • graph (DGLGraph) – 一个 DGLGraph 或批量 DGLGraphs。

  • feat (torch.Tensor) – 输入节点特征,形状为 \((N, D)\),其中 \(N\) 是图中的节点数量,\(D\) 表示特征的大小。

返回值:

输出特征,形状为 \((B, k * D)\),其中 \(B\) 指输入图的批量大小。

返回类型:

torch.Tensor