第5章:训练图神经网络

(中文版)

概述

本章讨论了如何使用第2章:消息传递中介绍的消息传递方法以及第3章:构建 GNN 模块中介绍的神经网络模块,训练图神经网络用于节点分类、边分类、链接预测和小型图的图分类任务。

本章假设你的图及其所有节点和边特征都能放入 GPU 显存中;如果不能,请参阅第6章:大型图上的随机训练

以下文本假设图以及节点/边特征已经准备好。如果你计划使用 DGL 提供的或其它兼容的 DGLDataset 数据集(如第4章:图数据处理流程中所述),你可以通过类似如下方式获取单图数据集中的图:

import dgl

dataset = dgl.data.CiteseerGraphDataset()
graph = dataset[0]

注意:在本章中,我们将使用 PyTorch 作为后端。

异构图

有时你会想要处理异构图。这里我们以一个合成的异构图为例,演示节点分类、边分类和链接预测任务。

合成异构图 hetero_graph 包含以下边类型:

  • ('user', 'follow', 'user')

  • ('user', 'followed-by', 'user')

  • ('user', 'click', 'item')

  • ('item', 'clicked-by', 'user')

  • ('user', 'dislike', 'item')

  • ('item', 'disliked-by', 'user')

import numpy as np
import torch

n_users = 1000
n_items = 500
n_follows = 3000
n_clicks = 5000
n_dislikes = 500
n_hetero_features = 10
n_user_classes = 5
n_max_clicks = 10

follow_src = np.random.randint(0, n_users, n_follows)
follow_dst = np.random.randint(0, n_users, n_follows)
click_src = np.random.randint(0, n_users, n_clicks)
click_dst = np.random.randint(0, n_items, n_clicks)
dislike_src = np.random.randint(0, n_users, n_dislikes)
dislike_dst = np.random.randint(0, n_items, n_dislikes)

hetero_graph = dgl.heterograph({
    ('user', 'follow', 'user'): (follow_src, follow_dst),
    ('user', 'followed-by', 'user'): (follow_dst, follow_src),
    ('user', 'click', 'item'): (click_src, click_dst),
    ('item', 'clicked-by', 'user'): (click_dst, click_src),
    ('user', 'dislike', 'item'): (dislike_src, dislike_dst),
    ('item', 'disliked-by', 'user'): (dislike_dst, dislike_src)})

hetero_graph.nodes['user'].data['feature'] = torch.randn(n_users, n_hetero_features)
hetero_graph.nodes['item'].data['feature'] = torch.randn(n_items, n_hetero_features)
hetero_graph.nodes['user'].data['label'] = torch.randint(0, n_user_classes, (n_users,))
hetero_graph.edges['click'].data['label'] = torch.randint(1, n_max_clicks, (n_clicks,)).float()
# randomly generate training masks on user nodes and click edges
hetero_graph.nodes['user'].data['train_mask'] = torch.zeros(n_users, dtype=torch.bool).bernoulli(0.6)
hetero_graph.edges['click'].data['train_mask'] = torch.zeros(n_clicks, dtype=torch.bool).bernoulli(0.6)

路线图

本章包含四个小节,每个小节介绍一种图学习任务。