SubgraphX

class dgl.nn.pytorch.explain.SubgraphX(model, num_hops, coef=10.0, high2low=True, num_child=12, num_rollouts=20, node_min=3, shapley_steps=100, log=False)[源码]

基类: Module

SubgraphX 来自 On Explainability of Graph Neural Networks via Subgraph Explorations <https://arxiv.org/abs/2102.05152>

它识别原始图中最重要的子图,该子图在基于 GNN 的图分类中起关键作用。

它采用蒙特卡洛树搜索 (MCTS) 高效探索不同的子图进行解释,并使用 Shapley 值作为子图重要性的度量。

参数:
  • model (nn.Module) –

    要解释的 GNN 模型,用于解决多分类图分类问题

    • 其 forward 函数必须具有 forward(self, graph, nfeat) 的形式。

    • 其 forward 函数的输出是 logits。

  • num_hops (int) – 模型中的消息传递层数

  • coef (float, optional) – 此超参数控制探索和利用之间的权衡。值越高,算法越倾向于探索相对未访问的节点。默认值: 10.0

  • high2low (bool, optional) – 如果为 True,在搜索树中扩展子节点时,它将使用“High2low”策略进行剪枝操作,从高度节点扩展到低度节点。否则,它将使用“Low2high”策略。默认值: True

  • num_child (int, optional) – 这是在搜索树中扩展子节点时要扩展的子节点数量。默认值: 12

  • num_rollouts (int, optional) – 这是 MCTS 的模拟次数(rollouts)。默认值: 20

  • node_min (int, optional) – 这是根据子图中的节点数量定义叶节点的阈值。默认值: 3

  • shapley_steps (int, optional) – 这是蒙特卡洛采样估计 Shapley 值时的步数。默认值: 100

  • log (bool, optional) – 如果为 True,将记录进度。默认值: False

explain_graph(graph, feat, target_class, **kwargs)[源码]

找到原始图中最重要的子图,以便模型将图分类到目标类别。

参数:
  • graph (DGLGraph) – 同构图

  • feat (Tensor) – 输入节点特征,形状为 \((N, D)\),其中 \(N\) 是节点数量,\(D\) 是特征维度

  • target_class (int) – 要解释的目标类别

  • kwargs (dict) – 传递给 GNN 模型的额外参数

返回值:

表示最重要子图的节点

返回类型:

Tensor

示例

>>> import torch
>>> import torch.nn as nn
>>> import torch.nn.functional as F
>>> from dgl.data import GINDataset
>>> from dgl.dataloading import GraphDataLoader
>>> from dgl.nn import GraphConv, AvgPooling, SubgraphX
>>> # Define the model
>>> class Model(nn.Module):
...     def __init__(self, in_dim, n_classes, hidden_dim=128):
...         super().__init__()
...         self.conv1 = GraphConv(in_dim, hidden_dim)
...         self.conv2 = GraphConv(hidden_dim, n_classes)
...         self.pool = AvgPooling()
...
...     def forward(self, g, h):
...         h = F.relu(self.conv1(g, h))
...         h = self.conv2(g, h)
...         return self.pool(g, h)
>>> # Load dataset
>>> data = GINDataset('MUTAG', self_loop=True)
>>> dataloader = GraphDataLoader(data, batch_size=64, shuffle=True)
>>> # Train the model
>>> feat_size = data[0][0].ndata['attr'].shape[1]
>>> model = Model(feat_size, data.gclasses)
>>> criterion = nn.CrossEntropyLoss()
>>> optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
>>> for bg, labels in dataloader:
...     logits = model(bg, bg.ndata['attr'])
...     loss = criterion(logits, labels)
...     optimizer.zero_grad()
...     loss.backward()
...     optimizer.step()
>>> # Initialize the explainer
>>> explainer = SubgraphX(model, num_hops=2)
>>> # Explain the prediction for graph 0
>>> graph, l = data[0]
>>> graph_feat = graph.ndata.pop("attr")
>>> g_nodes_explain = explainer.explain_graph(graph, graph_feat,
...                                           target_class=l)
forward(*input: Any) None

定义每次调用时执行的计算。

应由所有子类覆盖。

注意

尽管需要在函数内定义前向传播的实现,但后续应调用 Module 实例而不是此函数本身,因为前者会处理注册的钩子,而后者会默默忽略它们。