SBMMixtureDataset

class dgl.data.SBMMixtureDataset(n_graphs, n_nodes, n_communities, k=2, avg_deg=3, pq='Appendix_C', rng=None)[源码]

基类: DGLDataset

对称随机块模型混合

参考:Supervised Community Detection with Hierarchical Graph Neural Networks 的附录 C

参数:
  • n_graphs (int) – 图的数量。

  • n_nodes (int) – 节点数量。

  • n_communities (int) – 社区数量。

  • k (int, 可选的) – 乘数。默认为: 2

  • avg_deg (int, 可选的) – 平均度。默认为: 3

  • pq (list of pair of nonnegative float or str, 可选的) – 随机密度。此参数用于将来的扩展,目前始终使用默认值。默认为: Appendix_C

  • rng (numpy.random.RandomState, 可选的) – 随机数生成器。如果未给定,则为带 seed=Nonenumpy.random.RandomState(),它在可用时从 /dev/urandom(或 Windows 类似物)读取数据,否则从时钟读取种子。默认为: None

引发:

RuntimeError is raised if pq is not a list or string. – 如果 pq 不是列表或字符串,则引发 RuntimeError。

示例

>>> data = SBMMixtureDataset(n_graphs=16, n_nodes=10000, n_communities=2)
>>> from torch.utils.data import DataLoader
>>> dataloader = DataLoader(data, batch_size=1, collate_fn=data.collate_fn)
>>> for graph, line_graph, graph_degrees, line_graph_degrees, pm_pd in dataloader:
...     # your code here
__getitem__(idx)[源码]

按索引获取一个示例

参数:

idx (int) – 项目索引

返回:

  • graph (dgl.DGLGraph) – 原始图

  • line_graph (dgl.DGLGraph) – graph 的线图

  • graph_degree (numpy.ndarray) – graph 中每个节点的入度

  • line_graph_degree (numpy.ndarray) – line_graph 中每个节点的入度

  • pm_pd (numpy.ndarray) – 边指示矩阵 Pm 和 Pd

__len__()[源码]

数据集中的图数量。

collate_fn(x)[源码]

collate 函数,用于 dataloader

参数:

x (tuple) –

包含以下内容的批量数据

  • graph: dgl.DGLGraph

    原始图

  • line_graph: dgl.DGLGraph

    graph 的线图

  • graph_degree: numpy.ndarray

    graph 中每个节点的入度

  • line_graph_degree: numpy.ndarray

    line_graph 中每个节点的入度

  • pm_pd: numpy.ndarray

    边指示矩阵 Pm 和 Pd

返回:

  • g_batch (dgl.DGLGraph) – 批量图

  • lg_batch (dgl.DGLGraph) – 批量线图

  • degg_batch (numpy.ndarray) – g_batch 中每个节点的批量入度

  • deglg_batch (numpy.ndarray) – lg_batch 中每个节点的批量入度

  • pm_pd_batch (numpy.ndarray) – 边指示矩阵 Pm 和 Pd 的批量数据