6.2 使用邻居采样训练用于边分类的 GNN
用于边分类/回归的训练与节点分类/回归的训练有些相似,但有几个显著差异。
定义邻居采样器和数据加载器
您可以使用与节点分类相同的邻居采样器。
datapipe = datapipe.sample_neighbor(g, [10, 10])
# Or equivalently
datapipe = dgl.graphbolt.NeighborSampler(datapipe, g, [10, 10])
定义数据加载器的代码也与节点分类相同。唯一的区别是它迭代的是训练集中的边(即节点对),而不是节点。
import dgl.graphbolt as gb
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
g = gb.SamplingGraph()
seeds = torch.arange(0, 1000).reshape(-1, 2)
labels = torch.randint(0, 2, (5,))
train_set = gb.ItemSet((seeds, labels), names=("seeds", "labels"))
datapipe = gb.ItemSampler(train_set, batch_size=128, shuffle=True)
datapipe = datapipe.sample_neighbor(g, [10, 10]) # 2 layers.
# Or equivalently:
# datapipe = gb.NeighborSampler(datapipe, g, [10, 10])
datapipe = datapipe.fetch_feature(feature, node_feature_keys=["feat"])
datapipe = datapipe.copy_to(device)
dataloader = gb.DataLoader(datapipe)
迭代数据加载器将产生 MiniBatch
,其中包含一个特殊创建的图列表,代表每层上的计算依赖关系。您可以通过 mini_batch.blocks 访问 消息流图 (MFGs)。
注意
关于消息流图的概念,请参阅:doc:`随机训练教程 <../notebooks/stochastic_training/neighbor_sampling_overview.nblink>`__。
如果您想开发自己的邻居采样器或需要更详细地解释 MFG 的概念,请参阅6.4 实现自定义图采样器。
从原始图中移除 minibatch 中的边进行邻居采样
训练边分类模型时,有时您希望从计算依赖关系中移除出现在训练数据中的边,就像它们从未存在过一样。否则,模型将“知道”两个节点之间存在一条边的事实,并可能利用它获得优势。
因此,在边分类中,您有时会希望从采样的 minibatch 中排除种子边及其反向边。您可以使用 exclude_seed_edges()
和 MiniBatchTransformer
来实现这一点。
import dgl.graphbolt as gb
from functools import partial
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
g = gb.SamplingGraph()
seeds = torch.arange(0, 1000).reshape(-1, 2)
labels = torch.randint(0, 2, (5,))
train_set = gb.ItemSet((seeds, labels), names=("seeds", "labels"))
datapipe = gb.ItemSampler(train_set, batch_size=128, shuffle=True)
datapipe = datapipe.sample_neighbor(g, [10, 10]) # 2 layers.
exclude_seed_edges = partial(gb.exclude_seed_edges, include_reverse_edges=True)
datapipe = datapipe.transform(exclude_seed_edges)
datapipe = datapipe.fetch_feature(feature, node_feature_keys=["feat"])
datapipe = datapipe.copy_to(device)
dataloader = gb.DataLoader(datapipe)
使您的模型适应 minibatch 训练
边分类模型通常包含两部分
一部分用于获取关联节点的表示。
另一部分用于从关联节点表示计算边得分。
前一部分与节点分类的模型完全相同,我们可以直接重用它。输入仍然是由 DGL 提供的数据加载器生成的 MFG 列表,以及输入特征。
class StochasticTwoLayerGCN(nn.Module):
def __init__(self, in_features, hidden_features, out_features):
super().__init__()
self.conv1 = dglnn.GraphConv(in_features, hidden_features)
self.conv2 = dglnn.GraphConv(hidden_features, out_features)
def forward(self, blocks, x):
x = F.relu(self.conv1(blocks[0], x))
x = F.relu(self.conv2(blocks[1], x))
return x
后一部分的输入通常是前一部分的输出,以及由 minibatch 中的边诱导的原始图的子图(节点对)。该子图是由同一数据加载器生成的。
以下代码展示了一个示例,通过连接关联节点特征并使用密集层进行投影来预测边的得分。
class ScorePredictor(nn.Module):
def __init__(self, num_classes, in_features):
super().__init__()
self.W = nn.Linear(2 * in_features, num_classes)
def forward(self, seeds, x):
src_x = x[seeds[:, 0]]
dst_x = x[seeds[:, 1]]
data = torch.cat([src_x, dst_x], 1)
return self.W(data)
整个模型将接受数据加载器生成的 MFG 列表和边,以及如下所示的输入节点特征
class Model(nn.Module):
def __init__(self, in_features, hidden_features, out_features, num_classes):
super().__init__()
self.gcn = StochasticTwoLayerGCN(
in_features, hidden_features, out_features)
self.predictor = ScorePredictor(num_classes, out_features)
def forward(self, blocks, x, seeds):
x = self.gcn(blocks, x)
return self.predictor(seeds, x)
DGL 确保边子图中的节点与生成的 MFG 列表中的最后一个 MFG 的输出节点相同。
训练循环
训练循环与节点分类非常相似。您可以迭代数据加载器并获取由 minibatch 中的边诱导的子图,以及计算其关联节点表示所需的 MFG 列表。
import torch.nn.functional as F
model = Model(in_features, hidden_features, out_features, num_classes)
model = model.to(device)
opt = torch.optim.Adam(model.parameters())
for data in dataloader:
blocks = data.blocks
x = data.edge_features("feat")
y_hat = model(data.blocks, x, data.compacted_seeds)
loss = F.cross_entropy(data.labels, y_hat)
opt.zero_grad()
loss.backward()
opt.step()
对于异构图
在异构图上计算节点表示的模型也可用于计算边分类/回归的关联节点表示。
class StochasticTwoLayerRGCN(nn.Module):
def __init__(self, in_feat, hidden_feat, out_feat, rel_names):
super().__init__()
self.conv1 = dglnn.HeteroGraphConv({
rel : dglnn.GraphConv(in_feat, hidden_feat, norm='right')
for rel in rel_names
})
self.conv2 = dglnn.HeteroGraphConv({
rel : dglnn.GraphConv(hidden_feat, out_feat, norm='right')
for rel in rel_names
})
def forward(self, blocks, x):
x = self.conv1(blocks[0], x)
x = self.conv2(blocks[1], x)
return x
对于得分预测,同构图和异构图之间唯一的实现差异在于我们是遍历边类型。
class ScorePredictor(nn.Module):
def __init__(self, num_classes, in_features):
super().__init__()
self.W = nn.Linear(2 * in_features, num_classes)
def forward(self, seeds, x):
scores = {}
for etype in seeds.keys():
src, dst = seeds[etype].T
data = torch.cat([x[etype][src], x[etype][dst]], 1)
scores[etype] = self.W(data)
return scores
class Model(nn.Module):
def __init__(self, in_features, hidden_features, out_features, num_classes,
etypes):
super().__init__()
self.rgcn = StochasticTwoLayerRGCN(
in_features, hidden_features, out_features, etypes)
self.pred = ScorePredictor(num_classes, out_features)
def forward(self, seeds, blocks, x):
x = self.rgcn(blocks, x)
return self.pred(seeds, x)
数据加载器的定义与同构图几乎相同。唯一的区别是 train_set 现在是 HeteroItemSet
的实例,而不是 ItemSet
的实例。
import dgl.graphbolt as gb
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
g = gb.SamplingGraph()
seeds = torch.arange(0, 1000).reshape(-1, 2)
labels = torch.randint(0, 3, (1000,))
seeds_labels = {
"user:like:item": gb.ItemSet(
(seeds, labels), names=("seeds", "labels")
),
"user:follow:user": gb.ItemSet(
(seeds, labels), names=("seeds", "labels")
),
}
train_set = gb.HeteroItemSet(seeds_labels)
datapipe = gb.ItemSampler(train_set, batch_size=128, shuffle=True)
datapipe = datapipe.sample_neighbor(g, [10, 10]) # 2 layers.
datapipe = datapipe.fetch_feature(
feature, node_feature_keys={"item": ["feat"], "user": ["feat"]}
)
datapipe = datapipe.copy_to(device)
dataloader = gb.DataLoader(datapipe)
如果您希望在异构图上排除反向边,情况会有些不同。在异构图上,反向边通常与边本身具有不同的边类型,以区分“正向”和“反向”关系(例如,follow
和 followed_by
互为反向关系,like
和 liked_by
互为反向关系等)。
如果某种类型的每条边在另一种类型中都有一个具有相同 ID 的反向边,您可以指定边类型与其反向类型之间的映射。排除 minibatch 中的边及其反向边的方法如下。
exclude_seed_edges = partial(
gb.exclude_seed_edges,
include_reverse_edges=True,
reverse_etypes_mapping={
"user:like:item": "item:liked_by:user",
"user:follow:user": "user:followed_by:user",
},
)
datapipe = datapipe.transform(exclude_seed_edges)
训练循环再次与同构图上的训练循环几乎相同,除了 compute_loss
的实现,此处它将接受两个包含节点类型和预测的字典。
import torch.nn.functional as F
model = Model(in_features, hidden_features, out_features, num_classes, etypes)
model = model.to(device)
opt = torch.optim.Adam(model.parameters())
for data in dataloader:
blocks = data.blocks
x = data.edge_features(("user:like:item", "feat"))
y_hat = model(data.blocks, x, data.compacted_seeds)
loss = F.cross_entropy(data.labels, y_hat)
opt.zero_grad()
loss.backward()
opt.step()