KNNGraph
- class dgl.nn.pytorch.factory.KNNGraph(k)[source]
基类:
Module
将一个点集转换为图,或将一批点数相同的点集转换为这些图的批处理联合的层。
KNNGraph 的实现步骤如下
计算所有点的 NxN 成对距离矩阵。
对于每个点,选择距离最小的 k 个点作为其 k 近邻。
构建一个图,其中的边从每个节点的 k 近邻指向该节点。
整体计算复杂度为 \(O(N^2(logN + D)\)。
如果提供了批量点集,点集 \(i\) 中的点 \(j\) 将映射到图节点 ID:\(i \times M + j\),其中 \(M\) 是每个点集中的节点数。
每个节点的前驱是对应点的 k 近邻。
- 参数:
k (int) – 邻居数量。
注意事项
为节点找到的近邻包含节点本身。
示例
以下示例使用 PyTorch 后端。
>>> import torch >>> from dgl.nn.pytorch.factory import KNNGraph >>> >>> kg = KNNGraph(2) >>> x = torch.tensor([[0,1], [1,2], [1,3], [100, 101], [101, 102], [50, 50]]) >>> g = kg(x) >>> print(g.edges()) (tensor([0, 1, 1, 1, 2, 2, 2, 3, 3, 4, 4, 5]), tensor([0, 0, 1, 2, 1, 2, 5, 3, 4, 3, 4, 5]))
- forward(x, algorithm='bruteforce-blas', dist='euclidean', exclude_self=False)[source]
前向计算。
- 参数:
x (Tensor) – \((M, D)\) 或 \((N, M, D)\),其中 \(N\) 表示点集数量,\(M\) 表示每个点集中的点数,\(D\) 表示特征维度大小。
algorithm (str, optional) –
用于计算 k 近邻的算法。
’bruteforce-blas’ 将首先使用后端框架提供的 BLAS 矩阵乘法运算计算距离矩阵。然后使用 topk 算法获取 k 近邻。当点集较小时,此方法速度很快,但内存复杂度为 \(O(N^2)\),其中 \(N\) 是点的数量。
’bruteforce’ 将逐对计算距离,并在计算距离时直接选择 k 近邻。此方法比 ‘bruteforce-blas’ 慢,但内存开销较小(即 \(O(Nk)\),其中 \(N\) 是点数,\(k\) 是每个节点的近邻数),因为我们无需存储所有距离。
’bruteforce-sharemem’ (仅限 CUDA) 类似于 ‘bruteforce’,但在 CUDA 设备中使用共享内存作为缓冲区。当输入点的维度不大时,此方法比 ‘bruteforce’ 快。此方法仅在 CUDA 设备上可用。
’kd-tree’ 将使用 kd-tree 算法 (仅限 CPU)。此方法适用于低维数据(例如 3D 点云)
’nn-descent’ 是论文 Efficient k-nearest neighbor graph construction for generic similarity measures 中提出的一种近似方法。此方法将在“邻居的邻居”中搜索近邻候选。
(默认: ‘bruteforce-blas’)
dist (str, optional) –
用于计算点之间距离的度量。可以是以下度量之一:* ‘euclidean’:使用欧几里得距离 (L2 范数)
\(\sqrt{\sum_{i} (x_{i} - y_{i})^{2}}\).
’cosine’:使用余弦距离。
(默认: ‘euclidean’)
exclude_self (bool, optional) – 如果为 True,则输出图将不包含自循环边,并且每个节点不计为其自身的 k 个邻居之一。如果为 False,则输出图将包含自循环边,并且节点将计为其自身的 k 个邻居之一。
- 返回值:
一个不带特征的 DGLGraph。
- 返回类型: