TorchBasedFeature
- class dgl.graphbolt.TorchBasedFeature(torch_feature: Tensor, metadata: Dict | None = None)[source]
基类:
Feature
一个基于 PyTorch 特征的包装器。
通过一个 PyTorch 特征初始化一个基于 PyTorch 的特征存储。注意特征可以是在内存中或在磁盘上。
- 参数:
torch_feature (torch.Tensor) – PyTorch 特征。注意张量的维度应大于 1。
示例
>>> import torch >>> from dgl import graphbolt as gb
特征在内存中。
>>> torch_feat = torch.arange(10).reshape(2, -1) >>> feature = gb.TorchBasedFeature(torch_feat) >>> feature.read() tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]) >>> feature.read(torch.tensor([0])) tensor([[0, 1, 2, 3, 4]]) >>> feature.update(torch.tensor([[1 for _ in range(5)]]), ... torch.tensor([1])) >>> feature.read(torch.tensor([0, 1])) tensor([[0, 1, 2, 3, 4], [1, 1, 1, 1, 1]]) >>> feature.size() torch.Size([5])
2. 特征在磁盘上。请注意,您可以使用 gb.numpy_save_aligned 代替 np.save,这可能会提高性能。
>>> import numpy as np >>> arr = np.array([[1, 2], [3, 4]]) >>> np.save("/tmp/arr.npy", arr) >>> torch_feat = torch.from_numpy(np.load("/tmp/arr.npy", mmap_mode="r+")) >>> feature = gb.TorchBasedFeature(torch_feat) >>> feature.read() tensor([[1, 2], [3, 4]]) >>> feature.read(torch.tensor([0])) tensor([[1, 2]])
锁页 CPU 特征。
>>> torch_feat = torch.arange(10).reshape(2, -1).pin_memory() >>> feature = gb.TorchBasedFeature(torch_feat) >>> feature.read().device device(type='cuda', index=0) >>> feature.read(torch.tensor([0]).cuda()).device device(type='cuda', index=0)
- read(ids: Tensor | None = None)[source]
按索引读取特征。
如果特征在锁页 CPU 内存上,并且 ids 在 GPU 或锁页 CPU 内存上,则将由 GPU 读取,返回的张量将在 GPU 上。否则,返回的张量将在 CPU 上。
- 参数:
ids (torch.Tensor, 可选) – 特征的索引。如果指定,则只读取指定索引的特征。如果为 None,则返回整个特征。
- 返回值:
读取的特征。
- 返回类型:
torch.Tensor
- read_async(ids: Tensor)[source]
按索引异步读取特征。
- 参数:
ids (torch.Tensor) – 特征的索引。只读取指定索引的特征。
- 返回值:
返回的生成器对象在第
read_async_num_stages(ids.device)
次调用时返回一个 future。可以通过调用返回的 future 对象的.wait()
方法来访问返回结果。多次调用.wait()
是未定义行为。- 返回类型:
一个生成器对象。
示例
>>> import dgl.graphbolt as gb >>> feature = gb.Feature(...) >>> ids = torch.tensor([0, 2]) >>> for stage, future in enumerate(feature.read_async(ids)): ... pass >>> assert stage + 1 == feature.read_async_num_stages(ids.device) >>> result = future.wait() # result contains the read values.