SubgraphX
- class dgl.nn.pytorch.explain.SubgraphX(model, num_hops, coef=10.0, high2low=True, num_child=12, num_rollouts=20, node_min=3, shapley_steps=100, log=False)[源码]
基类:
Module
SubgraphX 来自 On Explainability of Graph Neural Networks via Subgraph Explorations <https://arxiv.org/abs/2102.05152>
它识别原始图中最重要的子图,该子图在基于 GNN 的图分类中起关键作用。
它采用蒙特卡洛树搜索 (MCTS) 高效探索不同的子图进行解释,并使用 Shapley 值作为子图重要性的度量。
- 参数:
model (nn.Module) –
要解释的 GNN 模型,用于解决多分类图分类问题
其 forward 函数必须具有
forward(self, graph, nfeat)
的形式。其 forward 函数的输出是 logits。
num_hops (int) – 模型中的消息传递层数
coef (float, optional) – 此超参数控制探索和利用之间的权衡。值越高,算法越倾向于探索相对未访问的节点。默认值: 10.0
high2low (bool, optional) – 如果为 True,在搜索树中扩展子节点时,它将使用“High2low”策略进行剪枝操作,从高度节点扩展到低度节点。否则,它将使用“Low2high”策略。默认值: True
num_child (int, optional) – 这是在搜索树中扩展子节点时要扩展的子节点数量。默认值: 12
num_rollouts (int, optional) – 这是 MCTS 的模拟次数(rollouts)。默认值: 20
node_min (int, optional) – 这是根据子图中的节点数量定义叶节点的阈值。默认值: 3
shapley_steps (int, optional) – 这是蒙特卡洛采样估计 Shapley 值时的步数。默认值: 100
log (bool, optional) – 如果为 True,将记录进度。默认值: False
- explain_graph(graph, feat, target_class, **kwargs)[源码]
找到原始图中最重要的子图,以便模型将图分类到目标类别。
- 参数:
- 返回值:
表示最重要子图的节点
- 返回类型:
Tensor
示例
>>> import torch >>> import torch.nn as nn >>> import torch.nn.functional as F >>> from dgl.data import GINDataset >>> from dgl.dataloading import GraphDataLoader >>> from dgl.nn import GraphConv, AvgPooling, SubgraphX
>>> # Define the model >>> class Model(nn.Module): ... def __init__(self, in_dim, n_classes, hidden_dim=128): ... super().__init__() ... self.conv1 = GraphConv(in_dim, hidden_dim) ... self.conv2 = GraphConv(hidden_dim, n_classes) ... self.pool = AvgPooling() ... ... def forward(self, g, h): ... h = F.relu(self.conv1(g, h)) ... h = self.conv2(g, h) ... return self.pool(g, h)
>>> # Load dataset >>> data = GINDataset('MUTAG', self_loop=True) >>> dataloader = GraphDataLoader(data, batch_size=64, shuffle=True)
>>> # Train the model >>> feat_size = data[0][0].ndata['attr'].shape[1] >>> model = Model(feat_size, data.gclasses) >>> criterion = nn.CrossEntropyLoss() >>> optimizer = torch.optim.Adam(model.parameters(), lr=1e-2) >>> for bg, labels in dataloader: ... logits = model(bg, bg.ndata['attr']) ... loss = criterion(logits, labels) ... optimizer.zero_grad() ... loss.backward() ... optimizer.step()
>>> # Initialize the explainer >>> explainer = SubgraphX(model, num_hops=2)
>>> # Explain the prediction for graph 0 >>> graph, l = data[0] >>> graph_feat = graph.ndata.pop("attr") >>> g_nodes_explain = explainer.explain_graph(graph, graph_feat, ... target_class=l)