DGL

入门

  • 安装与设置
  • DGL 速览

进阶内容

  • 🆕 使用 GraphBolt 对 GNN 进行随机训练
  • 用户指南
    • 第 1 章:图
    • 第 2 章:消息传递
    • 第 3 章:构建 GNN 模块
    • 第 4 章:图数据管道
    • 第 5 章:训练图神经网络
    • 第 6 章:在大规模图上进行随机训练
      • 6.1 使用邻居采样训练 GNN 进行节点分类
      • 6.2 使用邻居采样训练 GNN 进行边分类
      • 6.3 使用邻居采样训练 GNN 进行链接预测
      • 6.4 实现自定义图采样器
      • 6.5 使用 DGL sparse 训练 GNN
      • 6.6 实现用于 Mini-batch 训练的自定义 GNN 模块
      • 6.7 在大规模图上进行精确离线推理
      • 6.8 使用 GPU 进行邻居采样
      • 6.9 数据加载并行性
    • 第 7 章:分布式训练
    • 第 8 章:混合精度训练
  • 用户指南【包含过时信息】
  • 사용자 가이드[시대에 뒤쳐진]
  • 🆕 教程:图 Transformer
  • 教程:dgl.sparse
  • 在 CPU 上训练
  • 在多个 GPU 上训练
  • 分布式训练
  • 使用 DGL 进行论文研习

API 参考

  • dgl
  • dgl.data
  • dgl.dataloading
  • dgl.DGLGraph
  • dgl.distributed
  • dgl.function
  • dgl.geometry
  • 🆕 dgl.graphbolt
  • dgl.nn (PyTorch)
  • dgl.nn.functional
  • dgl.ops
  • dgl.optim
  • dgl.sampling
  • dgl.sparse
  • dgl.multiprocessing
  • dgl.transforms
  • 用户自定义函数

说明

  • 贡献 DGL
  • DGL 外部函数接口 (FFI)
  • 性能基准测试

杂项

  • 常见问题 (FAQ)
  • 环境变量
  • 资源
DGL
  • 用户指南
  • 第 6 章:在大规模图上进行随机训练
  • 6.3 使用邻居采样训练 GNN 进行链接预测
  • 查看页面源码

6.3 使用邻居采样训练 GNN 进行链接预测

(中文版)

定义一个包含邻居采样和负采样的数据加载器

您仍然可以使用与节点/边分类相同的数据加载器。唯一的区别是,您需要在邻居采样阶段之前添加一个额外的 负采样 阶段。以下数据加载器将为边的每个源节点均匀选择 5 个负目标节点。

datapipe = datapipe.sample_uniform_negative(graph, 5)

完整的数据加载器流程如下

datapipe = gb.ItemSampler(itemset, batch_size=1024, shuffle=True)
datapipe = datapipe.sample_uniform_negative(graph, 5)
datapipe = datapipe.sample_neighbor(g, [10, 10]) # 2 layers.
datapipe = datapipe.transform(gb.exclude_seed_edges)
datapipe = datapipe.fetch_feature(feature, node_feature_keys=["feat"])
datapipe = datapipe.copy_to(device)
dataloader = gb.DataLoader(datapipe)

有关内置均匀负采样器的详细信息,请参阅 UniformNegativeSampler。

您也可以提供自己的负采样函数,只要它继承自 NegativeSampler 并覆盖 _sample_with_etype() 方法即可。该方法接收 minibatch 中的节点对,并返回负节点对。

以下提供了一个自定义负采样器的示例,该采样器根据与节点度数的幂成比例的概率分布来采样负目标节点。

@functional_datapipe("customized_sample_negative")
class CustomizedNegativeSampler(dgl.graphbolt.NegativeSampler):
    def __init__(self, datapipe, k, node_degrees):
        super().__init__(datapipe, k)
        # caches the probability distribution
        self.weights = node_degrees ** 0.75
        self.k = k

    def _sample_with_etype(self, seeds, etype=None):
        src, _ = seeds.T
        src = src.repeat_interleave(self.k)
        dst = self.weights.multinomial(len(src), replacement=True)
        return src, dst

datapipe = datapipe.customized_sample_negative(5, node_degrees)

定义一个用于 minibatch 训练的 GraphSAGE 模型

class SAGE(nn.Module):
    def __init__(self, in_size, hidden_size):
        super().__init__()
        self.layers = nn.ModuleList()
        self.layers.append(dglnn.SAGEConv(in_size, hidden_size, "mean"))
        self.layers.append(dglnn.SAGEConv(hidden_size, hidden_size, "mean"))
        self.layers.append(dglnn.SAGEConv(hidden_size, hidden_size, "mean"))
        self.hidden_size = hidden_size
        self.predictor = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, 1),
        )

    def forward(self, blocks, x):
        hidden_x = x
        for layer_idx, (layer, block) in enumerate(zip(self.layers, blocks)):
            hidden_x = layer(block, hidden_x)
            is_last_layer = layer_idx == len(self.layers) - 1
            if not is_last_layer:
                hidden_x = F.relu(hidden_x)
        return hidden_x

当提供了负采样器时,数据加载器除了 消息流图 (MFGs) 外,还会为每个 minibatch 生成正负节点对。使用 compacted_seeds 和 labels 来获取紧凑的节点对和相应的标签。

训练循环

训练循环仅涉及迭代数据加载器,并将图和输入特征馈送给上面定义的模型。

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

for epoch in tqdm.trange(args.epochs):
    model.train()
    total_loss = 0
    start_epoch_time = time.time()
    for step, data in enumerate(dataloader):
        # Unpack MiniBatch.
        compacted_seeds = data.compacted_seeds.T
        labels = data.labels
        node_feature = data.node_features["feat"]
        # Convert sampled subgraphs to DGL blocks.
        blocks = data.blocks

        # Get the embeddings of the input nodes.
        y = model(blocks, node_feature)
        logits = model.predictor(
            y[compacted_seeds[0]] * y[compacted_seeds[1]]
        ).squeeze()

        # Compute loss.
        loss = F.binary_cross_entropy_with_logits(logits, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    end_epoch_time = time.time()

DGL 提供了 无监督学习 GraphSAGE 的示例,展示了在同构图上进行链接预测的方法。

对于异构图

之前的模型可以很容易地扩展到异构图。唯一的区别是您需要根据边类型使用 HeteroGraphConv 来包装 SAGEConv。

class SAGE(nn.Module):
    def __init__(self, in_size, hidden_size):
        super().__init__()
        self.layers = nn.ModuleList()
        self.layers.append(dglnn.HeteroGraphConv({
                rel : dglnn.SAGEConv(in_size, hidden_size, "mean")
                for rel in rel_names
            }))
        self.layers.append(dglnn.HeteroGraphConv({
                rel : dglnn.SAGEConv(hidden_size, hidden_size, "mean")
                for rel in rel_names
            }))
        self.layers.append(dglnn.HeteroGraphConv({
                rel : dglnn.SAGEConv(hidden_size, hidden_size, "mean")
                for rel in rel_names
            }))
        self.hidden_size = hidden_size
        self.predictor = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, 1),
        )

    def forward(self, blocks, x):
        hidden_x = x
        for layer_idx, (layer, block) in enumerate(zip(self.layers, blocks)):
            hidden_x = layer(block, hidden_x)
            is_last_layer = layer_idx == len(self.layers) - 1
            if not is_last_layer:
                hidden_x = F.relu(hidden_x)
        return hidden_x

数据加载器的定义也与同构图非常相似。唯一的区别是您需要提供边类型来进行特征获取。

datapipe = gb.ItemSampler(itemset, batch_size=1024, shuffle=True)
datapipe = datapipe.sample_uniform_negative(graph, 5)
datapipe = datapipe.sample_neighbor(g, [10, 10]) # 2 layers.
datapipe = datapipe.transform(gb.exclude_seed_edges)
datapipe = datapipe.fetch_feature(
    feature,
    node_feature_keys={"user": ["feat"], "item": ["feat"]}
)
datapipe = datapipe.copy_to(device)
dataloader = gb.DataLoader(datapipe)

如果您想提供自己的负采样函数,只需继承自 NegativeSampler 类并覆盖 _sample_with_etype() 方法即可。

@functional_datapipe("customized_sample_negative")
class CustomizedNegativeSampler(dgl.graphbolt.NegativeSampler):
    def __init__(self, datapipe, k, node_degrees):
        super().__init__(datapipe, k)
        # caches the probability distribution
        self.weights = {
            etype: node_degrees[etype] ** 0.75 for etype in node_degrees
        }
        self.k = k

    def _sample_with_etype(self, seeds, etype):
        src, _ = seeds.T
        src = src.repeat_interleave(self.k)
        dst = self.weights[etype].multinomial(len(src), replacement=True)
        return src, dst

datapipe = datapipe.customized_sample_negative(5, node_degrees)

对于异构图,节点对按边类型分组。训练循环再次与同构图上的训练循环几乎相同,除了在特定边类型上计算损失。

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

category = "user"
for epoch in tqdm.trange(args.epochs):
    model.train()
    total_loss = 0
    start_epoch_time = time.time()
    for step, data in enumerate(dataloader):
        # Unpack MiniBatch.
        compacted_seeds = data.compacted_seeds
        labels = data.labels
        node_features = {
            ntype: data.node_features[(ntype, "feat")]
            for ntype in data.blocks[0].srctypes
        }
        # Convert sampled subgraphs to DGL blocks.
        blocks = data.blocks
        # Get the embeddings of the input nodes.
        y = model(blocks, node_feature)
        logits = model.predictor(
            y[category][compacted_pairs[category][:, 0]]
            * y[category][compacted_pairs[category][:, 1]]
        ).squeeze()

        # Compute loss.
        loss = F.binary_cross_entropy_with_logits(logits, labels[category])
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    end_epoch_time = time.time()
上一页 下一页

© 版权所有 2018, DGL 团队。

使用 Sphinx 构建,并使用了 Read the Docs 提供的 主题。
Read the Docs v: latest
版本
下载
在 Read the Docs 上
项目主页
构建