GNNExplainer
- class dgl.nn.pytorch.explain.GNNExplainer(model, num_hops, lr=0.01, num_epochs=100, *, alpha1=0.005, alpha2=1.0, beta1=1.0, beta2=0.1, log=True)[source]
基类:
Module
GNNExplainer 模型,来自论文 GNNExplainer: Generating Explanations for Graph Neural Networks
它识别出在基于 GNN 的节点分类和图分类中起关键作用的紧凑子图结构和少量节点特征子集。
为了生成解释,它通过优化以下目标函数来学习边掩码 \(M\) 和特征掩码 \(F\)。
\[l(y, \hat{y}) + \alpha_1 \|M\|_1 + \alpha_2 H(M) + \beta_1 \|F\|_1 + \beta_2 H(F)\]其中 \(l\) 是损失函数,\(y\) 是原始模型预测,\(\hat{y}\) 是应用边和特征掩码后的模型预测,\(H\) 是熵函数。
- 参数:
model (nn.Module) –
待解释的 GNN 模型。
其 forward 函数所需的参数为 graph 和 feat。后者用于输入节点特征。
它还应可选地接受一个用于边权重的 eweight 参数,并在消息传递中将消息乘以该权重。
其 forward 函数的输出是预测的节点/图类别的 logits。
另请参见
explain_node()
和explain_graph()
中的示例。num_hops (int) – GNN 信息聚合的跳数。
lr (float, optional) – 要使用的学习率,默认为 0.01。
num_epochs (int, optional) – 训练的 epoch 数。
alpha1 (float, optional) – 值越大,通过减小边掩码的总和,将使解释边掩码更稀疏。
alpha2 (float, optional) – 值越大,通过减小边掩码的熵,将使解释边掩码更稀疏。
beta1 (float, optional) – 值越大,通过减小节点特征掩码的均值,将使解释节点特征掩码更稀疏。
beta2 (float, optional) – 值越大,通过减小节点特征掩码的熵,将使解释节点特征掩码更稀疏。
log (bool, optional) – 如果为 True,将记录计算过程,默认为 True。
- explain_graph(graph, feat, **kwargs)[source]
学习并返回一个节点特征掩码和边掩码,它们在解释 GNN 对图的预测中起着关键作用。
- 参数:
- 返回:
feat_mask (Tensor) – 学习到的特征重要性掩码,形状为 \((D)\),其中 \(D\) 是特征大小。值在范围 \((0, 1)\) 内。值越高,重要性越高。
edge_mask (Tensor) – 学习到的图中边的重要性掩码,形状为张量 \((E)\),其中 \(E\) 是图中的边数。值在范围 \((0, 1)\) 内。值越高,重要性越高。
示例
>>> import dgl.function as fn >>> import torch >>> import torch.nn as nn >>> from dgl.data import GINDataset >>> from dgl.dataloading import GraphDataLoader >>> from dgl.nn import AvgPooling, GNNExplainer
>>> # Load dataset >>> data = GINDataset('MUTAG', self_loop=True) >>> dataloader = GraphDataLoader(data, batch_size=64, shuffle=True)
>>> # Define a model >>> class Model(nn.Module): ... def __init__(self, in_feats, out_feats): ... super(Model, self).__init__() ... self.linear = nn.Linear(in_feats, out_feats) ... self.pool = AvgPooling() ... ... def forward(self, graph, feat, eweight=None): ... with graph.local_scope(): ... feat = self.linear(feat) ... graph.ndata['h'] = feat ... if eweight is None: ... graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h')) ... else: ... graph.edata['w'] = eweight ... graph.update_all(fn.u_mul_e('h', 'w', 'm'), fn.sum('m', 'h')) ... return self.pool(graph, graph.ndata['h'])
>>> # Train the model >>> feat_size = data[0][0].ndata['attr'].shape[1] >>> model = Model(feat_size, data.gclasses) >>> criterion = nn.CrossEntropyLoss() >>> optimizer = torch.optim.Adam(model.parameters(), lr=1e-2) >>> for bg, labels in dataloader: ... logits = model(bg, bg.ndata['attr']) ... loss = criterion(logits, labels) ... optimizer.zero_grad() ... loss.backward() ... optimizer.step()
>>> # Explain the prediction for graph 0 >>> explainer = GNNExplainer(model, num_hops=1) >>> g, _ = data[0] >>> features = g.ndata['attr'] >>> feat_mask, edge_mask = explainer.explain_graph(g, features) >>> feat_mask tensor([0.2362, 0.2497, 0.2622, 0.2675, 0.2649, 0.2962, 0.2533]) >>> edge_mask tensor([0.2154, 0.2235, 0.8325, ..., 0.7787, 0.1735, 0.1847])
- explain_node(node_id, graph, feat, **kwargs)[source]
学习并返回一个节点特征掩码和子图,它们在解释 GNN 对节点
node_id
的预测中起着关键作用。- 参数:
- 返回:
new_node_id (Tensor) – 输入中心节点的新 ID。
sg (DGLGraph) – 在输入中心节点的 k 跳入邻居上诱导的子图。
feat_mask (Tensor) – 学习到的节点特征重要性掩码,形状为 \((D)\),其中 \(D\) 是特征大小。值在范围 \((0, 1)\) 内。值越高,重要性越高。
edge_mask (Tensor) – 学习到的子图中边的重要性掩码,形状为张量 \((E)\),其中 \(E\) 是子图中的边数。值在范围 \((0, 1)\) 内。值越高,重要性越高。
示例
>>> import dgl >>> import dgl.function as fn >>> import torch >>> import torch.nn as nn >>> from dgl.data import CoraGraphDataset >>> from dgl.nn import GNNExplainer
>>> # Load dataset >>> data = CoraGraphDataset() >>> g = data[0] >>> features = g.ndata['feat'] >>> labels = g.ndata['label'] >>> train_mask = g.ndata['train_mask']
>>> # Define a model >>> class Model(nn.Module): ... def __init__(self, in_feats, out_feats): ... super(Model, self).__init__() ... self.linear = nn.Linear(in_feats, out_feats) ... ... def forward(self, graph, feat, eweight=None): ... with graph.local_scope(): ... feat = self.linear(feat) ... graph.ndata['h'] = feat ... if eweight is None: ... graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h')) ... else: ... graph.edata['w'] = eweight ... graph.update_all(fn.u_mul_e('h', 'w', 'm'), fn.sum('m', 'h')) ... return graph.ndata['h']
>>> # Train the model >>> model = Model(features.shape[1], data.num_classes) >>> criterion = nn.CrossEntropyLoss() >>> optimizer = torch.optim.Adam(model.parameters(), lr=1e-2) >>> for epoch in range(10): ... logits = model(g, features) ... loss = criterion(logits[train_mask], labels[train_mask]) ... optimizer.zero_grad() ... loss.backward() ... optimizer.step()
>>> # Explain the prediction for node 10 >>> explainer = GNNExplainer(model, num_hops=1) >>> new_center, sg, feat_mask, edge_mask = explainer.explain_node(10, g, features) >>> new_center tensor([1]) >>> sg.num_edges() 12 >>> # Old IDs of the nodes in the subgraph >>> sg.ndata[dgl.NID] tensor([ 9, 10, 11, 12]) >>> # Old IDs of the edges in the subgraph >>> sg.edata[dgl.EID] tensor([51, 53, 56, 48, 52, 57, 47, 50, 55, 46, 49, 54]) >>> feat_mask tensor([0.2638, 0.2738, 0.3039, ..., 0.2794, 0.2643, 0.2733]) >>> edge_mask tensor([0.0937, 0.1496, 0.8287, 0.8132, 0.8825, 0.8515, 0.8146, 0.0915, 0.1145, 0.9011, 0.1311, 0.8437])