准备数据

在本节中,我们将为之前介绍的 Graphormer 模型准备数据。我们可以使用任何包含 DGLGraph 对象的数据集和标准的 PyTorch dataloader 将数据送入模型。关键是定义一个 collate 函数来将多个图的特征分组到 batch 中。collate 函数的示例如下:

def collate(graphs):
    # compute shortest path features, can be done in advance
    for g in graphs:
        spd, path = dgl.shortest_dist(g, root=None, return_paths=True)
        g.ndata["spd"] = spd
        g.ndata["path"] = path

    num_graphs = len(graphs)
    num_nodes = [g.num_nodes() for g in graphs]
    max_num_nodes = max(num_nodes)

    attn_mask = th.zeros(num_graphs, max_num_nodes, max_num_nodes)
    node_feat = []
    in_degree, out_degree = [], []
    path_data = []
    # Since shortest_dist returns -1 for unreachable node pairs and padded
    # nodes are unreachable to others, distance relevant to padded nodes
    # use -1 padding as well.
    dist = -th.ones(
        (num_graphs, max_num_nodes, max_num_nodes), dtype=th.long
    )

    for i in range(num_graphs):
        # A binary mask where invalid positions are indicated by True.
        # Avoid the case where all positions are invalid.
        attn_mask[i, :, num_nodes[i] + 1 :] = 1

        # +1 to distinguish padded non-existing nodes from real nodes
        node_feat.append(graphs[i].ndata["feat"] + 1)

        # 0 for padding
        in_degree.append(
            th.clamp(graphs[i].in_degrees() + 1, min=0, max=512)
        )
        out_degree.append(
            th.clamp(graphs[i].out_degrees() + 1, min=0, max=512)
        )

        # Path padding to make all paths to the same length "max_len".
        path = graphs[i].ndata["path"]
        path_len = path.size(dim=2)
        # shape of shortest_path: [n, n, max_len]
        max_len = 5
        if path_len >= max_len:
            shortest_path = path[:, :, :max_len]
        else:
            p1d = (0, max_len - path_len)
            # Use the same -1 padding as shortest_dist for
            # invalid edge IDs.
            shortest_path = th.nn.functional.pad(path, p1d, "constant", -1)
        pad_num_nodes = max_num_nodes - num_nodes[i]
        p3d = (0, 0, 0, pad_num_nodes, 0, pad_num_nodes)
        shortest_path = th.nn.functional.pad(shortest_path, p3d, "constant", -1)
        # +1 to distinguish padded non-existing edges from real edges
        edata = graphs[i].edata["feat"] + 1

        # shortest_dist pads non-existing edges (at the end of shortest
        # paths) with edge IDs -1, and th.zeros(1, edata.shape[1]) stands
        # for all padded edge features.
        edata = th.cat(
            (edata, th.zeros(1, edata.shape[1]).to(edata.device)), dim=0
        )
        path_data.append(edata[shortest_path])

        dist[i, : num_nodes[i], : num_nodes[i]] = graphs[i].ndata["spd"]

    # node feat padding
    node_feat = th.nn.utils.rnn.pad_sequence(node_feat, batch_first=True)

    # degree padding
    in_degree = th.nn.utils.rnn.pad_sequence(in_degree, batch_first=True)
    out_degree = th.nn.utils.rnn.pad_sequence(out_degree, batch_first=True)

    return (
        node_feat,
        in_degree,
        out_degree,
        attn_mask,
        th.stack(path_data),
        dist,
    )

在此示例中,我们省略了一些细节,例如添加虚拟节点。更多详情请参考 Graphormer 示例