PGExplainer
- class dgl.nn.pytorch.explain.PGExplainer(model, num_features, num_hops=None, explain_graph=True, coff_budget=0.01, coff_connect=0.0005, sample_bias=0.0)[source]
基类:
Module
PGExplainer,出自 Parameterized Explainer for Graph Neural Network <https://arxiv.org/pdf/2011.04573>
PGExplainer 采用深度神经网络(解释网络)来参数化解释的生成过程,使其能够集体解释多个实例。PGExplainer 将底层结构建模为边分布,并从中采样解释性图。
- 参数:
model (nn.Module)
用于解释处理多类图分类的 GNN 模型
其 forward 函数必须具有
forward(self, graph, nfeat, embed, edge_weight)
的形式。如果 embed=False,其 forward 函数的输出是 logits;否则是中间节点嵌入。
num_features (int) –
model
使用的节点嵌入大小。num_hops (int, optional) – GNN 信息聚合的跳数,必须与待解释 GNN 所使用的消息传递层数匹配。
explain_graph (bool, optional) – 是否初始化模型用于图级别或节点级别预测。
coff_budget (float, optional) – 大小正则化项,用于约束解释的大小。默认值: 0.01。
coff_connect (float, optional) – 熵正则化项,用于约束解释的连通性。默认值: 5e-4。
sample_bias (float, optional) – 种群中的某些成员在样本中被选中的可能性系统地高于其他成员。默认值: 0.0。
- explain_graph(graph, feat, temperature=1.0, training=False, **kwargs)[source]
学习并返回一个边掩码,该掩码在解释 GNN 对图的预测中起着关键作用。同时,返回使用根据边掩码选择的边进行的预测。
- 参数:
- 返回:
Tensor – 给定掩码图的分类概率。其形状为 \((B, L)\),其中 \(L\) 是数据集中不同类型的标签数量,\(B\) 是批次大小。
Tensor – 边权重,形状为 \((E)\),其中 \(E\) 是图中的边数。权重越高表示该边的贡献越大。
示例
>>> import torch as th >>> import torch.nn as nn >>> import dgl >>> from dgl.data import GINDataset >>> from dgl.dataloading import GraphDataLoader >>> from dgl.nn import GraphConv, PGExplainer >>> import numpy as np
>>> # Define the model >>> class Model(nn.Module): ... def __init__(self, in_feats, out_feats): ... super().__init__() ... self.conv = GraphConv(in_feats, out_feats) ... self.fc = nn.Linear(out_feats, out_feats) ... nn.init.xavier_uniform_(self.fc.weight) ... ... def forward(self, g, h, embed=False, edge_weight=None): ... h = self.conv(g, h, edge_weight=edge_weight) ... ... if embed: ... return h ... ... with g.local_scope(): ... g.ndata['h'] = h ... hg = dgl.mean_nodes(g, 'h') ... return self.fc(hg)
>>> # Load dataset >>> data = GINDataset('MUTAG', self_loop=True) >>> dataloader = GraphDataLoader(data, batch_size=64, shuffle=True)
>>> # Train the model >>> feat_size = data[0][0].ndata['attr'].shape[1] >>> model = Model(feat_size, data.gclasses) >>> criterion = nn.CrossEntropyLoss() >>> optimizer = th.optim.Adam(model.parameters(), lr=1e-2) >>> for bg, labels in dataloader: ... preds = model(bg, bg.ndata['attr']) ... loss = criterion(preds, labels) ... optimizer.zero_grad() ... loss.backward() ... optimizer.step()
>>> # Initialize the explainer >>> explainer = PGExplainer(model, data.gclasses)
>>> # Train the explainer >>> # Define explainer temperature parameter >>> init_tmp, final_tmp = 5.0, 1.0 >>> optimizer_exp = th.optim.Adam(explainer.parameters(), lr=0.01) >>> for epoch in range(20): ... tmp = float(init_tmp * np.power(final_tmp / init_tmp, epoch / 20)) ... for bg, labels in dataloader: ... loss = explainer.train_step(bg, bg.ndata['attr'], tmp) ... optimizer_exp.zero_grad() ... loss.backward() ... optimizer_exp.step()
>>> # Explain the prediction for graph 0 >>> graph, l = data[0] >>> graph_feat = graph.ndata.pop("attr") >>> probs, edge_weight = explainer.explain_graph(graph, graph_feat)
- explain_node(nodes, graph, feat, temperature=1.0, training=False, **kwargs)[source]
学习并返回一个边掩码,该掩码在解释 GNN 对提供的节点 ID 集合的预测中起着关键作用。同时,返回使用图和边掩码进行的预测。
- 参数:
- 返回:
Tensor – 给定掩码图的分类概率。其形状为 \((N, L)\),其中 \(L\) 是数据集中不同类型的节点标签数量,\(N\) 是图中的节点数。
Tensor – 边权重,形状为 \((E)\),其中 \(E\) 是图中的边数。权重越高表示该边的贡献越大。
DGLGraph – 在输入中心节点的 k-跳入邻域上诱导的批次子图集合。
Tensor – 子图中心节点的新 ID。
示例
>>> import dgl >>> import numpy as np >>> import torch
>>> # Define the model >>> class Model(torch.nn.Module): ... def __init__(self, in_feats, out_feats): ... super().__init__() ... self.conv1 = dgl.nn.GraphConv(in_feats, out_feats) ... self.conv2 = dgl.nn.GraphConv(out_feats, out_feats) ... ... def forward(self, g, h, embed=False, edge_weight=None): ... h = self.conv1(g, h, edge_weight=edge_weight) ... if embed: ... return h ... return self.conv2(g, h)
>>> # Load dataset >>> data = dgl.data.CoraGraphDataset(verbose=False) >>> g = data[0] >>> features = g.ndata["feat"] >>> labels = g.ndata["label"]
>>> # Train the model >>> model = Model(features.shape[1], data.num_classes) >>> criterion = torch.nn.CrossEntropyLoss() >>> optimizer = torch.optim.Adam(model.parameters(), lr=1e-2) >>> for epoch in range(20): ... logits = model(g, features) ... loss = criterion(logits, labels) ... optimizer.zero_grad() ... loss.backward() ... optimizer.step()
>>> # Initialize the explainer >>> explainer = dgl.nn.PGExplainer( ... model, data.num_classes, num_hops=2, explain_graph=False ... )
>>> # Train the explainer >>> # Define explainer temperature parameter >>> init_tmp, final_tmp = 5.0, 1.0 >>> optimizer_exp = torch.optim.Adam(explainer.parameters(), lr=0.01) >>> epochs = 10 >>> for epoch in range(epochs): ... tmp = float(init_tmp * np.power(final_tmp / init_tmp, epoch / epochs)) ... loss = explainer.train_step_node(g.nodes(), g, features, tmp) ... optimizer_exp.zero_grad() ... loss.backward() ... optimizer_exp.step()
>>> # Explain the prediction for graph 0 >>> probs, edge_weight, bg, inverse_indices = explainer.explain_node( ... 0, g, features ... )
- forward(*input: Any) None
定义每次调用时执行的计算。
应被所有子类覆盖。
注意
虽然前向传播的实现需要在该函数中定义,但之后应调用
Module
实例而不是直接调用此函数,因为前者负责运行已注册的钩子,而后者则默默忽略它们。