ItemSet

class dgl.graphbolt.ItemSet(items: int | Tensor | Tuple[Tensor], names: str | Tuple[str] | None = None)[source]

基类: object

张量或张量元组的包装器。

参数:
  • items (Union[int, torch.Tensor, Tuple[torch.Tensor]]) –

    要包装的张量。- 如果它是一个单个标量(一个整数或一个只包含单个值的张量),

    该项将被视为由 torch.arange 创建的 range_tensor。

    • 如果它是多维张量,索引将沿着第一个维度进行。

    • 如果它是一个元组,元组中的每个项都必须是一个张量。

  • names (Union[str, Tuple[str]], optional) – 项的名称。如果它是一个元组,每个名称必须与 items 参数中的项对应。命名是任意的,但通常做法是,名称应从 [‘labels’, ‘seeds’, ‘indexes’] 中选择,以与类 dgl.graphbolt.MiniBatch 的属性对齐。

示例

>>> import torch
>>> from dgl import graphbolt as gb
  1. 整数:节点数量。

>>> num = 10
>>> item_set = gb.ItemSet(num, names="seeds")
>>> list(item_set)
[tensor(0), tensor(1), tensor(2), tensor(3), tensor(4), tensor(5),
 tensor(6), tensor(7), tensor(8), tensor(9)]
>>> item_set[:]
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
>>> item_set.names
('seeds',)
  1. Torch 标量:节点数量。与整数相比,dtype 可自定义。

>>> num = torch.tensor(10, dtype=torch.int32)
>>> item_set = gb.ItemSet(num, names="seeds")
>>> list(item_set)
[tensor(0, dtype=torch.int32), tensor(1, dtype=torch.int32),
 tensor(2, dtype=torch.int32), tensor(3, dtype=torch.int32),
 tensor(4, dtype=torch.int32), tensor(5, dtype=torch.int32),
 tensor(6, dtype=torch.int32), tensor(7, dtype=torch.int32),
 tensor(8, dtype=torch.int32), tensor(9, dtype=torch.int32)]
>>> item_set[:]
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=torch.int32)
>>> item_set.names
('seeds',)
  1. 单个张量:种子节点。

>>> node_ids = torch.arange(0, 5)
>>> item_set = gb.ItemSet(node_ids, names="seeds")
>>> list(item_set)
[tensor(0), tensor(1), tensor(2), tensor(3), tensor(4)]
>>> item_set[:]
tensor([0, 1, 2, 3, 4])
>>> item_set.names
('seeds',)
  1. 形状相同的张量元组:种子节点和标签。

>>> node_ids = torch.arange(0, 5)
>>> labels = torch.arange(5, 10)
>>> item_set = gb.ItemSet(
...     (node_ids, labels), names=("seeds", "labels"))
>>> list(item_set)
[(tensor(0), tensor(5)), (tensor(1), tensor(6)), (tensor(2), tensor(7)),
 (tensor(3), tensor(8)), (tensor(4), tensor(9))]
>>> item_set[:]
(tensor([0, 1, 2, 3, 4]), tensor([5, 6, 7, 8, 9]))
>>> item_set.names
('seeds', 'labels')
  1. 形状不同的张量元组:种子和标签。

>>> seeds = torch.arange(0, 10).reshape(-1, 2)
>>> labels = torch.tensor([1, 1, 0, 0, 0])
>>> item_set = gb.ItemSet(
...     (seeds, labels), names=("seeds", "lables"))
>>> list(item_set)
[(tensor([0, 1]), tensor([1])),
 (tensor([2, 3]), tensor([1])),
 (tensor([4, 5]), tensor([0])),
 (tensor([6, 7]), tensor([0])),
 (tensor([8, 9]), tensor([0]))]
>>> item_set[:]
(tensor([[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]]),
 tensor([1, 1, 0, 0, 0]))
>>> item_set.names
('seeds', 'labels')
  1. 形状不同的张量元组:超链接和标签。

>>> seeds = torch.arange(0, 10).reshape(-1, 5)
>>> labels = torch.tensor([1, 0])
>>> item_set = gb.ItemSet(
...     (seeds, labels), names=("seeds", "lables"))
>>> list(item_set)
[(tensor([0, 1, 2, 3, 4]), tensor([1])),
 (tensor([5, 6, 7, 8, 9]), tensor([0]))]
>>> item_set[:]
(tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]),
 tensor([1, 0]))
>>> item_set.names
('seeds', 'labels')
property names: Tuple[str]

返回项的名称。

property num_items: int

返回项的数量。