ItemSampler

class dgl.graphbolt.ItemSampler(item_set: ~dgl.graphbolt.itemset.ItemSet | ~dgl.graphbolt.itemset.HeteroItemSet, batch_size: int, minibatcher: ~typing.Callable | None = <function minibatcher_default>, drop_last: bool | None = False, shuffle: bool | None = False, seed: int | None = None)[source]

基类: IterDataPipe

一个用于迭代输入项并创建小批量的采样器。

输入项可以是节点 ID、带或不带标签的节点对、或者带负样本源/目标的节点对。

注意:此类 ItemSampler 特意没有使用 torch.utils.data.functional_datapipe 装饰。这表明它不支持函数式调用。但可以进一步添加来自 torch.utils.data.datapipes 的任何可迭代 datapipe。

参数:
  • item_set (Union[ItemSet, HeteroItemSet]) – 待采样数据。

  • batch_size (int) – 每个批次的大小。

  • minibatcher (Optional[Callable]) – 一个可调用对象,接收一个 item 列表并返回一个 MiniBatch

  • drop_last (bool) – 如果最后一个批次不满,是否丢弃该批次。

  • shuffle (bool) – 是否在采样前打乱。

  • seed (int) – 用于可复现随机打乱的种子。如果为 None,将生成一个随机种子。

示例

  1. 节点 ID。

>>> import torch
>>> from dgl import graphbolt as gb
>>> item_set = gb.ItemSet(torch.arange(0, 10), names="seeds")
>>> item_sampler = gb.ItemSampler(
...     item_set, batch_size=4, shuffle=False, drop_last=False
... )
>>> next(iter(item_sampler))
MiniBatch(seeds=tensor([0, 1, 2, 3]), sampled_subgraphs=None,
    node_features=None, labels=None, input_nodes=None,
    indexes=None, edge_features=None, compacted_seeds=None,
    blocks=None,)
  1. 节点对。

>>> item_set = gb.ItemSet(torch.arange(0, 20).reshape(-1, 2),
...     names="seeds")
>>> item_sampler = gb.ItemSampler(
...     item_set, batch_size=4, shuffle=False, drop_last=False
... )
>>> next(iter(item_sampler))
MiniBatch(seeds=tensor([[0, 1], [2, 3], [4, 5], [6, 7]]),
    sampled_subgraphs=None, node_features=None, labels=None,
    input_nodes=None, indexes=None, edge_features=None,
    compacted_seeds=None, blocks=None,)
  1. 节点对和标签。

>>> item_set = gb.ItemSet(
...     (torch.arange(0, 20).reshape(-1, 2), torch.arange(10, 20)),
...     names=("seeds", "labels")
... )
>>> item_sampler = gb.ItemSampler(
...     item_set, batch_size=4, shuffle=False, drop_last=False
... )
>>> next(iter(item_sampler))
MiniBatch(seeds=tensor([[0, 1], [2, 3], [4, 5], [6, 7]]),
    sampled_subgraphs=None, node_features=None,
    labels=tensor([10, 11, 12, 13]), input_nodes=None,
    indexes=None, edge_features=None, compacted_seeds=None,
    blocks=None,)
  1. 节点对、标签和索引。

>>> seeds = torch.arange(0, 20).reshape(-1, 2)
>>> labels = torch.tensor([1, 1, 0, 0, 0, 0, 0, 0, 0, 0])
>>> indexes = torch.tensor([0, 1, 0, 0, 0, 0, 1, 1, 1, 1])
>>> item_set = gb.ItemSet((seeds, labels, indexes), names=("seeds",
...     "labels", "indexes"))
>>> item_sampler = gb.ItemSampler(
...     item_set, batch_size=4, shuffle=False, drop_last=False
... )
>>> next(iter(item_sampler))
MiniBatch(seeds=tensor([[0, 1], [2, 3], [4, 5], [6, 7]]),
    sampled_subgraphs=None, node_features=None,
    labels=tensor([1, 1, 0, 0]), input_nodes=None,
    indexes=tensor([0, 1, 0, 0]), edge_features=None,
    compacted_seeds=None, blocks=None,)

5. 使用其他 datapipe (如 torch.utils.data.datapipes.iter.Mapper)进一步处理批次。

>>> item_set = gb.ItemSet(torch.arange(0, 10))
>>> data_pipe = gb.ItemSampler(item_set, 4)
>>> def add_one(batch):
...     return batch + 1
>>> data_pipe = data_pipe.map(add_one)
>>> list(data_pipe)
[tensor([1, 2, 3, 4]), tensor([5, 6, 7, 8]), tensor([ 9, 10])]
  1. 异构节点 ID。

>>> ids = {
...     "user": gb.ItemSet(torch.arange(0, 5), names="seeds"),
...     "item": gb.ItemSet(torch.arange(0, 6), names="seeds"),
... }
>>> item_set = gb.HeteroItemSet(ids)
>>> item_sampler = gb.ItemSampler(item_set, batch_size=4)
>>> next(iter(item_sampler))
MiniBatch(seeds={'user': tensor([0, 1, 2, 3])}, sampled_subgraphs=None,
    node_features=None, labels=None, input_nodes=None, indexes=None,
    edge_features=None, compacted_seeds=None, blocks=None,)
  1. 异构节点对。

>>> seeds_like = torch.arange(0, 10).reshape(-1, 2)
>>> seeds_follow = torch.arange(10, 20).reshape(-1, 2)
>>> item_set = gb.HeteroItemSet({
...     "user:like:item": gb.ItemSet(
...         seeds_like, names="seeds"),
...     "user:follow:user": gb.ItemSet(
...         seeds_follow, names="seeds"),
... })
>>> item_sampler = gb.ItemSampler(item_set, batch_size=4)
>>> next(iter(item_sampler))
MiniBatch(seeds={'user:like:item':
    tensor([[0, 1], [2, 3], [4, 5], [6, 7]])}, sampled_subgraphs=None,
    node_features=None, labels=None, input_nodes=None, indexes=None,
    edge_features=None, compacted_seeds=None, blocks=None,)
  1. 异构节点对和标签。

>>> seeds_like = torch.arange(0, 10).reshape(-1, 2)
>>> labels_like = torch.arange(0, 5)
>>> seeds_follow = torch.arange(10, 20).reshape(-1, 2)
>>> labels_follow = torch.arange(5, 10)
>>> item_set = gb.HeteroItemSet({
...     "user:like:item": gb.ItemSet((seeds_like, labels_like),
...         names=("seeds", "labels")),
...     "user:follow:user": gb.ItemSet((seeds_follow, labels_follow),
...         names=("seeds", "labels")),
... })
>>> item_sampler = gb.ItemSampler(item_set, batch_size=4)
>>> next(iter(item_sampler))
MiniBatch(seeds={'user:like:item':
    tensor([[0, 1], [2, 3], [4, 5], [6, 7]])}, sampled_subgraphs=None,
    node_features=None, labels={'user:like:item': tensor([0, 1, 2, 3])},
    input_nodes=None, indexes=None, edge_features=None,
    compacted_seeds=None, blocks=None,)
  1. 异构节点对、标签和索引。

>>> seeds_like = torch.arange(0, 10).reshape(-1, 2)
>>> labels_like = torch.tensor([1, 1, 0, 0, 0])
>>> indexes_like = torch.tensor([0, 1, 0, 0, 1])
>>> seeds_follow = torch.arange(20, 30).reshape(-1, 2)
>>> labels_follow = torch.tensor([1, 1, 0, 0, 0])
>>> indexes_follow = torch.tensor([0, 1, 0, 0, 1])
>>> item_set = gb.HeteroItemSet({
...     "user:like:item": gb.ItemSet((seeds_like, labels_like,
...         indexes_like), names=("seeds", "labels", "indexes")),
...     "user:follow:user": gb.ItemSet((seeds_follow,labels_follow,
...         indexes_follow), names=("seeds", "labels", "indexes")),
... })
>>> item_sampler = gb.ItemSampler(item_set, batch_size=4)
>>> next(iter(item_sampler))
MiniBatch(seeds={'user:like:item':
    tensor([[0, 1], [2, 3], [4, 5], [6, 7]])}, sampled_subgraphs=None,
    node_features=None, labels={'user:like:item': tensor([1, 1, 0, 0])},
    input_nodes=None, indexes={'user:like:item': tensor([0, 1, 0, 0])},
    edge_features=None, compacted_seeds=None, blocks=None,)