采样子图
- class dgl.graphbolt.SampledSubgraph[源码]
基类:
object
采样子图的抽象类。在异构图的上下文中,每个字段应为 Dict 类型。否则,对于同构图,每个字段应对应其各自的值类型。
- exclude_edges(edges: Dict[str, Tensor] | Tensor, assume_num_node_within_int32: bool = True, async_op: bool = False)[源码]
从采样子图中排除边。
无论采样子图是否具有压缩的行/列节点,都可以使用此函数。如果原始子图具有压缩的行或列节点,则返回的子图中的相应行或列节点也将被压缩。
- 参数:
self (SampledSubgraph) – 采样子图。
edges (Union[torch.Tensor, Dict[str, torch.Tensor]]) – 要排除的边。如果采样子图是同构的,则 edges 应该是一个 N*2 的张量,表示要排除的边。如果采样子图是异构的,则 edges 应该是一个字典,键是边类型,值是相应的要排除的边。
assume_num_node_within_int32 (bool) – 如果为 True,则假定提供的 edges 中的节点 ID 值在 int32 范围内,这可以显著提高计算速度。默认值:True
async_op (bool) – 布尔值,指示调用是否是异步的。如果是,可以通过对返回的 future 调用 wait 来获取结果。
- 返回值:
继承自 SampledSubgraph 的类实例。
- 返回类型:
示例
>>> import dgl.graphbolt as gb >>> import torch >>> sampled_csc = {"A:relation:B": gb.CSCFormatBase( ... indptr=torch.tensor([0, 1, 2, 3]), ... indices=torch.tensor([0, 1, 2]))} >>> original_column_node_ids = {"B": torch.tensor([10, 11, 12])} >>> original_row_node_ids = {"A": torch.tensor([13, 14, 15])} >>> original_edge_ids = {"A:relation:B": torch.tensor([19, 20, 21])} >>> subgraph = gb.SampledSubgraphImpl( ... sampled_csc=sampled_csc, ... original_column_node_ids=original_column_node_ids, ... original_row_node_ids=original_row_node_ids, ... original_edge_ids=original_edge_ids ... ) >>> edges_to_exclude = {"A:relation:B": torch.tensor([[14, 11], [15, 12]])} >>> result = subgraph.exclude_edges(edges_to_exclude) >>> print(result.sampled_csc) {'A:relation:B': CSCFormatBase(indptr=tensor([0, 1, 1, 1]), indices=tensor([0]), )} >>> print(result.original_column_node_ids) {'B': tensor([10, 11, 12])} >>> print(result.original_row_node_ids) {'A': tensor([13, 14, 15])} >>> print(result.original_edge_ids) {'A:relation:B': tensor([19])}
- to_pyg(x: Tensor | Dict[str, Tensor]) PyGLayerData | PyGLayerHeteroData [源码]
处理层输入,使其可以被 PyG 模型层消费。
- 参数:
x (Union[torch.Tensor, Dict[str, torch.Tensor]]) – GNN 层的输入节点特征。
- 返回值:
一个命名元组类,包含 x、edge_index 和 size 字段。通常,PyG GNN 层的 forward 方法会接受这些作为参数。
- 返回类型:
Union[PyGLayerData, PyGLayerHeteroData]
- property original_column_node_ids: Tensor | Dict[str, Tensor]
返回原始图中对应的逆向列节点 ID。列在原始图中的逆向节点 ID。图结构可以被视为由行和列组成的对,这是列的映射 ID。
如果 original_column_node_ids 是一个张量:它表示原始节点 ID。
如果 original_column_node_ids 是一个字典:键应为节点类型,值应为对应的原始异构节点 ID。
如果存在,则表示列 ID 已被压缩,并且 sampled_csc 中的列 ID 与这些压缩的 ID 匹配。
- property original_edge_ids: Tensor | Dict[str, Tensor]
返回原始图中对应的逆向边 ID。边在原始图中的逆向 ID。这在需要边特征时很有用。
如果 original_edge_ids 是一个张量:它表示原始边 ID。
如果 original_edge_ids 是一个字典:键应为边类型,值应为对应的原始异构边 ID。
- property original_row_node_ids: Tensor | Dict[str, Tensor]
返回原始图中对应的逆向行节点 ID。行在原始图中的逆向节点 ID。图结构可以被视为由行和列组成的对,这是行的映射 ID。
如果 original_row_node_ids 是一个张量:它表示原始节点 ID。
如果 original_row_node_ids 是一个字典:键应为节点类型,值应为对应的原始异构节点 ID。
如果存在,则表示行 ID 已被压缩,并且 sampled_csc 中的行 ID 与这些压缩的 ID 匹配。
- property sampled_csc: CSCFormatBase | Dict[str, CSCFormatBase]
- 返回以 CSC 格式表示边的节点对。
如果 sampled_csc 是 CSCFormatBase:它应采用 CSC 格式。indptr 存储数据数组中每列开始的索引。indices 存储非零元素的行索引。
如果 sampled_csc 是一个字典:键应为边类型,值应为对应的节点对。其中的 ID 是异构 ID。
示例
同构图。
>>> import dgl.graphbolt as gb >>> import torch >>> sampled_csc = gb.CSCFormatBase( ... indptr=torch.tensor([0, 1, 2, 3]), ... indices=torch.tensor([0, 1, 2])) >>> print(sampled_csc) CSCFormatBase(indptr=tensor([0, 1, 2, 3]), indices=tensor([0, 1, 2]), )
异构图。
>>> sampled_csc = {"A:relation:B": gb.CSCFormatBase( ... indptr=torch.tensor([0, 1, 2, 3]), ... indices=torch.tensor([0, 1, 2]))} >>> print(sampled_csc) {'A:relation:B': CSCFormatBase(indptr=tensor([0, 1, 2, 3]), indices=tensor([0, 1, 2]), )}