线图神经网络

作者: Qi Huang, Yu Gai, Minjie Wang, Zheng Zhang

警告

本教程旨在通过代码解释,深入了解这篇论文。因此,此处实现并未针对运行效率进行优化。对于推荐的实现,请参阅官方示例

在本教程中,您将学习如何通过实现线图神经网络 (LGNN) 来解决社区检测任务。社区检测,或称图聚类,旨在将图中的顶点划分到簇中,使得簇内的节点彼此之间更相似。

图卷积网络教程中,您学习了如何在半监督设置下对输入图的节点进行分类。您使用了图卷积神经网络 (GCN) 作为图特征的嵌入机制。

为了将图神经网络 (GNN) 推广到监督社区检测,论文Supervised Community Detection with Line Graph Neural Networks中介绍了一种基于线图的 GNN 变体。该模型的一个亮点是增强了直接的 GNN 架构,使其能够在由非回溯算子定义的边邻接线图上运行。

线图神经网络 (LGNN) 展示了 DGL 如何通过结合基本的张量操作、稀疏矩阵乘法和消息传递 API 来实现高级图算法。

在以下各节中,您将学习社区检测、线图、LGNN 及其实现。

使用 Cora 数据集进行监督社区检测任务

社区检测

在社区检测任务中,您将相似节点聚类而不是为其打标签。节点相似性通常被描述为每个簇内部具有更高的密度。

社区检测和节点分类有什么区别? 与节点分类相比,社区检测侧重于检索图中的簇信息,而不是为节点分配特定标签。例如,只要一个节点与其社区成员聚在一起,该节点被分配为“社区 A”还是“社区 B”并不重要;而在电影网络分类任务中,将所有“好电影”都分配给标签“坏电影”将是灾难性的。

那么,社区检测算法和 k-means 等其他聚类算法之间有什么区别呢? 社区检测算法在图结构数据上运行。与 k-means 相比,社区检测利用图结构,而不是简单地根据节点特征进行聚类。

Cora 数据集

为了与 GCN 教程保持一致,您使用 Cora 数据集来演示一个简单的社区检测任务。Cora 是一个科学出版物数据集,包含 2708 篇论文,分属于七个不同的机器学习领域。在这里,您将 Cora 建模为一个有向图,其中每个节点是一篇论文,每条边是一个引用链接(A->B 表示 A 引用 B)。这是整个 Cora 数据集的可视化。

cora

Cora 自然包含七个类别,下面的统计数据表明每个类别都满足我们的社区假设,即同一类别的节点之间的连接概率高于与不同类别节点的连接概率。以下代码片段验证了类内边多于类间边。

import os

os.environ["DGLBACKEND"] = "pytorch"
import dgl
import torch
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from dgl.data import citation_graph as citegrh

data = citegrh.load_cora()

G = data[0]
labels = th.tensor(G.ndata["label"])

# find all the nodes labeled with class 0
label0_nodes = th.nonzero(labels == 0, as_tuple=False).squeeze()
# find all the edges pointing to class 0 nodes
src, _ = G.in_edges(label0_nodes)
src_labels = labels[src]
# find all the edges whose both endpoints are in class 0
intra_src = th.nonzero(src_labels == 0, as_tuple=False)
print("Intra-class edges percent: %.4f" % (len(intra_src) / len(src_labels)))

import matplotlib.pyplot as plt
  NumNodes: 2708
  NumEdges: 10556
  NumFeats: 1433
  NumClasses: 7
  NumTrainingSamples: 140
  NumValidationSamples: 500
  NumTestSamples: 1000
Done loading data from cached files.
/dgl/tutorials/models/1_gnn/6_line_graph.py:102: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  labels = th.tensor(G.ndata["label"])
Intra-class edges percent: 0.6994

使用测试数据集从 Cora 中提取二分类社区子图

不失一般性,在本教程中,我们将任务范围限制在二分类社区检测。

注意

为了从 Cora 创建一个用于练习的二分类社区数据集,首先从原始 Cora 的七个类别中提取所有两类对。对于每一对,您将每个类别视为一个社区,并找到至少包含一条跨社区边的最大子图作为训练示例。因此,这个小数据集中共有 21 个训练样本。

使用以下代码,您可以可视化一个训练样本及其社区结构。

import networkx as nx

train_set = dgl.data.CoraBinary()
G1, pmpd1, label1 = train_set[1]
nx_G1 = G1.to_networkx()


def visualize(labels, g):
    pos = nx.spring_layout(g, seed=1)
    plt.figure(figsize=(8, 8))
    plt.axis("off")
    nx.draw_networkx(
        g,
        pos=pos,
        node_size=50,
        cmap=plt.get_cmap("coolwarm"),
        node_color=labels,
        edge_color="k",
        arrows=False,
        width=0.5,
        style="dotted",
        with_labels=False,
    )


visualize(label1, nx_G1)
6 line graph
Downloading /root/.dgl/cora_binary.zip from https://data.dgl.ai/dataset/cora_binary.zip...

/root/.dgl/cora_binary.zip:   0%|          | 0.00/373k [00:00<?, ?B/s]
/root/.dgl/cora_binary.zip: 100%|██████████| 373k/373k [00:00<00:00, 16.3MB/s]
Extracting file to /root/.dgl/cora_binary_2ffdf50c
Done saving data into cached files.
Done saving data into cached files.

要了解更多信息,请查阅原始研究论文,了解如何推广到多社区情况。

监督设置下的社区检测

社区检测问题可以通过监督和无监督方法解决。您可以在监督设置下将社区检测形式化如下

  • 每个训练示例由 \((G, L)\) 组成,其中 \(G\) 是有向图 \((V, E)\)。对于 \(V\) 中的每个节点 \(v\),我们分配一个真实社区标签 \(z_v \in \{0,1\}\)

  • 参数化模型 \(f(G, \theta)\) 预测节点 \(V\) 的标签集 \(\tilde{Z} = f(G)\)

  • 对于每个示例 \((G,L)\),模型学习最小化专门设计的损失函数(等变损失)\(L_{equivariant} = (\tilde{Z},Z)\)

注意

在这个监督设置中,模型自然会预测每个社区的标签。然而,社区分配应该对标签排列具有等变性。为了实现这一点,在每个前向传播过程中,我们取所有可能的标签排列计算出的损失中的最小值。

数学上,这意味着 \(L_{equivariant} = \underset{\pi \in S_c} {min}-\log(\hat{\pi}, \pi)\),其中 \(S_c\) 是所有标签排列的集合,\(\hat{\pi}\) 是预测标签的集合,\(- \log(\hat{\pi},\pi)\) 表示负对数似然。

例如,对于一个节点集为 \(\{1,2,3,4\}\) 且社区分配为 \(\{A, A, A, B\}\) 的样本图,每个节点的标签 \(l \in \{0,1\}\),所有可能的排列组 \(S_c = \{\{0,0,0,1\}, \{1,1,1,0\}\}\)

线图神经网络关键思想

这个主题的一个关键创新是使用线图。与之前教程中的模型不同,消息传递不仅发生在原始图上(例如 Cora 的二分类社区子图),还发生在与原始图关联的线图上。

什么是线图?

在图论中,线图是一种图表示,它编码了原始图中的边邻接结构。

具体来说,线图 \(L(G)\) 将原始图 G 的一条边变成一个节点。下图(取自研究论文)对此进行了说明。

lg

这里,\(e_{A}:= (i\rightarrow j)\)\(e_{B}:= (j\rightarrow k)\) 是原始图 \(G\) 中的两条边。在线图 \(G_L\) 中,它们对应于节点 \(v^{l}_{A}, v^{l}_{B}\)

接下来很自然的问题是,如何连接线图中的节点? 如何连接两条边? 在这里,我们使用以下连接规则

lg 中的两个节点 \(v^{l}_{A}\), \(v^{l}_{B}\) 相连,如果对应的 g 中的两条边 \(e_{A}, e_{B}\) 共享一个且仅一个节点:\(e_{A}\) 的目标节点是 \(e_{B}\) 的源节点 (\(j\))。

注意

数学上,这个定义对应于一个称为非回溯算子的概念:\(B_{(i \rightarrow j), (\hat{i} \rightarrow \hat{j})}\) \(= \begin{cases} 1 \text{ if } j = \hat{i}, \hat{j} \neq i\\ 0 \text{ otherwise} \end{cases}\),其中当 \(B_{node1, node2} = 1\) 时形成一条边。

LGNN 中的一层,算法结构

LGNN 将一系列线图神经网络层串联起来。图表示 \(x\) 及其线图伴侣 \(y\) 随着数据流演化如下。

alg

在第 \(k\) 层,第 \(l\) 个通道的第 \(i\) 个神经元使用以下方式更新其嵌入 \(x^{(k+1)}_{i,l}\)

\[\begin{split}\begin{split} x^{(k+1)}_{i,l} ={}&\rho[x^{(k)}_{i}\theta^{(k)}_{1,l} +(Dx^{(k)})_{i}\theta^{(k)}_{2,l} \\ &+\sum^{J-1}_{j=0}(A^{2^{j}}x^{k})_{i}\theta^{(k)}_{3+j,l}\\ &+[\{\text{Pm},\text{Pd}\}y^{(k)}]_{i}\theta^{(k)}_{3+J,l}] \\ &+\text{跳跃连接} \qquad i \in V, l = 1,2,3, ... b_{k+1}/2 \end{split}\end{split}\]

然后,线图表示 \(y^{(k+1)}_{i',l^{'}}\) 使用,

\[\begin{split}\begin{split} y^{(k+1)}_{i',l^{'}} = {}&\rho[y^{(k)}_{i^{'}}\gamma^{(k)}_{1,l^{'}}+ (D_{L(G)}y^{(k)})_{i^{'}}\gamma^{(k)}_{2,l^{'}}\\ &+\sum^{J-1}_{j=0}(A_{L(G)}^{2^{j}}y^{k})_{i}\gamma^{(k)}_{3+j,l^{'}}\\ &+[\{\text{Pm},\text{Pd}\}^{T}x^{(k+1)}]_{i^{'}}\gamma^{(k)}_{3+J,l^{'}}]\\ &+\text{跳跃连接} \qquad i^{'} \in V_{l}, l^{'} = 1,2,3, ... b^{'}_{k+1}/2 \end{split}\end{split}\]

其中 \(\text{skip-connection}\) 指的是在没有非线性 \(\rho\) 的情况下执行相同的操作,并使用线性投影 \(\theta_\{\frac{b_{k+1}}{2} + 1, ..., b_{k+1}-1, b_{k+1}\}\)\(\gamma_\{\frac{b_{k+1}}{2} + 1, ..., b_{k+1}-1, b_{k+1}\}\)

在 DGL 中实现 LGNN

尽管上一节中的方程式可能看起来令人望而生畏,但在实现 LGNN 之前理解以下信息会有所帮助。

这两个方程式是对称的,可以作为同一个类的两个实例实现,但参数不同。第一个方程作用于图表示 \(x\),而第二个方程作用于线图表示 \(y\)。我们将这种抽象记为 \(f\)。那么第一个是 \(f(x,y; \theta_x)\),第二个是 \(f(y,x, \theta_y)\)。也就是说,它们被参数化以分别计算原始图及其伴随线图的表示。

每个方程式由四项组成。以第一个为例,如下所示。

  • \(x^{(k)}\theta^{(k)}_{1,l}\),前一层输出 \(x^{(k)}\) 的线性投影,记为 \(\text{prev}(x)\)

  • \((Dx^{(k)})\theta^{(k)}_{2,l}\),度算子作用于 \(x^{(k)}\) 后的线性投影,记为 \(\text{deg}(x)\)

  • \(\sum^{J-1}_{j=0}(A^{2^{j}}x^{(k)})\theta^{(k)}_{3+j,l}\)\(2^{j}\) 邻接算子作用于 \(x^{(k)}\) 后的求和,记为 \(\text{radius}(x)\)

  • \([\{Pm,Pd\}y^{(k)}]\theta^{(k)}_{3+J,l}\),使用关联矩阵 \(\{Pm, Pd\}\) 融合另一个图的嵌入信息,然后进行线性投影,记为 \(\text{fuse}(y)\)

每个项都使用不同的参数再次执行,并且在求和之后没有非线性。因此,\(f\) 可以写成

\[\begin{split}\begin{split} f(x^{(k)},y^{(k)}) = {}\rho[&\text{prev}(x^{(k-1)}) + \text{deg}(x^{(k-1)}) +\text{radius}(x^{k-1}) +\text{fuse}(y^{(k)})]\\ +&\text{prev}(x^{(k-1)}) + \text{deg}(x^{(k-1)}) +\text{radius}(x^{k-1}) +\text{fuse}(y^{(k)}) \end{split}\end{split}\]

两个方程式按以下顺序串联

\[\begin{split}\begin{split} x^{(k+1)} = {}& f(x^{(k)}, y^{(k)})\\ y^{(k+1)} = {}& f(y^{(k)}, x^{(k+1)}) \end{split}\end{split}\]

请记住本概述中列出的观察结果,并继续进行实现。重要的一点是,您对提到的项使用不同的策略。

注意

通过这个解释,您可以更全面地理解 \(\{Pm, Pd\}\)。粗略地说,glg(线图)如何与循环简略传播协同工作之间存在关联。在这里,您将 \(\{Pm, Pd\}\) 在数据集中实现为 SciPy COO 稀疏矩阵,并在批量处理时将它们堆叠为张量。另一种批量处理解决方案是将 $\{Pm, Pd\}$ 视为二分图的邻接矩阵,它将线图的特征映射到图的特征,反之亦然。

\(\text{prev}\)\(\text{deg}\) 实现为张量操作

线性投影和度操作都只是矩阵乘法。将它们写成 PyTorch 张量操作。

__init__ 中,您定义投影变量。

self.linear_prev = nn.Linear(in_feats, out_feats)
self.linear_deg = nn.Linear(in_feats, out_feats)

forward() 中,\(\text{prev}\)\(\text{deg}\) 与任何其他 PyTorch 张量操作相同。

prev_proj = self.linear_prev(feat_a)
deg_proj = self.linear_deg(deg * feat_a)

在 DGL 中将 \(\text{radius}\) 实现为消息传递

如 GCN 教程中所讨论的,您可以将一个邻接算子形式化为进行一步消息传递。作为一个泛化,\(2^j\) 邻接操作可以形式化为执行 \(2^j\) 步消息传递。因此,求和等价于对每个节点的 \(2^j, j=0, 1, 2..\) 步消息传递表示进行求和,即收集每个节点 \(2^{j}\) 邻域中的信息。

__init__ 中,定义在消息传递的每个 \(2^j\) 步中使用的投影变量。

self.linear_radius = nn.ModuleList(
        [nn.Linear(in_feats, out_feats) for i in range(radius)])

__forward__ 中,使用函数 aggregate_radius() 从多跳收集数据。这可以在以下代码中看到。请注意 update_all 被多次调用。

# Return a list containing features gathered from multiple radius.
import dgl.function as fn


def aggregate_radius(radius, g, z):
    # initializing list to collect message passing result
    z_list = []
    g.ndata["z"] = z
    # pulling message from 1-hop neighbourhood
    g.update_all(fn.copy_u(u="z", out="m"), fn.sum(msg="m", out="z"))
    z_list.append(g.ndata["z"])
    for i in range(radius - 1):
        for j in range(2**i):
            # pulling message from 2^j neighborhood
            g.update_all(fn.copy_u(u="z", out="m"), fn.sum(msg="m", out="z"))
        z_list.append(g.ndata["z"])
    return z_list

\(\text{fuse}\) 实现为稀疏矩阵乘法

\(\{Pm, Pd\}\) 是一个稀疏矩阵,在每一列上只有两个非零项。因此,您可以在数据集中将其构造为稀疏矩阵,并将 \(\text{fuse}\) 实现为稀疏矩阵乘法。

__forward__

fuse = self.linear_fuse(th.mm(pm_pd, feat_b))

完成 \(f(x, y)\)

最后,以下展示了如何将所有项相加,将其传递给跳跃连接和批量归一化。

result = prev_proj + deg_proj + radius_proj + fuse

将结果传递给跳跃连接。

result = th.cat([result[:, :n], F.relu(result[:, n:])], 1)

然后将结果传递给批量归一化。

result = self.bn(result) #Batch Normalization.

这是一个 LGNN 层抽象 \(f(x,y)\) 的完整代码

class LGNNCore(nn.Module):
    def __init__(self, in_feats, out_feats, radius):
        super(LGNNCore, self).__init__()
        self.out_feats = out_feats
        self.radius = radius

        self.linear_prev = nn.Linear(in_feats, out_feats)
        self.linear_deg = nn.Linear(in_feats, out_feats)
        self.linear_radius = nn.ModuleList(
            [nn.Linear(in_feats, out_feats) for i in range(radius)]
        )
        self.linear_fuse = nn.Linear(in_feats, out_feats)
        self.bn = nn.BatchNorm1d(out_feats)

    def forward(self, g, feat_a, feat_b, deg, pm_pd):
        # term "prev"
        prev_proj = self.linear_prev(feat_a)
        # term "deg"
        deg_proj = self.linear_deg(deg * feat_a)

        # term "radius"
        # aggregate 2^j-hop features
        hop2j_list = aggregate_radius(self.radius, g, feat_a)
        # apply linear transformation
        hop2j_list = [
            linear(x) for linear, x in zip(self.linear_radius, hop2j_list)
        ]
        radius_proj = sum(hop2j_list)

        # term "fuse"
        fuse = self.linear_fuse(th.mm(pm_pd, feat_b))

        # sum them together
        result = prev_proj + deg_proj + radius_proj + fuse

        # skip connection and batch norm
        n = self.out_feats // 2
        result = th.cat([result[:, :n], F.relu(result[:, n:])], 1)
        result = self.bn(result)

        return result

将 LGNN 抽象串联成一个 LGNN 层

实现

\[\begin{split}\begin{split} x^{(k+1)} = {}& f(x^{(k)}, y^{(k)})\\ y^{(k+1)} = {}& f(y^{(k)}, x^{(k+1)}) \end{split}\end{split}\]

串联两个 LGNNCore 实例,如示例代码所示,并在前向传播中使用不同的参数。

class LGNNLayer(nn.Module):
    def __init__(self, in_feats, out_feats, radius):
        super(LGNNLayer, self).__init__()
        self.g_layer = LGNNCore(in_feats, out_feats, radius)
        self.lg_layer = LGNNCore(in_feats, out_feats, radius)

    def forward(self, g, lg, x, lg_x, deg_g, deg_lg, pm_pd):
        next_x = self.g_layer(g, x, lg_x, deg_g, pm_pd)
        pm_pd_y = th.transpose(pm_pd, 0, 1)
        next_lg_x = self.lg_layer(lg, lg_x, x, deg_lg, pm_pd_y)
        return next_x, next_lg_x

串联 LGNN 层

定义一个具有三个隐藏层的 LGNN,如下例所示。

class LGNN(nn.Module):
    def __init__(self, radius):
        super(LGNN, self).__init__()
        self.layer1 = LGNNLayer(1, 16, radius)  # input is scalar feature
        self.layer2 = LGNNLayer(16, 16, radius)  # hidden size is 16
        self.layer3 = LGNNLayer(16, 16, radius)
        self.linear = nn.Linear(16, 2)  # predice two classes

    def forward(self, g, lg, pm_pd):
        # compute the degrees
        deg_g = g.in_degrees().float().unsqueeze(1)
        deg_lg = lg.in_degrees().float().unsqueeze(1)
        # use degree as the input feature
        x, lg_x = deg_g, deg_lg
        x, lg_x = self.layer1(g, lg, x, lg_x, deg_g, deg_lg, pm_pd)
        x, lg_x = self.layer2(g, lg, x, lg_x, deg_g, deg_lg, pm_pd)
        x, lg_x = self.layer3(g, lg, x, lg_x, deg_g, deg_lg, pm_pd)
        return self.linear(x)

训练与推理

首先加载数据。

from torch.utils.data import DataLoader

training_loader = DataLoader(
    train_set, batch_size=1, collate_fn=train_set.collate_fn, drop_last=True
)

接下来,定义主要的训练循环。请注意,每个训练样本包含三个对象:一个 DGLGraph,一个 SciPy 稀疏矩阵 pmpd,以及一个 numpy.ndarray 中的标签数组。使用以下命令生成线图

lg = g.line_graph(backtracking=False)

请注意,需要 backtracking=False 来正确模拟非回溯操作。我们还定义了一个工具函数,将 SciPy 稀疏矩阵转换为 torch 稀疏张量。

# Create the model
model = LGNN(radius=3)
# define the optimizer
optimizer = th.optim.Adam(model.parameters(), lr=1e-2)

# A utility function to convert a scipy.coo_matrix to torch.SparseFloat
def sparse2th(mat):
    value = mat.data
    indices = th.LongTensor([mat.row, mat.col])
    tensor = th.sparse.FloatTensor(
        indices, th.from_numpy(value).float(), mat.shape
    )
    return tensor


# Train for 20 epochs
for i in range(20):
    all_loss = []
    all_acc = []
    for [g, pmpd, label] in training_loader:
        # Generate the line graph.
        lg = g.line_graph(backtracking=False)
        # Create torch tensors
        pmpd = sparse2th(pmpd)
        label = th.from_numpy(label)

        # Forward
        z = model(g, lg, pmpd)

        # Calculate loss:
        # Since there are only two communities, there are only two permutations
        #  of the community labels.
        loss_perm1 = F.cross_entropy(z, label)
        loss_perm2 = F.cross_entropy(z, 1 - label)
        loss = th.min(loss_perm1, loss_perm2)

        # Calculate accuracy:
        _, pred = th.max(z, 1)
        acc_perm1 = (pred == label).float().mean()
        acc_perm2 = (pred == 1 - label).float().mean()
        acc = th.max(acc_perm1, acc_perm2)
        all_loss.append(loss.item())
        all_acc.append(acc.item())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    niters = len(all_loss)
    print(
        "Epoch %d | loss %.4f | accuracy %.4f"
        % (i, sum(all_loss) / niters, sum(all_acc) / niters)
    )
/dgl/tutorials/models/1_gnn/6_line_graph.py:561: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at ../torch/csrc/utils/tensor_new.cpp:261.)
  indices = th.LongTensor([mat.row, mat.col])
/dgl/tutorials/models/1_gnn/6_line_graph.py:562: UserWarning: torch.sparse.SparseTensor(indices, values, shape, *, device=) is deprecated.  Please use torch.sparse_coo_tensor(indices, values, shape, dtype=, device=). (Triggered internally at ../torch/csrc/utils/tensor_new.cpp:605.)
  tensor = th.sparse.FloatTensor(
Epoch 0 | loss 0.5662 | accuracy 0.7168
Epoch 1 | loss 0.4632 | accuracy 0.7869
Epoch 2 | loss 0.4661 | accuracy 0.7697
Epoch 3 | loss 0.4650 | accuracy 0.7766
Epoch 4 | loss 0.4762 | accuracy 0.7747
Epoch 5 | loss 0.4929 | accuracy 0.7740
Epoch 6 | loss 0.4692 | accuracy 0.7904
Epoch 7 | loss 0.4506 | accuracy 0.7957
Epoch 8 | loss 0.4505 | accuracy 0.8016
Epoch 9 | loss 0.4572 | accuracy 0.7885
Epoch 10 | loss 0.4554 | accuracy 0.7876
Epoch 11 | loss 0.4506 | accuracy 0.8054
Epoch 12 | loss 0.4357 | accuracy 0.8028
Epoch 13 | loss 0.4170 | accuracy 0.8170
Epoch 14 | loss 0.4095 | accuracy 0.8164
Epoch 15 | loss 0.4155 | accuracy 0.8116
Epoch 16 | loss 0.4022 | accuracy 0.8235
Epoch 17 | loss 0.4103 | accuracy 0.8172
Epoch 18 | loss 0.4225 | accuracy 0.7993
Epoch 19 | loss 0.3940 | accuracy 0.8235

可视化训练过程

您可以可视化网络在一个训练示例上的社区预测,并与真实标签进行对比。使用以下代码示例开始。

pmpd1 = sparse2th(pmpd1)
LG1 = G1.line_graph(backtracking=False)
z = model(G1, LG1, pmpd1)
_, pred = th.max(z, 1)
visualize(pred, nx_G1)
6 line graph

与真实标签进行对比。请注意,两个社区的颜色可能颠倒了,因为模型的目的是正确预测划分。

visualize(label1, nx_G1)
6 line graph

这里有一个动画可以更好地理解这个过程。(40 轮)

lgnn-anim

批量处理图以实现并行

LGNN 处理一组不同的图。您可能会考虑是否可以使用批量处理来实现并行。

批量处理已经内置在数据加载器本身中。在 PyTorch 数据加载器的 collate_fn 中,图使用 DGL 的 batched_graph API 进行批量处理。DGL 通过将图合并到一个大图中来批量处理图,其中每个小图的邻接矩阵是大图邻接矩阵对角线上的一个块。将 :math`{Pm,Pd}` 拼接为块对角矩阵,以与 DGL 批量图 API 对应。

def collate_fn(batch):
    graphs, pmpds, labels = zip(*batch)
    batched_graphs = dgl.batch(graphs)
    batched_pmpds = sp.block_diag(pmpds)
    batched_labels = np.concatenate(labels, axis=0)
    return batched_graphs, batched_pmpds, batched_labels

您可以在 Github 上找到完整的代码:Community Detection with Graph Neural Networks (CDGNN)

脚本总运行时间: (0 分钟 22.600 秒)

由 Sphinx-Gallery 生成的图库