dgl.DGLGraph.filter_nodes
- DGLGraph.filter_nodes(predicate, nodes='__ALL__', ntype=None)[源码]
返回满足给定谓词的、具有给定节点类型的节点的 ID。
- 参数:
predicate (callable) – 签名
func(nodes) -> Tensor
的函数。nodes
是dgl.NodeBatch
对象。其输出张量应为一维布尔张量,每个元素指示批量中相应节点是否满足谓词。nodes (节点 ID(s),可选) –
用于查询的节点。允许的格式包括
Tensor: 一个包含用于查询的节点的一维张量,其数据类型和设备应与图的
idtype
和设备相同。iterable[int] : 与张量类似,但将节点 ID 存储在序列中(例如,列表、元组、numpy.ndarray)。
默认情况下,它考虑所有节点。
ntype (str, 可选) – 用于查询的节点类型。如果图有多个节点类型,必须指定此参数。否则,可以省略。
- 返回:
一个一维张量,包含满足谓词的节点的 ID。
- 返回类型:
Tensor
示例
以下示例使用 PyTorch 后端。
>>> import dgl >>> import torch
定义一个谓词函数。
>>> def nodes_with_feature_one(nodes): ... # Whether a node has feature 1 ... return (nodes.data['h'] == 1.).squeeze(1)
过滤同构图中的节点。
>>> g = dgl.graph((torch.tensor([0, 1, 2]), torch.tensor([1, 2, 3]))) >>> g.ndata['h'] = torch.tensor([[0.], [1.], [1.], [0.]]) >>> print(g.filter_nodes(nodes_with_feature_one)) tensor([1, 2])
过滤 ID 为 0 和 1 的节点
>>> print(g.filter_nodes(nodes_with_feature_one, nodes=torch.tensor([0, 1]))) tensor([1])
过滤异构图中的节点。
>>> g = dgl.heterograph({ ... ('user', 'plays', 'game'): (torch.tensor([0, 1, 1, 2]), ... torch.tensor([0, 0, 1, 1]))}) >>> g.nodes['user'].data['h'] = torch.tensor([[0.], [1.], [1.]]) >>> g.nodes['game'].data['h'] = torch.tensor([[0.], [1.]]) >>> # Filter for 'user' nodes >>> print(g.filter_nodes(nodes_with_feature_one, ntype='user')) tensor([1, 2])