ItemSet
- class dgl.graphbolt.ItemSet(items: int | Tensor | Tuple[Tensor], names: str | Tuple[str] | None = None)[source]
基类:
object
张量或张量元组的包装器。
- 参数:
items (Union[int, torch.Tensor, Tuple[torch.Tensor]]) –
要包装的张量。- 如果它是一个单个标量(一个整数或一个只包含单个值的张量),
该项将被视为由 torch.arange 创建的 range_tensor。
如果它是多维张量,索引将沿着第一个维度进行。
如果它是一个元组,元组中的每个项都必须是一个张量。
names (Union[str, Tuple[str]], optional) – 项的名称。如果它是一个元组,每个名称必须与 items 参数中的项对应。命名是任意的,但通常做法是,名称应从 [‘labels’, ‘seeds’, ‘indexes’] 中选择,以与类 dgl.graphbolt.MiniBatch 的属性对齐。
示例
>>> import torch >>> from dgl import graphbolt as gb
整数:节点数量。
>>> num = 10 >>> item_set = gb.ItemSet(num, names="seeds") >>> list(item_set) [tensor(0), tensor(1), tensor(2), tensor(3), tensor(4), tensor(5), tensor(6), tensor(7), tensor(8), tensor(9)] >>> item_set[:] tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) >>> item_set.names ('seeds',)
Torch 标量:节点数量。与整数相比,dtype 可自定义。
>>> num = torch.tensor(10, dtype=torch.int32) >>> item_set = gb.ItemSet(num, names="seeds") >>> list(item_set) [tensor(0, dtype=torch.int32), tensor(1, dtype=torch.int32), tensor(2, dtype=torch.int32), tensor(3, dtype=torch.int32), tensor(4, dtype=torch.int32), tensor(5, dtype=torch.int32), tensor(6, dtype=torch.int32), tensor(7, dtype=torch.int32), tensor(8, dtype=torch.int32), tensor(9, dtype=torch.int32)] >>> item_set[:] tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=torch.int32) >>> item_set.names ('seeds',)
单个张量:种子节点。
>>> node_ids = torch.arange(0, 5) >>> item_set = gb.ItemSet(node_ids, names="seeds") >>> list(item_set) [tensor(0), tensor(1), tensor(2), tensor(3), tensor(4)] >>> item_set[:] tensor([0, 1, 2, 3, 4]) >>> item_set.names ('seeds',)
形状相同的张量元组:种子节点和标签。
>>> node_ids = torch.arange(0, 5) >>> labels = torch.arange(5, 10) >>> item_set = gb.ItemSet( ... (node_ids, labels), names=("seeds", "labels")) >>> list(item_set) [(tensor(0), tensor(5)), (tensor(1), tensor(6)), (tensor(2), tensor(7)), (tensor(3), tensor(8)), (tensor(4), tensor(9))] >>> item_set[:] (tensor([0, 1, 2, 3, 4]), tensor([5, 6, 7, 8, 9])) >>> item_set.names ('seeds', 'labels')
形状不同的张量元组:种子和标签。
>>> seeds = torch.arange(0, 10).reshape(-1, 2) >>> labels = torch.tensor([1, 1, 0, 0, 0]) >>> item_set = gb.ItemSet( ... (seeds, labels), names=("seeds", "lables")) >>> list(item_set) [(tensor([0, 1]), tensor([1])), (tensor([2, 3]), tensor([1])), (tensor([4, 5]), tensor([0])), (tensor([6, 7]), tensor([0])), (tensor([8, 9]), tensor([0]))] >>> item_set[:] (tensor([[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]]), tensor([1, 1, 0, 0, 0])) >>> item_set.names ('seeds', 'labels')
形状不同的张量元组:超链接和标签。
>>> seeds = torch.arange(0, 10).reshape(-1, 5) >>> labels = torch.tensor([1, 0]) >>> item_set = gb.ItemSet( ... (seeds, labels), names=("seeds", "lables")) >>> list(item_set) [(tensor([0, 1, 2, 3, 4]), tensor([1])), (tensor([5, 6, 7, 8, 9]), tensor([0]))] >>> item_set[:] (tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]), tensor([1, 0])) >>> item_set.names ('seeds', 'labels')