MiniGCDataset

class dgl.data.MiniGCDataset(num_graphs, min_num_v, max_num_v, seed=0, save_graph=True, force_reload=False, verbose=False, transform=None)[source]

基类:DGLDataset

合成图分类数据集类。

该数据集包含 8 种不同类型的图。

  • 类别 0:环图

  • 类别 1:星图

  • 类别 2:轮图

  • 类别 3:棒棒糖图

  • 类别 4:超立方体图

  • 类别 5:网格图

  • 类别 6:完全图

  • 类别 7:循环阶梯图

参数:
  • num_graphs (int) – 此数据集中的图数量。

  • min_num_v (int) – 图的最小节点数

  • max_num_v (int) – 图的最大节点数

  • seed (int, 默认为 0) – 数据生成的随机种子

  • transform (callable, 可选) – 一个转换函数,接受 DGLGraph 对象并返回转换后的版本。每次访问时,DGLGraph 对象都将被转换。

num_graphs

图的数量

类型:

int

min_num_v

最小节点数

类型:

int

max_num_v

最大节点数

类型:

int

num_classes

类别数量

类型:

int

示例

>>> data = MiniGCDataset(100, 16, 32, seed=0)

数据集实例是可迭代的

>>> len(data)
100
>>> g, label = data[64]
>>> g
Graph(num_nodes=20, num_edges=82,
      ndata_schemes={}
      edata_schemes={})
>>> label
tensor(5)

批量处理图和标签以进行小批量训练

>>> graphs, labels = zip(*[data[i] for i in range(16)])
>>> batched_graphs = dgl.batch(graphs)
>>> batched_labels = torch.tensor(labels)
>>> batched_graphs
Graph(num_nodes=356, num_edges=1060,
      ndata_schemes={}
      edata_schemes={})
__getitem__(idx)[source]

获取第 idx 个样本。

参数:

idx (int) – 样本索引。

返回:

图及其标签。

返回类型:

(dgl.Graph, Tensor)

__len__()[source]

返回数据集中的图数量。