6.4 实现自定义图采样器

实现自定义采样器需要继承 dgl.graphbolt.SubgraphSampler 基类并实现其抽象方法 sample_subgraphs。该 sample_subgraphs 方法应该接收种子节点,这些节点是用于采样邻居的节点

def sample_subgraphs(self, seed_nodes):
    return input_nodes, sampled_subgraphs

该方法应返回输入节点 ID 列表和子图列表。每个子图都是一个 SampledSubgraph 对象。

采样期间所需的任何其他数据(如图结构、扇出大小等)应通过构造函数传递给采样器。

下面的代码实现了一个经典的邻居采样器

@functional_datapipe("customized_sample_neighbor")
class CustomizedNeighborSampler(dgl.graphbolt.SubgraphSampler):
   def __init__(self, datapipe, graph, fanouts):
       super().__init__(datapipe)
       self.graph = graph
       self.fanouts = fanouts

   def sample_subgraphs(self, seed_nodes):
       subgs = []
       for fanout in reversed(self.fanouts):
           # Sample a fixed number of neighbors of the current seed nodes.
           input_nodes, sg = g.sample_neighbors(seed_nodes, fanout)
           subgs.insert(0, sg)
           seed_nodes = input_nodes
       return input_nodes, subgs

要将此采样器与 DataLoader 一起使用

datapipe = gb.ItemSampler(train_set, batch_size=1024, shuffle=True)
datapipe = datapipe.customized_sample_neighbor(g, [10, 10]) # 2 layers.
datapipe = datapipe.fetch_feature(feature, node_feature_keys=["feat"])
datapipe = datapipe.copy_to(device)
dataloader = gb.DataLoader(datapipe)

for data in dataloader:
    input_features = data.node_features["feat"]
    output_labels = data.labels
    output_predictions = model(data.blocks, input_features)
    loss = compute_loss(output_labels, output_predictions)
    opt.zero_grad()
    loss.backward()
    opt.step()

异构图采样器

要为异构图编写采样器,需要注意参数 graph 是一个异构图,而 seeds 可以是一个 ID 张量字典。DGL 的大多数图采样操作符(例如,上面示例中的 sample_neighborsto_block 函数)可以原生支持异构图,因此许多采样器自动适用于异构图。例如,上面的 CustomizedNeighborSampler 可以用于异构图

import dgl.graphbolt as gb
hg = gb.FusedCSCSamplingGraph()
train_set = item_set = gb.HeteroItemSet(
    {
        "user": gb.ItemSet(
            (torch.arange(0, 5), torch.arange(5, 10)),
            names=("seeds", "labels"),
        ),
        "item": gb.ItemSet(
            (torch.arange(5, 10), torch.arange(10, 15)),
            names=("seeds", "labels"),
        ),
    }
)
datapipe = gb.ItemSampler(train_set, batch_size=1024, shuffle=True)
datapipe = datapipe.customized_sample_neighbor(g, [10, 10]) # 2 layers.
datapipe = datapipe.fetch_feature(
    feature, node_feature_keys={"user": ["feat"], "item": ["feat"]}
)
datapipe = datapipe.copy_to(device)
dataloader = gb.DataLoader(datapipe)

for data in dataloader:
    input_features = {
        ntype: data.node_features[(ntype, "feat")]
        for ntype in data.blocks[0].srctypes
    }
    output_labels = data.labels["user"]
    output_predictions = model(data.blocks, input_features)["user"]
    loss = compute_loss(output_labels, output_predictions)
    opt.zero_grad()
    loss.backward()
    opt.step()

采样后排除边

在某些情况下,我们可能希望从采样的子图中排除种子边。例如,在链接预测任务中,我们希望从采样的子图中排除训练集中的边,以防止信息泄露。为此,我们需要在采样后立即添加一个额外的数据管线,如下所示

datapipe = datapipe.customized_sample_neighbor(g, [10, 10]) # 2 layers.
datapipe = datapipe.transform(gb.exclude_seed_edges)

请查看 exclude_seed_edges() 的 API 页面了解更多详情。

上述 API 基于 exclude_edges()。如果您想根据其他标准从采样的子图中排除边,您可以编写自己的转换函数。请查看该方法以作参考。

您还可以参考 链接预测 中的示例。