dgl.DGLGraph.adj_tensors
- DGLGraph.adj_tensors(fmt, etype=None)[source]
将给定边类型的邻接矩阵以稀疏矩阵表示的张量形式返回。默认情况下,返回的邻接矩阵的行代表边的源节点,列代表目标节点。 :param fmt: 必须是
coo
,csr
或csc
之一。 :type fmt: str :param etype: 边的类型名称。允许的类型名称格式为源节点类型、边类型和目标节点类型的
(str, str, str)
。或者如果名称可以唯一标识图中的三元组格式,则为一个
str
边类型名称。
如果图只有一种边类型,则可以省略此参数。
- 返回值:
如果
fmt
是coo
,则返回一对源节点和目标节点的 ID 张量。如果fmt
是csr
或csc
,则返回邻接矩阵的 CSR 或 CSC 表示,形式为一个包含三个张量的元组(indptr, indices, edge_ids)
。其中edge_ids
可能是一个包含 0 个元素的空张量,在这种情况下,边 ID 是从 0 开始的连续整数。- 返回类型:
tuple[Tensor]
示例
>>> g = dgl.graph(([0, 1, 2], [1, 2, 3])) >>> g.adj_tensors('coo') (tensor([0, 1, 2]), tensor([1, 2, 3])) >>> g.adj_tensors('csr') (tensor([0, 1, 2, 3, 3]), tensor([1, 2, 3]), tensor([0, 1, 2]))