SubgraphSampler

class dgl.graphbolt.SubgraphSampler(datapipe, *args, **kwargs)[source]

基类: MiniBatchTransformer

一个子图采样器,用于从大图中的给定节点集采样子图。

函数名: sample_subgraph

此类是所有子图采样器的基类。SubgraphSampler 的任何子类都应实现 sample_subgraphs() 方法或 sampling_stages() 方法,以定义细粒度的采样阶段,从而利用 GraphBolt DataLoader 提供的优化。

参数:
  • datapipe (DataPipe) – 数据管道。

  • args (非关键字参数) – 将传递给 sampling_stages 的参数。

  • kwargs (关键字参数) – 将传递给 sampling_stages 的参数。预处理阶段在将 asynchronouscooperative 参数传递给采样阶段之前会使用它们。

sample_subgraphs(seeds, seeds_timestamp, seeds_pre_time_window=None)[source]

从给定种子采样子图,可能带有时间约束。

SubgraphSampler 的任何子类都应实现此方法。

参数:
  • seeds (Union[torch.Tensor, Dict[str, torch.Tensor]]) – 种子节点。

  • seeds_timestamp (Union[torch.Tensor, Dict[str, torch.Tensor]]) – 种子节点的时间戳。如果给定,采样子图不应包含任何比种子节点时间戳新的节点或边。默认值: None。

  • seeds_pre_time_window (Union[torch.Tensor, Dict[str, torch.Tensor]]) – 节点的时间窗口表示 seeds_timestamp 之前的一段时间。如果提供,只有时间戳落在 [seeds_timestamp - seeds_pre_time_window, seeds_timestamp] 范围内的邻居和相关边会被过滤。

返回值:

  • Union[torch.Tensor, Dict[str, torch.Tensor]] – 输入节点。

  • List[SampledSubgraph] – 采样的子图。

示例

>>> @functional_datapipe("my_sample_subgraph")
>>> class MySubgraphSampler(SubgraphSampler):
>>>     def __init__(self, datapipe, graph, fanouts):
>>>         super().__init__(datapipe)
>>>         self.graph = graph
>>>         self.fanouts = fanouts
>>>     def sample_subgraphs(self, seeds):
>>>         # Sample subgraphs from the given seeds.
>>>         subgraphs = []
>>>         subgraphs_nodes = []
>>>         for fanout in reversed(self.fanouts):
>>>             subgraph = self.graph.sample_neighbors(seeds, fanout)
>>>             subgraphs.insert(0, subgraph)
>>>             subgraphs_nodes.append(subgraph.nodes)
>>>             seeds = subgraph.nodes
>>>         subgraphs_nodes = torch.unique(torch.cat(subgraphs_nodes))
>>>         return subgraphs_nodes, subgraphs
sampling_stages(datapipe)[source]

采样阶段在此通过链接到数据管道定义。默认实现期望 sample_subgraphs() 被实现。要定义细粒度阶段,应重写此方法。