准备数据
在本节中,我们将为之前介绍的 Graphormer 模型准备数据。我们可以使用任何包含 DGLGraph
对象的数据集和标准的 PyTorch dataloader 将数据送入模型。关键是定义一个 collate 函数来将多个图的特征分组到 batch 中。collate 函数的示例如下:
def collate(graphs):
# compute shortest path features, can be done in advance
for g in graphs:
spd, path = dgl.shortest_dist(g, root=None, return_paths=True)
g.ndata["spd"] = spd
g.ndata["path"] = path
num_graphs = len(graphs)
num_nodes = [g.num_nodes() for g in graphs]
max_num_nodes = max(num_nodes)
attn_mask = th.zeros(num_graphs, max_num_nodes, max_num_nodes)
node_feat = []
in_degree, out_degree = [], []
path_data = []
# Since shortest_dist returns -1 for unreachable node pairs and padded
# nodes are unreachable to others, distance relevant to padded nodes
# use -1 padding as well.
dist = -th.ones(
(num_graphs, max_num_nodes, max_num_nodes), dtype=th.long
)
for i in range(num_graphs):
# A binary mask where invalid positions are indicated by True.
# Avoid the case where all positions are invalid.
attn_mask[i, :, num_nodes[i] + 1 :] = 1
# +1 to distinguish padded non-existing nodes from real nodes
node_feat.append(graphs[i].ndata["feat"] + 1)
# 0 for padding
in_degree.append(
th.clamp(graphs[i].in_degrees() + 1, min=0, max=512)
)
out_degree.append(
th.clamp(graphs[i].out_degrees() + 1, min=0, max=512)
)
# Path padding to make all paths to the same length "max_len".
path = graphs[i].ndata["path"]
path_len = path.size(dim=2)
# shape of shortest_path: [n, n, max_len]
max_len = 5
if path_len >= max_len:
shortest_path = path[:, :, :max_len]
else:
p1d = (0, max_len - path_len)
# Use the same -1 padding as shortest_dist for
# invalid edge IDs.
shortest_path = th.nn.functional.pad(path, p1d, "constant", -1)
pad_num_nodes = max_num_nodes - num_nodes[i]
p3d = (0, 0, 0, pad_num_nodes, 0, pad_num_nodes)
shortest_path = th.nn.functional.pad(shortest_path, p3d, "constant", -1)
# +1 to distinguish padded non-existing edges from real edges
edata = graphs[i].edata["feat"] + 1
# shortest_dist pads non-existing edges (at the end of shortest
# paths) with edge IDs -1, and th.zeros(1, edata.shape[1]) stands
# for all padded edge features.
edata = th.cat(
(edata, th.zeros(1, edata.shape[1]).to(edata.device)), dim=0
)
path_data.append(edata[shortest_path])
dist[i, : num_nodes[i], : num_nodes[i]] = graphs[i].ndata["spd"]
# node feat padding
node_feat = th.nn.utils.rnn.pad_sequence(node_feat, batch_first=True)
# degree padding
in_degree = th.nn.utils.rnn.pad_sequence(in_degree, batch_first=True)
out_degree = th.nn.utils.rnn.pad_sequence(out_degree, batch_first=True)
return (
node_feat,
in_degree,
out_degree,
attn_mask,
th.stack(path_data),
dist,
)
在此示例中,我们省略了一些细节,例如添加虚拟节点。更多详情请参考 Graphormer 示例。