GNNExplainer

class dgl.nn.pytorch.explain.GNNExplainer(model, num_hops, lr=0.01, num_epochs=100, *, alpha1=0.005, alpha2=1.0, beta1=1.0, beta2=0.1, log=True)[source]

基类: Module

GNNExplainer 模型,来自论文 GNNExplainer: Generating Explanations for Graph Neural Networks

它识别出在基于 GNN 的节点分类和图分类中起关键作用的紧凑子图结构和少量节点特征子集。

为了生成解释,它通过优化以下目标函数来学习边掩码 \(M\) 和特征掩码 \(F\)

\[l(y, \hat{y}) + \alpha_1 \|M\|_1 + \alpha_2 H(M) + \beta_1 \|F\|_1 + \beta_2 H(F)\]

其中 \(l\) 是损失函数,\(y\) 是原始模型预测,\(\hat{y}\) 是应用边和特征掩码后的模型预测,\(H\) 是熵函数。

参数:
  • model (nn.Module) –

    待解释的 GNN 模型。

    • 其 forward 函数所需的参数为 graph 和 feat。后者用于输入节点特征。

    • 它还应可选地接受一个用于边权重的 eweight 参数,并在消息传递中将消息乘以该权重。

    • 其 forward 函数的输出是预测的节点/图类别的 logits。

    另请参见 explain_node()explain_graph() 中的示例。

  • num_hops (int) – GNN 信息聚合的跳数。

  • lr (float, optional) – 要使用的学习率,默认为 0.01。

  • num_epochs (int, optional) – 训练的 epoch 数。

  • alpha1 (float, optional) – 值越大,通过减小边掩码的总和,将使解释边掩码更稀疏。

  • alpha2 (float, optional) – 值越大,通过减小边掩码的熵,将使解释边掩码更稀疏。

  • beta1 (float, optional) – 值越大,通过减小节点特征掩码的均值,将使解释节点特征掩码更稀疏。

  • beta2 (float, optional) – 值越大,通过减小节点特征掩码的熵,将使解释节点特征掩码更稀疏。

  • log (bool, optional) – 如果为 True,将记录计算过程,默认为 True。

explain_graph(graph, feat, **kwargs)[source]

学习并返回一个节点特征掩码和边掩码,它们在解释 GNN 对图的预测中起着关键作用。

参数:
  • graph (DGLGraph) – 同构图。

  • feat (Tensor) – 输入特征,形状为 \((N, D)\)。其中 \(N\) 是节点数,\(D\) 是特征大小。

  • kwargs (dict) – 传递给 GNN 模型的附加参数。第一维是节点数或边数的张量将被假定为节点/边特征。

返回:

  • feat_mask (Tensor) – 学习到的特征重要性掩码,形状为 \((D)\),其中 \(D\) 是特征大小。值在范围 \((0, 1)\) 内。值越高,重要性越高。

  • edge_mask (Tensor) – 学习到的图中边的重要性掩码,形状为张量 \((E)\),其中 \(E\) 是图中的边数。值在范围 \((0, 1)\) 内。值越高,重要性越高。

示例

>>> import dgl.function as fn
>>> import torch
>>> import torch.nn as nn
>>> from dgl.data import GINDataset
>>> from dgl.dataloading import GraphDataLoader
>>> from dgl.nn import AvgPooling, GNNExplainer
>>> # Load dataset
>>> data = GINDataset('MUTAG', self_loop=True)
>>> dataloader = GraphDataLoader(data, batch_size=64, shuffle=True)
>>> # Define a model
>>> class Model(nn.Module):
...     def __init__(self, in_feats, out_feats):
...         super(Model, self).__init__()
...         self.linear = nn.Linear(in_feats, out_feats)
...         self.pool = AvgPooling()
...
...     def forward(self, graph, feat, eweight=None):
...         with graph.local_scope():
...             feat = self.linear(feat)
...             graph.ndata['h'] = feat
...             if eweight is None:
...                 graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
...             else:
...                 graph.edata['w'] = eweight
...                 graph.update_all(fn.u_mul_e('h', 'w', 'm'), fn.sum('m', 'h'))
...             return self.pool(graph, graph.ndata['h'])
>>> # Train the model
>>> feat_size = data[0][0].ndata['attr'].shape[1]
>>> model = Model(feat_size, data.gclasses)
>>> criterion = nn.CrossEntropyLoss()
>>> optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
>>> for bg, labels in dataloader:
...     logits = model(bg, bg.ndata['attr'])
...     loss = criterion(logits, labels)
...     optimizer.zero_grad()
...     loss.backward()
...     optimizer.step()
>>> # Explain the prediction for graph 0
>>> explainer = GNNExplainer(model, num_hops=1)
>>> g, _ = data[0]
>>> features = g.ndata['attr']
>>> feat_mask, edge_mask = explainer.explain_graph(g, features)
>>> feat_mask
tensor([0.2362, 0.2497, 0.2622, 0.2675, 0.2649, 0.2962, 0.2533])
>>> edge_mask
tensor([0.2154, 0.2235, 0.8325, ..., 0.7787, 0.1735, 0.1847])
explain_node(node_id, graph, feat, **kwargs)[source]

学习并返回一个节点特征掩码和子图,它们在解释 GNN 对节点 node_id 的预测中起着关键作用。

参数:
  • node_id (int) – 要解释的节点。

  • graph (DGLGraph) – 同构图。

  • feat (Tensor) – 输入特征,形状为 \((N, D)\)。其中 \(N\) 是节点数,\(D\) 是特征大小。

  • kwargs (dict) – 传递给 GNN 模型的附加参数。第一维是节点数或边数的张量将被假定为节点/边特征。

返回:

  • new_node_id (Tensor) – 输入中心节点的新 ID。

  • sg (DGLGraph) – 在输入中心节点的 k 跳入邻居上诱导的子图。

  • feat_mask (Tensor) – 学习到的节点特征重要性掩码,形状为 \((D)\),其中 \(D\) 是特征大小。值在范围 \((0, 1)\) 内。值越高,重要性越高。

  • edge_mask (Tensor) – 学习到的子图中边的重要性掩码,形状为张量 \((E)\),其中 \(E\) 是子图中的边数。值在范围 \((0, 1)\) 内。值越高,重要性越高。

示例

>>> import dgl
>>> import dgl.function as fn
>>> import torch
>>> import torch.nn as nn
>>> from dgl.data import CoraGraphDataset
>>> from dgl.nn import GNNExplainer
>>> # Load dataset
>>> data = CoraGraphDataset()
>>> g = data[0]
>>> features = g.ndata['feat']
>>> labels = g.ndata['label']
>>> train_mask = g.ndata['train_mask']
>>> # Define a model
>>> class Model(nn.Module):
...     def __init__(self, in_feats, out_feats):
...         super(Model, self).__init__()
...         self.linear = nn.Linear(in_feats, out_feats)
...
...     def forward(self, graph, feat, eweight=None):
...         with graph.local_scope():
...             feat = self.linear(feat)
...             graph.ndata['h'] = feat
...             if eweight is None:
...                 graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
...             else:
...                 graph.edata['w'] = eweight
...                 graph.update_all(fn.u_mul_e('h', 'w', 'm'), fn.sum('m', 'h'))
...             return graph.ndata['h']
>>> # Train the model
>>> model = Model(features.shape[1], data.num_classes)
>>> criterion = nn.CrossEntropyLoss()
>>> optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
>>> for epoch in range(10):
...     logits = model(g, features)
...     loss = criterion(logits[train_mask], labels[train_mask])
...     optimizer.zero_grad()
...     loss.backward()
...     optimizer.step()
>>> # Explain the prediction for node 10
>>> explainer = GNNExplainer(model, num_hops=1)
>>> new_center, sg, feat_mask, edge_mask = explainer.explain_node(10, g, features)
>>> new_center
tensor([1])
>>> sg.num_edges()
12
>>> # Old IDs of the nodes in the subgraph
>>> sg.ndata[dgl.NID]
tensor([ 9, 10, 11, 12])
>>> # Old IDs of the edges in the subgraph
>>> sg.edata[dgl.EID]
tensor([51, 53, 56, 48, 52, 57, 47, 50, 55, 46, 49, 54])
>>> feat_mask
tensor([0.2638, 0.2738, 0.3039,  ..., 0.2794, 0.2643, 0.2733])
>>> edge_mask
tensor([0.0937, 0.1496, 0.8287, 0.8132, 0.8825, 0.8515, 0.8146, 0.0915, 0.1145,
        0.9011, 0.1311, 0.8437])
forward(*input: Any) None

定义每次调用时执行的计算。

应由所有子类重写。

注意

尽管 forward pass 的实现需要在此函数中定义,但之后应调用 Module 实例而不是直接调用此函数,因为前者会负责运行已注册的钩子,而后者则会默默忽略它们。