dgl.DGLGraph.adj_tensors

DGLGraph.adj_tensors(fmt, etype=None)[source]

将给定边类型的邻接矩阵以稀疏矩阵表示的张量形式返回。默认情况下,返回的邻接矩阵的行代表边的源节点,列代表目标节点。 :param fmt: 必须是 coo, csrcsc 之一。 :type fmt: str :param etype: 边的类型名称。允许的类型名称格式为

  • 源节点类型、边类型和目标节点类型的 (str, str, str)

  • 或者如果名称可以唯一标识图中的三元组格式,则为一个 str 边类型名称。

如果图只有一种边类型,则可以省略此参数。

返回值:

如果 fmtcoo,则返回一对源节点和目标节点的 ID 张量。如果 fmtcsrcsc,则返回邻接矩阵的 CSR 或 CSC 表示,形式为一个包含三个张量的元组 (indptr, indices, edge_ids)。其中 edge_ids 可能是一个包含 0 个元素的空张量,在这种情况下,边 ID 是从 0 开始的连续整数。

返回类型:

tuple[Tensor]

示例

>>> g = dgl.graph(([0, 1, 2], [1, 2, 3]))
>>> g.adj_tensors('coo')
(tensor([0, 1, 2]), tensor([1, 2, 3]))
>>> g.adj_tensors('csr')
(tensor([0, 1, 2, 3, 3]), tensor([1, 2, 3]), tensor([0, 1, 2]))