SubgraphSampler
- class dgl.graphbolt.SubgraphSampler(datapipe, *args, **kwargs)[source]
-
一个子图采样器,用于从大图中的给定节点集采样子图。
函数名:
sample_subgraph
。此类是所有子图采样器的基类。SubgraphSampler 的任何子类都应实现
sample_subgraphs()
方法或sampling_stages()
方法,以定义细粒度的采样阶段,从而利用 GraphBolt DataLoader 提供的优化。- 参数:
datapipe (DataPipe) – 数据管道。
args (非关键字参数) – 将传递给 sampling_stages 的参数。
kwargs (关键字参数) – 将传递给 sampling_stages 的参数。预处理阶段在将 asynchronous 和 cooperative 参数传递给采样阶段之前会使用它们。
- sample_subgraphs(seeds, seeds_timestamp, seeds_pre_time_window=None)[source]
从给定种子采样子图,可能带有时间约束。
SubgraphSampler 的任何子类都应实现此方法。
- 参数:
seeds (Union[torch.Tensor, Dict[str, torch.Tensor]]) – 种子节点。
seeds_timestamp (Union[torch.Tensor, Dict[str, torch.Tensor]]) – 种子节点的时间戳。如果给定,采样子图不应包含任何比种子节点时间戳新的节点或边。默认值: None。
seeds_pre_time_window (Union[torch.Tensor, Dict[str, torch.Tensor]]) – 节点的时间窗口表示 seeds_timestamp 之前的一段时间。如果提供,只有时间戳落在 [seeds_timestamp - seeds_pre_time_window, seeds_timestamp] 范围内的邻居和相关边会被过滤。
- 返回值:
Union[torch.Tensor, Dict[str, torch.Tensor]] – 输入节点。
List[SampledSubgraph] – 采样的子图。
示例
>>> @functional_datapipe("my_sample_subgraph") >>> class MySubgraphSampler(SubgraphSampler): >>> def __init__(self, datapipe, graph, fanouts): >>> super().__init__(datapipe) >>> self.graph = graph >>> self.fanouts = fanouts >>> def sample_subgraphs(self, seeds): >>> # Sample subgraphs from the given seeds. >>> subgraphs = [] >>> subgraphs_nodes = [] >>> for fanout in reversed(self.fanouts): >>> subgraph = self.graph.sample_neighbors(seeds, fanout) >>> subgraphs.insert(0, subgraph) >>> subgraphs_nodes.append(subgraph.nodes) >>> seeds = subgraph.nodes >>> subgraphs_nodes = torch.unique(torch.cat(subgraphs_nodes)) >>> return subgraphs_nodes, subgraphs
- sampling_stages(datapipe)[source]
采样阶段在此通过链接到数据管道定义。默认实现期望
sample_subgraphs()
被实现。要定义细粒度阶段,应重写此方法。