构建自己的数据集

本教程假设您已经了解使用 GNN 进行节点分类训练的基础知识以及如何创建、加载和存储 DGL 图

通过本教程的学习,您将能够

  • 创建用于节点分类、链接预测或图分类的自定义图数据集。

(预计时间:15 分钟)

DGLDataset 对象概述

您的自定义图数据集应该继承 dgl.data.DGLDataset 类并实现以下方法

  • __getitem__(self, i):检索数据集的第 i 个示例。一个示例通常包含一个 DGL 图,偶尔也包含其标签。

  • __len__(self):数据集中示例的数量。

  • process(self):从磁盘加载和处理原始数据。

从 CSV 创建用于图分类的数据集

创建图分类数据集涉及实现 __getitem__ 以返回图及其图级标签。

本教程演示了如何使用以下合成 CSV 数据创建图分类数据集

  • graph_edges.csv:包含三列

    • graph_id:图的 ID。

    • src:给定图的一条边的源节点。

    • dst:给定图的一条边的目标节点。

  • graph_properties.csv:包含三列

    • graph_id:图的 ID。

    • label:图的标签。

    • num_nodes:图中的节点数量。

urllib.request.urlretrieve(
    "https://data.dgl.ai/tutorial/dataset/graph_edges.csv", "./graph_edges.csv"
)
urllib.request.urlretrieve(
    "https://data.dgl.ai/tutorial/dataset/graph_properties.csv",
    "./graph_properties.csv",
)
edges = pd.read_csv("./graph_edges.csv")
properties = pd.read_csv("./graph_properties.csv")

edges.head()

properties.head()


class SyntheticDataset(DGLDataset):
    def __init__(self):
        super().__init__(name="synthetic")

    def process(self):
        edges = pd.read_csv("./graph_edges.csv")
        properties = pd.read_csv("./graph_properties.csv")
        self.graphs = []
        self.labels = []

        # Create a graph for each graph ID from the edges table.
        # First process the properties table into two dictionaries with graph IDs as keys.
        # The label and number of nodes are values.
        label_dict = {}
        num_nodes_dict = {}
        for _, row in properties.iterrows():
            label_dict[row["graph_id"]] = row["label"]
            num_nodes_dict[row["graph_id"]] = row["num_nodes"]

        # For the edges, first group the table by graph IDs.
        edges_group = edges.groupby("graph_id")

        # For each graph ID...
        for graph_id in edges_group.groups:
            # Find the edges as well as the number of nodes and its label.
            edges_of_id = edges_group.get_group(graph_id)
            src = edges_of_id["src"].to_numpy()
            dst = edges_of_id["dst"].to_numpy()
            num_nodes = num_nodes_dict[graph_id]
            label = label_dict[graph_id]

            # Create a graph and add it to the list of graphs and labels.
            g = dgl.graph((src, dst), num_nodes=num_nodes)
            self.graphs.append(g)
            self.labels.append(label)

        # Convert the label list to tensor for saving.
        self.labels = torch.LongTensor(self.labels)

    def __getitem__(self, i):
        return self.graphs[i], self.labels[i]

    def __len__(self):
        return len(self.graphs)


dataset = SyntheticDataset()
graph, label = dataset[0]
print(graph, label)
Graph(num_nodes=15, num_edges=45,
      ndata_schemes={}
      edata_schemes={}) tensor(0)

通过 CSVDataset 从 CSV 创建数据集

前面的示例介绍了如何从 CSV 文件一步一步创建数据集。DGL 还提供了一个工具类 CSVDataset 用于从 CSV 文件读取和解析数据。更多详细信息请参阅 4.6 从 CSV 文件加载数据

# Thumbnail credits: (Un)common Use Cases for Graph Databases, Michal Bachman
# sphinx_gallery_thumbnail_path = '_static/blitz_6_load_data.png'

脚本总运行时间: (0 分钟 0.594 秒)

由 Sphinx-Gallery 生成的画廊