6.4 实现自定义图采样器
实现自定义采样器需要继承 dgl.graphbolt.SubgraphSampler
基类并实现其抽象方法 sample_subgraphs
。该 sample_subgraphs
方法应该接收种子节点,这些节点是用于采样邻居的节点
def sample_subgraphs(self, seed_nodes):
return input_nodes, sampled_subgraphs
该方法应返回输入节点 ID 列表和子图列表。每个子图都是一个 SampledSubgraph
对象。
采样期间所需的任何其他数据(如图结构、扇出大小等)应通过构造函数传递给采样器。
下面的代码实现了一个经典的邻居采样器
@functional_datapipe("customized_sample_neighbor")
class CustomizedNeighborSampler(dgl.graphbolt.SubgraphSampler):
def __init__(self, datapipe, graph, fanouts):
super().__init__(datapipe)
self.graph = graph
self.fanouts = fanouts
def sample_subgraphs(self, seed_nodes):
subgs = []
for fanout in reversed(self.fanouts):
# Sample a fixed number of neighbors of the current seed nodes.
input_nodes, sg = g.sample_neighbors(seed_nodes, fanout)
subgs.insert(0, sg)
seed_nodes = input_nodes
return input_nodes, subgs
要将此采样器与 DataLoader
一起使用
datapipe = gb.ItemSampler(train_set, batch_size=1024, shuffle=True)
datapipe = datapipe.customized_sample_neighbor(g, [10, 10]) # 2 layers.
datapipe = datapipe.fetch_feature(feature, node_feature_keys=["feat"])
datapipe = datapipe.copy_to(device)
dataloader = gb.DataLoader(datapipe)
for data in dataloader:
input_features = data.node_features["feat"]
output_labels = data.labels
output_predictions = model(data.blocks, input_features)
loss = compute_loss(output_labels, output_predictions)
opt.zero_grad()
loss.backward()
opt.step()
异构图采样器
要为异构图编写采样器,需要注意参数 graph 是一个异构图,而 seeds 可以是一个 ID 张量字典。DGL 的大多数图采样操作符(例如,上面示例中的 sample_neighbors
和 to_block
函数)可以原生支持异构图,因此许多采样器自动适用于异构图。例如,上面的 CustomizedNeighborSampler
可以用于异构图
import dgl.graphbolt as gb
hg = gb.FusedCSCSamplingGraph()
train_set = item_set = gb.HeteroItemSet(
{
"user": gb.ItemSet(
(torch.arange(0, 5), torch.arange(5, 10)),
names=("seeds", "labels"),
),
"item": gb.ItemSet(
(torch.arange(5, 10), torch.arange(10, 15)),
names=("seeds", "labels"),
),
}
)
datapipe = gb.ItemSampler(train_set, batch_size=1024, shuffle=True)
datapipe = datapipe.customized_sample_neighbor(g, [10, 10]) # 2 layers.
datapipe = datapipe.fetch_feature(
feature, node_feature_keys={"user": ["feat"], "item": ["feat"]}
)
datapipe = datapipe.copy_to(device)
dataloader = gb.DataLoader(datapipe)
for data in dataloader:
input_features = {
ntype: data.node_features[(ntype, "feat")]
for ntype in data.blocks[0].srctypes
}
output_labels = data.labels["user"]
output_predictions = model(data.blocks, input_features)["user"]
loss = compute_loss(output_labels, output_predictions)
opt.zero_grad()
loss.backward()
opt.step()
采样后排除边
在某些情况下,我们可能希望从采样的子图中排除种子边。例如,在链接预测任务中,我们希望从采样的子图中排除训练集中的边,以防止信息泄露。为此,我们需要在采样后立即添加一个额外的数据管线,如下所示
datapipe = datapipe.customized_sample_neighbor(g, [10, 10]) # 2 layers.
datapipe = datapipe.transform(gb.exclude_seed_edges)
请查看 exclude_seed_edges()
的 API 页面了解更多详情。
上述 API 基于 exclude_edges()
。如果您想根据其他标准从采样的子图中排除边,您可以编写自己的转换函数。请查看该方法以作参考。
您还可以参考 链接预测 中的示例。