dgl.graphbolt.compact_csc_format

dgl.graphbolt.compact_csc_format(csc_formats: CSCFormatBase | Dict[str, CSCFormatBase], dst_nodes: Tensor | Dict[str, Tensor], dst_timestamps: Tensor | Dict[str, Tensor] | None = None)[源码]

将 CSC 格式中的行(源)ID 重新标记为从 0 开始的连续范围,并返回每种类型的原始行节点 ID。

注意:1. 列(目标)ID 包含在重新标记的行 ID 中。2. 如果存在重复的行 ID,它们将不会被去重,而是被视为不同的节点。3. 如果提供了 dst_timestamps,每个目标节点的时间戳将被广播到其相应的源节点。

参数:
  • csc_formats (Union[CSCFormatBase, Dict[str, CSCFormatBase]]) – 表示源-目标边的 CSC 格式。 - 如果 csc_formats 是一个 CSCFormatBase:表示图是同构的。此外,其中的 indptr 和 indice 应该是表示 CSC 格式中源和目标对的 torch.tensor。其中的 ID 是同构 ID。 - 如果 csc_formats 是一个 Dict[str, CSCFormatBase]:键应为边类型,值应为 CSC 格式的节点对。其中的 ID 是异构 ID。

  • dst_nodes (Union[torch.Tensor, Dict[str, torch.Tensor]]) – 节点对中所有目标节点的节点。 - 如果 dst_nodes 是一个张量:表示图是同构的。 - 如果 dst_nodes 是一个字典:键是节点类型,值是相应的节点。其中的 ID 是异构 ID。

  • dst_timestamps (Optional[Union[torch.Tensor, Dict[str, torch.Tensor]]]) – CSC 格式中所有目标节点的时间戳。如果提供,每个目标节点的时间戳将被广播到其相应的源节点。

返回值:

输入中所有节点的原始行节点 ID(按类型划分)的张量。压缩后的 CSC 格式,其中节点 ID 被替换为从 0 到 N 的映射节点 ID。如果提供了 dst_timestamps,则返回输入中所有节点的源时间戳(按类型划分)。

返回类型:

Tuple[original_row_node_ids, compacted_csc_formats, ...]

示例

>>> import dgl.graphbolt as gb
>>> csc_formats = {
...     "n2:e2:n1": gb.CSCFormatBase(
...         indptr=torch.tensor([0, 1, 3]), indices=torch.tensor([5, 4, 6])
...     ),
...     "n1:e1:n1": gb.CSCFormatBase(
...         indptr=torch.tensor([0, 1, 3]), indices=torch.tensor([1, 2, 3])
...     ),
... }
>>> dst_nodes = {"n1": torch.LongTensor([2, 4])}
>>> original_row_node_ids, compacted_csc_formats = gb.compact_csc_format(
...     csc_formats, dst_nodes
... )
>>> original_row_node_ids
{'n1': tensor([2, 4, 1, 2, 3]), 'n2': tensor([5, 4, 6])}
>>> compacted_csc_formats
{'n2:e2:n1': CSCFormatBase(indptr=tensor([0, 1, 3]),
            indices=tensor([0, 1, 2]),
), 'n1:e1:n1': CSCFormatBase(indptr=tensor([0, 1, 3]),
            indices=tensor([2, 3, 4]),
)}
>>> csc_formats = {
...     "n2:e2:n1": gb.CSCFormatBase(
...         indptr=torch.tensor([0, 1, 3]), indices=torch.tensor([5, 4, 6])
...     ),
...     "n1:e1:n1": gb.CSCFormatBase(
...         indptr=torch.tensor([0, 1, 3]), indices=torch.tensor([1, 2, 3])
...     ),
... }
>>> dst_nodes = {"n1": torch.LongTensor([2, 4])}
>>> original_row_node_ids, compacted_csc_formats = gb.compact_csc_format(
...     csc_formats, dst_nodes
... )
>>> original_row_node_ids
{'n1': tensor([2, 4, 1, 2, 3]), 'n2': tensor([5, 4, 6])}
>>> compacted_csc_formats
{'n2:e2:n1': CSCFormatBase(indptr=tensor([0, 1, 3]),
            indices=tensor([0, 1, 2]),
), 'n1:e1:n1': CSCFormatBase(indptr=tensor([0, 1, 3]),
            indices=tensor([2, 3, 4]),
)}
>>> dst_timestamps = {"n1": torch.LongTensor([10, 20])}
>>> (
...     original_row_node_ids,
...     compacted_csc_formats,
...     src_timestamps,
... ) = gb.compact_csc_format(csc_formats, dst_nodes, dst_timestamps)
>>> src_timestamps
{'n1': tensor([10, 20, 10, 20, 20]), 'n2': tensor([10, 20, 20])}