dgl.DGLGraph.filter_nodes

DGLGraph.filter_nodes(predicate, nodes='__ALL__', ntype=None)[源码]

返回满足给定谓词的、具有给定节点类型的节点的 ID。

参数:
  • predicate (callable) – 签名 func(nodes) -> Tensor 的函数。nodesdgl.NodeBatch 对象。其输出张量应为一维布尔张量,每个元素指示批量中相应节点是否满足谓词。

  • nodes (节点 ID(s),可选) –

    用于查询的节点。允许的格式包括

    • Tensor: 一个包含用于查询的节点的一维张量,其数据类型和设备应与图的 idtype 和设备相同。

    • iterable[int] : 与张量类似,但将节点 ID 存储在序列中(例如,列表、元组、numpy.ndarray)。

    默认情况下,它考虑所有节点。

  • ntype (str, 可选) – 用于查询的节点类型。如果图有多个节点类型,必须指定此参数。否则,可以省略。

返回:

一个一维张量,包含满足谓词的节点的 ID。

返回类型:

Tensor

示例

以下示例使用 PyTorch 后端。

>>> import dgl
>>> import torch

定义一个谓词函数。

>>> def nodes_with_feature_one(nodes):
...     # Whether a node has feature 1
...     return (nodes.data['h'] == 1.).squeeze(1)

过滤同构图中的节点。

>>> g = dgl.graph((torch.tensor([0, 1, 2]), torch.tensor([1, 2, 3])))
>>> g.ndata['h'] = torch.tensor([[0.], [1.], [1.], [0.]])
>>> print(g.filter_nodes(nodes_with_feature_one))
tensor([1, 2])

过滤 ID 为 0 和 1 的节点

>>> print(g.filter_nodes(nodes_with_feature_one, nodes=torch.tensor([0, 1])))
tensor([1])

过滤异构图中的节点。

>>> g = dgl.heterograph({
...     ('user', 'plays', 'game'): (torch.tensor([0, 1, 1, 2]),
...                                 torch.tensor([0, 0, 1, 1]))})
>>> g.nodes['user'].data['h'] = torch.tensor([[0.], [1.], [1.]])
>>> g.nodes['game'].data['h'] = torch.tensor([[0.], [1.]])
>>> # Filter for 'user' nodes
>>> print(g.filter_nodes(nodes_with_feature_one, ntype='user'))
tensor([1, 2])