mask_nodes_by_property
- class dgl.data.utils.mask_nodes_by_property(property_values, part_ratios, random_seed=None)[source]
基类
基于给定的节点属性,为具有分布偏移的节点分割提供分割掩码,如 Evaluating Robustness and Uncertainty of Graph Models Under Structural Distributional Shifts 中所提出。
它考虑了节点的分布内 (ID) 和分布外 (OOD) 子集。ID 子集包括训练、验证和测试部分,而 OOD 子集包括验证和测试部分。它按节点属性值的升序排序节点,将它们分成 5 个不相交的部分,并创建 5 个关联的节点掩码数组。
ID 节点的 3 个掩码:
'in_train_mask'
,'in_valid_mask'
,'in_test_mask'
,以及 OOD 节点的 2 个掩码:
'out_valid_mask'
,'out_test_mask'
。
- 参数:
- 返回:
split_masks – 一个 Python 字典,以掩码名称为键,相应的节点掩码数组为值。
- 返回类型:
示例
>>> num_nodes = 1000 >>> property_values = np.random.uniform(size=num_nodes) >>> part_ratios = [0.3, 0.1, 0.1, 0.3, 0.2] >>> split_masks = dgl.data.utils.mask_nodes_by_property(property_values, part_ratios) >>> print('in_valid_mask' in split_masks) True