dgl.graphbolt.compact_csc_format
- dgl.graphbolt.compact_csc_format(csc_formats: CSCFormatBase | Dict[str, CSCFormatBase], dst_nodes: Tensor | Dict[str, Tensor], dst_timestamps: Tensor | Dict[str, Tensor] | None = None)[源码]
将 CSC 格式中的行(源)ID 重新标记为从 0 开始的连续范围,并返回每种类型的原始行节点 ID。
注意:1. 列(目标)ID 包含在重新标记的行 ID 中。2. 如果存在重复的行 ID,它们将不会被去重,而是被视为不同的节点。3. 如果提供了 dst_timestamps,每个目标节点的时间戳将被广播到其相应的源节点。
- 参数:
csc_formats (Union[CSCFormatBase, Dict[str, CSCFormatBase]]) – 表示源-目标边的 CSC 格式。 - 如果 csc_formats 是一个 CSCFormatBase:表示图是同构的。此外,其中的 indptr 和 indice 应该是表示 CSC 格式中源和目标对的 torch.tensor。其中的 ID 是同构 ID。 - 如果 csc_formats 是一个 Dict[str, CSCFormatBase]:键应为边类型,值应为 CSC 格式的节点对。其中的 ID 是异构 ID。
dst_nodes (Union[torch.Tensor, Dict[str, torch.Tensor]]) – 节点对中所有目标节点的节点。 - 如果 dst_nodes 是一个张量:表示图是同构的。 - 如果 dst_nodes 是一个字典:键是节点类型,值是相应的节点。其中的 ID 是异构 ID。
dst_timestamps (Optional[Union[torch.Tensor, Dict[str, torch.Tensor]]]) – CSC 格式中所有目标节点的时间戳。如果提供,每个目标节点的时间戳将被广播到其相应的源节点。
- 返回值:
输入中所有节点的原始行节点 ID(按类型划分)的张量。压缩后的 CSC 格式,其中节点 ID 被替换为从 0 到 N 的映射节点 ID。如果提供了 dst_timestamps,则返回输入中所有节点的源时间戳(按类型划分)。
- 返回类型:
Tuple[original_row_node_ids, compacted_csc_formats, ...]
示例
>>> import dgl.graphbolt as gb >>> csc_formats = { ... "n2:e2:n1": gb.CSCFormatBase( ... indptr=torch.tensor([0, 1, 3]), indices=torch.tensor([5, 4, 6]) ... ), ... "n1:e1:n1": gb.CSCFormatBase( ... indptr=torch.tensor([0, 1, 3]), indices=torch.tensor([1, 2, 3]) ... ), ... } >>> dst_nodes = {"n1": torch.LongTensor([2, 4])} >>> original_row_node_ids, compacted_csc_formats = gb.compact_csc_format( ... csc_formats, dst_nodes ... ) >>> original_row_node_ids {'n1': tensor([2, 4, 1, 2, 3]), 'n2': tensor([5, 4, 6])} >>> compacted_csc_formats {'n2:e2:n1': CSCFormatBase(indptr=tensor([0, 1, 3]), indices=tensor([0, 1, 2]), ), 'n1:e1:n1': CSCFormatBase(indptr=tensor([0, 1, 3]), indices=tensor([2, 3, 4]), )}
>>> csc_formats = { ... "n2:e2:n1": gb.CSCFormatBase( ... indptr=torch.tensor([0, 1, 3]), indices=torch.tensor([5, 4, 6]) ... ), ... "n1:e1:n1": gb.CSCFormatBase( ... indptr=torch.tensor([0, 1, 3]), indices=torch.tensor([1, 2, 3]) ... ), ... } >>> dst_nodes = {"n1": torch.LongTensor([2, 4])} >>> original_row_node_ids, compacted_csc_formats = gb.compact_csc_format( ... csc_formats, dst_nodes ... ) >>> original_row_node_ids {'n1': tensor([2, 4, 1, 2, 3]), 'n2': tensor([5, 4, 6])} >>> compacted_csc_formats {'n2:e2:n1': CSCFormatBase(indptr=tensor([0, 1, 3]), indices=tensor([0, 1, 2]), ), 'n1:e1:n1': CSCFormatBase(indptr=tensor([0, 1, 3]), indices=tensor([2, 3, 4]), )}
>>> dst_timestamps = {"n1": torch.LongTensor([10, 20])} >>> ( ... original_row_node_ids, ... compacted_csc_formats, ... src_timestamps, ... ) = gb.compact_csc_format(csc_formats, dst_nodes, dst_timestamps) >>> src_timestamps {'n1': tensor([10, 20, 10, 20, 20]), 'n2': tensor([10, 20, 20])}