mask_nodes_by_property

class dgl.data.utils.mask_nodes_by_property(property_values, part_ratios, random_seed=None)[source]

基类

基于给定的节点属性,为具有分布偏移的节点分割提供分割掩码,如 Evaluating Robustness and Uncertainty of Graph Models Under Structural Distributional Shifts 中所提出。

它考虑了节点的分布内 (ID) 和分布外 (OOD) 子集。ID 子集包括训练、验证和测试部分,而 OOD 子集包括验证和测试部分。它按节点属性值的升序排序节点,将它们分成 5 个不相交的部分,并创建 5 个关联的节点掩码数组。

  • ID 节点的 3 个掩码:'in_train_mask', 'in_valid_mask', 'in_test_mask',

  • 以及 OOD 节点的 2 个掩码:'out_valid_mask', 'out_test_mask'

参数:
  • property_values (numpy ndarray) – 用于分割数据集的节点属性(浮点)值。数组长度必须等于图中的节点数。

  • part_ratios (list) – 一个包含 5 个比率的列表,分别用于训练、ID 验证、ID 测试、OOD 验证和 OOD 测试部分。列表中的值总和必须为一。

  • random_seed (int, optional) – 用于固定节点初始置换的随机种子。它用于为具有相同属性值或属于 ID 子集的节点创建随机顺序。(默认值:None)

返回:

split_masks – 一个 Python 字典,以掩码名称为键,相应的节点掩码数组为值。

返回类型:

dict

示例

>>> num_nodes = 1000
>>> property_values = np.random.uniform(size=num_nodes)
>>> part_ratios = [0.3, 0.1, 0.1, 0.3, 0.2]
>>> split_masks = dgl.data.utils.mask_nodes_by_property(property_values, part_ratios)
>>> print('in_valid_mask' in split_masks)
True