SBMMixtureDataset
- class dgl.data.SBMMixtureDataset(n_graphs, n_nodes, n_communities, k=2, avg_deg=3, pq='Appendix_C', rng=None)[源码]
基类:
DGLDataset
对称随机块模型混合
参考:Supervised Community Detection with Hierarchical Graph Neural Networks 的附录 C
- 参数:
n_graphs (int) – 图的数量。
n_nodes (int) – 节点数量。
n_communities (int) – 社区数量。
k (int, 可选的) – 乘数。默认为: 2
avg_deg (int, 可选的) – 平均度。默认为: 3
pq (list of pair of nonnegative float or str, 可选的) – 随机密度。此参数用于将来的扩展,目前始终使用默认值。默认为: Appendix_C
rng (numpy.random.RandomState, 可选的) – 随机数生成器。如果未给定,则为带 seed=None 的 numpy.random.RandomState(),它在可用时从 /dev/urandom(或 Windows 类似物)读取数据,否则从时钟读取种子。默认为: None
- 引发:
RuntimeError is raised if pq is not a list or string. – 如果 pq 不是列表或字符串,则引发 RuntimeError。
示例
>>> data = SBMMixtureDataset(n_graphs=16, n_nodes=10000, n_communities=2) >>> from torch.utils.data import DataLoader >>> dataloader = DataLoader(data, batch_size=1, collate_fn=data.collate_fn) >>> for graph, line_graph, graph_degrees, line_graph_degrees, pm_pd in dataloader: ... # your code here
- __getitem__(idx)[源码]
按索引获取一个示例
- 参数:
idx (int) – 项目索引
- 返回:
graph (
dgl.DGLGraph
) – 原始图line_graph (
dgl.DGLGraph
) – graph 的线图graph_degree (numpy.ndarray) – graph 中每个节点的入度
line_graph_degree (numpy.ndarray) – line_graph 中每个节点的入度
pm_pd (numpy.ndarray) – 边指示矩阵 Pm 和 Pd
- collate_fn(x)[源码]
collate 函数,用于 dataloader
- 参数:
x (tuple) –
包含以下内容的批量数据
- graph:
dgl.DGLGraph
原始图
- graph:
- line_graph:
dgl.DGLGraph
graph 的线图
- line_graph:
- graph_degree: numpy.ndarray
graph 中每个节点的入度
- line_graph_degree: numpy.ndarray
line_graph 中每个节点的入度
- pm_pd: numpy.ndarray
边指示矩阵 Pm 和 Pd
- 返回:
g_batch (
dgl.DGLGraph
) – 批量图lg_batch (
dgl.DGLGraph
) – 批量线图degg_batch (numpy.ndarray) – g_batch 中每个节点的批量入度
deglg_batch (numpy.ndarray) – lg_batch 中每个节点的批量入度
pm_pd_batch (numpy.ndarray) – 边指示矩阵 Pm 和 Pd 的批量数据