注意
跳到末尾下载完整示例代码。
构建自己的数据集
本教程假设您已经了解使用 GNN 进行节点分类训练的基础知识以及如何创建、加载和存储 DGL 图。
通过本教程的学习,您将能够
创建用于节点分类、链接预测或图分类的自定义图数据集。
(预计时间:15 分钟)
DGLDataset
对象概述
您的自定义图数据集应该继承 dgl.data.DGLDataset
类并实现以下方法
__getitem__(self, i)
:检索数据集的第i
个示例。一个示例通常包含一个 DGL 图,偶尔也包含其标签。__len__(self)
:数据集中示例的数量。process(self)
:从磁盘加载和处理原始数据。
从 CSV 创建用于节点分类或链接预测的数据集
节点分类数据集通常包含一个单独的图以及其节点和边特征。
本教程使用一个基于 Zachary 的空手道俱乐部网络的小数据集。它包含
一个
members.csv
文件,其中包含所有成员的属性。一个
interactions.csv
文件,其中包含两名俱乐部成员之间的两两交互。
import urllib.request
import pandas as pd
urllib.request.urlretrieve(
"https://data.dgl.ai/tutorial/dataset/members.csv", "./members.csv"
)
urllib.request.urlretrieve(
"https://data.dgl.ai/tutorial/dataset/interactions.csv",
"./interactions.csv",
)
members = pd.read_csv("./members.csv")
members.head()
interactions = pd.read_csv("./interactions.csv")
interactions.head()
本教程将成员视为节点,交互视为边。它将年龄作为节点的数值特征,所属俱乐部作为节点的标签,边权重作为边的数值特征。
注意
原始的 Zachary 空手道俱乐部网络没有成员年龄。本教程中的年龄是合成生成的,用于演示如何在创建数据集时将节点特征添加到图中。
注意
在实践中,直接将年龄作为数值特征在机器学习中可能效果不佳;分箱或标准化特征等策略会更好。本教程为简单起见直接使用原始值。
import os
os.environ["DGLBACKEND"] = "pytorch"
import dgl
import torch
from dgl.data import DGLDataset
class KarateClubDataset(DGLDataset):
def __init__(self):
super().__init__(name="karate_club")
def process(self):
nodes_data = pd.read_csv("./members.csv")
edges_data = pd.read_csv("./interactions.csv")
node_features = torch.from_numpy(nodes_data["Age"].to_numpy())
node_labels = torch.from_numpy(
nodes_data["Club"].astype("category").cat.codes.to_numpy()
)
edge_features = torch.from_numpy(edges_data["Weight"].to_numpy())
edges_src = torch.from_numpy(edges_data["Src"].to_numpy())
edges_dst = torch.from_numpy(edges_data["Dst"].to_numpy())
self.graph = dgl.graph(
(edges_src, edges_dst), num_nodes=nodes_data.shape[0]
)
self.graph.ndata["feat"] = node_features
self.graph.ndata["label"] = node_labels
self.graph.edata["weight"] = edge_features
# If your dataset is a node classification dataset, you will need to assign
# masks indicating whether a node belongs to training, validation, and test set.
n_nodes = nodes_data.shape[0]
n_train = int(n_nodes * 0.6)
n_val = int(n_nodes * 0.2)
train_mask = torch.zeros(n_nodes, dtype=torch.bool)
val_mask = torch.zeros(n_nodes, dtype=torch.bool)
test_mask = torch.zeros(n_nodes, dtype=torch.bool)
train_mask[:n_train] = True
val_mask[n_train : n_train + n_val] = True
test_mask[n_train + n_val :] = True
self.graph.ndata["train_mask"] = train_mask
self.graph.ndata["val_mask"] = val_mask
self.graph.ndata["test_mask"] = test_mask
def __getitem__(self, i):
return self.graph
def __len__(self):
return 1
dataset = KarateClubDataset()
graph = dataset[0]
print(graph)
/dgl/tutorials/blitz/6_load_data.py:106: UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior. You may want to copy the array to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at ../torch/csrc/utils/tensor_numpy.cpp:206.)
node_labels = torch.from_numpy(
Graph(num_nodes=34, num_edges=156,
ndata_schemes={'feat': Scheme(shape=(), dtype=torch.int64), 'label': Scheme(shape=(), dtype=torch.int8), 'train_mask': Scheme(shape=(), dtype=torch.bool), 'val_mask': Scheme(shape=(), dtype=torch.bool), 'test_mask': Scheme(shape=(), dtype=torch.bool)}
edata_schemes={'weight': Scheme(shape=(), dtype=torch.float64)})
由于链接预测数据集只涉及一个图,因此准备链接预测数据集的体验与准备节点分类数据集相同。
从 CSV 创建用于图分类的数据集
创建图分类数据集涉及实现 __getitem__
以返回图及其图级标签。
本教程演示了如何使用以下合成 CSV 数据创建图分类数据集
graph_edges.csv
:包含三列graph_id
:图的 ID。src
:给定图的一条边的源节点。dst
:给定图的一条边的目标节点。
graph_properties.csv
:包含三列graph_id
:图的 ID。label
:图的标签。num_nodes
:图中的节点数量。
urllib.request.urlretrieve(
"https://data.dgl.ai/tutorial/dataset/graph_edges.csv", "./graph_edges.csv"
)
urllib.request.urlretrieve(
"https://data.dgl.ai/tutorial/dataset/graph_properties.csv",
"./graph_properties.csv",
)
edges = pd.read_csv("./graph_edges.csv")
properties = pd.read_csv("./graph_properties.csv")
edges.head()
properties.head()
class SyntheticDataset(DGLDataset):
def __init__(self):
super().__init__(name="synthetic")
def process(self):
edges = pd.read_csv("./graph_edges.csv")
properties = pd.read_csv("./graph_properties.csv")
self.graphs = []
self.labels = []
# Create a graph for each graph ID from the edges table.
# First process the properties table into two dictionaries with graph IDs as keys.
# The label and number of nodes are values.
label_dict = {}
num_nodes_dict = {}
for _, row in properties.iterrows():
label_dict[row["graph_id"]] = row["label"]
num_nodes_dict[row["graph_id"]] = row["num_nodes"]
# For the edges, first group the table by graph IDs.
edges_group = edges.groupby("graph_id")
# For each graph ID...
for graph_id in edges_group.groups:
# Find the edges as well as the number of nodes and its label.
edges_of_id = edges_group.get_group(graph_id)
src = edges_of_id["src"].to_numpy()
dst = edges_of_id["dst"].to_numpy()
num_nodes = num_nodes_dict[graph_id]
label = label_dict[graph_id]
# Create a graph and add it to the list of graphs and labels.
g = dgl.graph((src, dst), num_nodes=num_nodes)
self.graphs.append(g)
self.labels.append(label)
# Convert the label list to tensor for saving.
self.labels = torch.LongTensor(self.labels)
def __getitem__(self, i):
return self.graphs[i], self.labels[i]
def __len__(self):
return len(self.graphs)
dataset = SyntheticDataset()
graph, label = dataset[0]
print(graph, label)
Graph(num_nodes=15, num_edges=45,
ndata_schemes={}
edata_schemes={}) tensor(0)
通过 CSVDataset
从 CSV 创建数据集
前面的示例介绍了如何从 CSV 文件一步一步创建数据集。DGL 还提供了一个工具类 CSVDataset
用于从 CSV 文件读取和解析数据。更多详细信息请参阅 4.6 从 CSV 文件加载数据。
# Thumbnail credits: (Un)common Use Cases for Graph Databases, Michal Bachman
# sphinx_gallery_thumbnail_path = '_static/blitz_6_load_data.png'
脚本总运行时间: (0 分钟 0.594 秒)