KNNGraph

class dgl.nn.pytorch.factory.KNNGraph(k)[source]

基类: Module

将一个点集转换为图,或将一批点数相同的点集转换为这些图的批处理联合的层。

KNNGraph 的实现步骤如下

  1. 计算所有点的 NxN 成对距离矩阵。

  2. 对于每个点,选择距离最小的 k 个点作为其 k 近邻。

  3. 构建一个图,其中的边从每个节点的 k 近邻指向该节点。

整体计算复杂度为 \(O(N^2(logN + D)\)

如果提供了批量点集,点集 \(i\) 中的点 \(j\) 将映射到图节点 ID:\(i \times M + j\),其中 \(M\) 是每个点集中的节点数。

每个节点的前驱是对应点的 k 近邻。

参数:

k (int) – 邻居数量。

注意事项

为节点找到的近邻包含节点本身。

示例

以下示例使用 PyTorch 后端。

>>> import torch
>>> from dgl.nn.pytorch.factory import KNNGraph
>>>
>>> kg = KNNGraph(2)
>>> x = torch.tensor([[0,1],
                      [1,2],
                      [1,3],
                      [100, 101],
                      [101, 102],
                      [50, 50]])
>>> g = kg(x)
>>> print(g.edges())
    (tensor([0, 1, 1, 1, 2, 2, 2, 3, 3, 4, 4, 5]),
     tensor([0, 0, 1, 2, 1, 2, 5, 3, 4, 3, 4, 5]))
forward(x, algorithm='bruteforce-blas', dist='euclidean', exclude_self=False)[source]

前向计算。

参数:
  • x (Tensor) – \((M, D)\)\((N, M, D)\),其中 \(N\) 表示点集数量,\(M\) 表示每个点集中的点数,\(D\) 表示特征维度大小。

  • algorithm (str, optional) –

    用于计算 k 近邻的算法。

    • ’bruteforce-blas’ 将首先使用后端框架提供的 BLAS 矩阵乘法运算计算距离矩阵。然后使用 topk 算法获取 k 近邻。当点集较小时,此方法速度很快,但内存复杂度为 \(O(N^2)\),其中 \(N\) 是点的数量。

    • ’bruteforce’ 将逐对计算距离,并在计算距离时直接选择 k 近邻。此方法比 ‘bruteforce-blas’ 慢,但内存开销较小(即 \(O(Nk)\),其中 \(N\) 是点数,\(k\) 是每个节点的近邻数),因为我们无需存储所有距离。

    • ’bruteforce-sharemem’ (仅限 CUDA) 类似于 ‘bruteforce’,但在 CUDA 设备中使用共享内存作为缓冲区。当输入点的维度不大时,此方法比 ‘bruteforce’ 快。此方法仅在 CUDA 设备上可用。

    • ’kd-tree’ 将使用 kd-tree 算法 (仅限 CPU)。此方法适用于低维数据(例如 3D 点云)

    • ’nn-descent’ 是论文 Efficient k-nearest neighbor graph construction for generic similarity measures 中提出的一种近似方法。此方法将在“邻居的邻居”中搜索近邻候选。

    (默认: ‘bruteforce-blas’)

  • dist (str, optional) –

    用于计算点之间距离的度量。可以是以下度量之一:* ‘euclidean’:使用欧几里得距离 (L2 范数)

    \(\sqrt{\sum_{i} (x_{i} - y_{i})^{2}}\).

    • ’cosine’:使用余弦距离。

    (默认: ‘euclidean’)

  • exclude_self (bool, optional) – 如果为 True,则输出图将不包含自循环边,并且每个节点不计为其自身的 k 个邻居之一。如果为 False,则输出图将包含自循环边,并且节点将计为其自身的 k 个邻居之一。

返回值:

一个不带特征的 DGLGraph。

返回类型:

DGLGraph