dgl.topk_nodes
- dgl.topk_nodes(graph, feat, k, *, descending=True, sortby=None, ntype=None)[源代码]
通过在图
graph
的节点特征feat
上执行图范围内的 top-k 操作,并按索引sortby
处的特征进行排序,返回图级别的表示。如果将
descending
设置为 False,则返回 k 个最小元素。如果将
sortby
设置为 None,函数将独立地在所有维度上执行 top-k 操作,等同于调用torch.topk(graph.ndata[feat], dim=0)
。- 参数:
- 返回值:
sorted_feat (Tensor) – 形状为 \((B, K, D)\) 的张量,其中 \(B\) 是输入图的批大小。
sorted_idx (Tensor) – 形状为 \((B, K)\)(如果 sortby 设置为 None,则形状为 \((B, K, D)\))的张量,其中 \(B\) 是输入图的批大小,\(D\) 是特征维度。
注意事项
如果一个样本有 \(n\) 个节点且 \(n<k\),则
sorted_feat
张量将用零填充从第 \(n+1\) 行到第 \(k\) 行;示例
>>> import dgl >>> import torch as th
创建两个
DGLGraph
对象并初始化它们的节点特征。>>> g1 = dgl.graph(([0, 1], [2, 3])) # Graph 1 >>> g1.ndata['h'] = th.rand(4, 5) >>> g1.ndata['h'] tensor([[0.0297, 0.8307, 0.9140, 0.6702, 0.3346], [0.5901, 0.3030, 0.9280, 0.6893, 0.7997], [0.0880, 0.6515, 0.4451, 0.7507, 0.5297], [0.5171, 0.6379, 0.2695, 0.8954, 0.5197]])
>>> g2 = dgl.graph(([0, 1, 2], [2, 3, 4])) # Graph 2 >>> g2.ndata['h'] = th.rand(5, 5) >>> g2.ndata['h'] tensor([[0.3168, 0.3174, 0.5303, 0.0804, 0.3808], [0.1323, 0.2766, 0.4318, 0.6114, 0.1458], [0.1752, 0.9105, 0.5692, 0.8489, 0.0539], [0.1931, 0.4954, 0.3455, 0.3934, 0.0857], [0.5065, 0.5182, 0.5418, 0.1520, 0.3872]])
在批处理图中对节点属性
h
进行 Top-k 操作。>>> bg = dgl.batch([g1, g2], ndata=['h']) >>> dgl.topk_nodes(bg, 'h', 3) (tensor([[[0.5901, 0.8307, 0.9280, 0.8954, 0.7997], [0.5171, 0.6515, 0.9140, 0.7507, 0.5297], [0.0880, 0.6379, 0.4451, 0.6893, 0.5197]], [[0.5065, 0.9105, 0.5692, 0.8489, 0.3872], [0.3168, 0.5182, 0.5418, 0.6114, 0.3808], [0.1931, 0.4954, 0.5303, 0.3934, 0.1458]]]), tensor([[[1, 0, 1, 3, 1], [3, 2, 0, 2, 2], [2, 3, 2, 1, 3]], [[4, 2, 2, 2, 4], [0, 4, 4, 1, 0], [3, 3, 0, 3, 1]]]))
在批处理图中沿着最后一个维度对节点属性
h
进行 Top-k 操作。(用于 SortPooling)>>> dgl.topk_nodes(bg, 'h', 3, sortby=-1) (tensor([[[0.5901, 0.3030, 0.9280, 0.6893, 0.7997], [0.0880, 0.6515, 0.4451, 0.7507, 0.5297], [0.5171, 0.6379, 0.2695, 0.8954, 0.5197]], [[0.5065, 0.5182, 0.5418, 0.1520, 0.3872], [0.3168, 0.3174, 0.5303, 0.0804, 0.3808], [0.1323, 0.2766, 0.4318, 0.6114, 0.1458]]]), tensor([[1, 2, 3], [4, 0, 1]]))
在单个图中对节点属性
h
进行 Top-k 操作。>>> dgl.topk_nodes(g1, 'h', 3) (tensor([[[0.5901, 0.8307, 0.9280, 0.8954, 0.7997], [0.5171, 0.6515, 0.9140, 0.7507, 0.5297], [0.0880, 0.6379, 0.4451, 0.6893, 0.5197]]]), tensor([[[1, 0, 1, 3, 1], [3, 2, 0, 2, 2], [2, 3, 2, 1, 3]]]))