TorchBasedFeatureStore

class dgl.graphbolt.TorchBasedFeatureStore(feat_data: List[OnDiskFeatureData])[源码]

基类: BasicFeatureStore

一个用于管理多个基于 PyTorch 的特征以进行访问的存储。

特征存储由 feat_data 描述。feat_data 是一个 OnDiskFeatureData 对象的列表。

对于一个特征存储,其格式必须是 PyTorch 或 Numpy 格式的“pt”或“npy”。如果格式是“pt”,则特征存储必须加载到内存中。如果格式是“npy”,则特征存储可以加载到内存或磁盘上。请注意,可以使用 gb.numpy_save_aligned 作为 np.save 的替代,以可能提高性能。

参数:

feat_data (List[OnDiskFeatureData]) – 特征存储的描述。

示例

>>> import torch
>>> import numpy as np
>>> from dgl import graphbolt as gb
>>> edge_label = torch.tensor([[1], [2], [3]])
>>> node_feat = torch.tensor([[1, 2, 3], [4, 5, 6]])
>>> torch.save(edge_label, "/tmp/edge_label.pt")
>>> gb.numpy_save_aligned("/tmp/node_feat.npy", node_feat.numpy())
>>> feat_data = [
...     gb.OnDiskFeatureData(domain="edge", type="author:writes:paper",
...         name="label", format="torch", path="/tmp/edge_label.pt",
...         in_memory=True),
...     gb.OnDiskFeatureData(domain="node", type="paper", name="feat",
...         format="numpy", path="/tmp/node_feat.npy", in_memory=False),
... ]
>>> feature_store = gb.TorchBasedFeatureStore(feat_data)
is_pinned()[源码]

如果所有存储的特征都已锁定(pinned),则返回 True。

pin_memory_()[源码]

原地操作,将特征存储复制到锁定内存。返回修改后的同一对象。

to(device)[源码]

TorchBasedFeatureStore 复制到指定设备。