PGExplainer

class dgl.nn.pytorch.explain.PGExplainer(model, num_features, num_hops=None, explain_graph=True, coff_budget=0.01, coff_connect=0.0005, sample_bias=0.0)[source]

基类: Module

PGExplainer,出自 Parameterized Explainer for Graph Neural Network <https://arxiv.org/pdf/2011.04573>

PGExplainer 采用深度神经网络(解释网络)来参数化解释的生成过程,使其能够集体解释多个实例。PGExplainer 将底层结构建模为边分布,并从中采样解释性图。

参数:
  • model (nn.Module)

    用于解释处理多类图分类的 GNN 模型

    • 其 forward 函数必须具有 forward(self, graph, nfeat, embed, edge_weight) 的形式。

    • 如果 embed=False,其 forward 函数的输出是 logits;否则是中间节点嵌入。

  • num_features (int) – model 使用的节点嵌入大小。

  • num_hops (int, optional) – GNN 信息聚合的跳数,必须与待解释 GNN 所使用的消息传递层数匹配。

  • explain_graph (bool, optional) – 是否初始化模型用于图级别或节点级别预测。

  • coff_budget (float, optional) – 大小正则化项,用于约束解释的大小。默认值: 0.01。

  • coff_connect (float, optional) – 熵正则化项,用于约束解释的连通性。默认值: 5e-4。

  • sample_bias (float, optional) – 种群中的某些成员在样本中被选中的可能性系统地高于其他成员。默认值: 0.0。

explain_graph(graph, feat, temperature=1.0, training=False, **kwargs)[source]

学习并返回一个边掩码,该掩码在解释 GNN 对图的预测中起着关键作用。同时,返回使用根据边掩码选择的边进行的预测。

参数:
  • graph (DGLGraph) – 同构图。

  • feat (Tensor) – 输入特征,形状为 \((N, D)\)\(N\) 是节点数,\(D\) 是特征大小。

  • temperature (float) – 采样过程中使用的温度参数。

  • training (bool) – 是否训练解释网络。

  • kwargs (dict) – 传递给 GNN 模型的额外参数。

返回:

  • Tensor – 给定掩码图的分类概率。其形状为 \((B, L)\),其中 \(L\) 是数据集中不同类型的标签数量,\(B\) 是批次大小。

  • Tensor – 边权重,形状为 \((E)\),其中 \(E\) 是图中的边数。权重越高表示该边的贡献越大。

示例

>>> import torch as th
>>> import torch.nn as nn
>>> import dgl
>>> from dgl.data import GINDataset
>>> from dgl.dataloading import GraphDataLoader
>>> from dgl.nn import GraphConv, PGExplainer
>>> import numpy as np
>>> # Define the model
>>> class Model(nn.Module):
...     def __init__(self, in_feats, out_feats):
...         super().__init__()
...         self.conv = GraphConv(in_feats, out_feats)
...         self.fc = nn.Linear(out_feats, out_feats)
...         nn.init.xavier_uniform_(self.fc.weight)
...
...     def forward(self, g, h, embed=False, edge_weight=None):
...         h = self.conv(g, h, edge_weight=edge_weight)
...
...         if embed:
...             return h
...
...         with g.local_scope():
...             g.ndata['h'] = h
...             hg = dgl.mean_nodes(g, 'h')
...             return self.fc(hg)
>>> # Load dataset
>>> data = GINDataset('MUTAG', self_loop=True)
>>> dataloader = GraphDataLoader(data, batch_size=64, shuffle=True)
>>> # Train the model
>>> feat_size = data[0][0].ndata['attr'].shape[1]
>>> model = Model(feat_size, data.gclasses)
>>> criterion = nn.CrossEntropyLoss()
>>> optimizer = th.optim.Adam(model.parameters(), lr=1e-2)
>>> for bg, labels in dataloader:
...     preds = model(bg, bg.ndata['attr'])
...     loss = criterion(preds, labels)
...     optimizer.zero_grad()
...     loss.backward()
...     optimizer.step()
>>> # Initialize the explainer
>>> explainer = PGExplainer(model, data.gclasses)
>>> # Train the explainer
>>> # Define explainer temperature parameter
>>> init_tmp, final_tmp = 5.0, 1.0
>>> optimizer_exp = th.optim.Adam(explainer.parameters(), lr=0.01)
>>> for epoch in range(20):
...     tmp = float(init_tmp * np.power(final_tmp / init_tmp, epoch / 20))
...     for bg, labels in dataloader:
...          loss = explainer.train_step(bg, bg.ndata['attr'], tmp)
...          optimizer_exp.zero_grad()
...          loss.backward()
...          optimizer_exp.step()
>>> # Explain the prediction for graph 0
>>> graph, l = data[0]
>>> graph_feat = graph.ndata.pop("attr")
>>> probs, edge_weight = explainer.explain_graph(graph, graph_feat)
explain_node(nodes, graph, feat, temperature=1.0, training=False, **kwargs)[source]

学习并返回一个边掩码,该掩码在解释 GNN 对提供的节点 ID 集合的预测中起着关键作用。同时,返回使用图和边掩码进行的预测。

参数:
  • nodes (int, iterable[int], tensor) – 图中的节点,不能包含重复值。

  • graph (DGLGraph) – 同构图。

  • feat (Tensor) – 输入特征,形状为 \((N, D)\)\(N\) 是节点数,\(D\) 是特征大小。

  • temperature (float) – 采样过程中使用的温度参数。

  • training (bool) – 是否训练解释网络。

  • kwargs (dict) – 传递给 GNN 模型的额外参数。

返回:

  • Tensor – 给定掩码图的分类概率。其形状为 \((N, L)\),其中 \(L\) 是数据集中不同类型的节点标签数量,\(N\) 是图中的节点数。

  • Tensor – 边权重,形状为 \((E)\),其中 \(E\) 是图中的边数。权重越高表示该边的贡献越大。

  • DGLGraph – 在输入中心节点的 k-跳入邻域上诱导的批次子图集合。

  • Tensor – 子图中心节点的新 ID。

示例

>>> import dgl
>>> import numpy as np
>>> import torch
>>> # Define the model
>>> class Model(torch.nn.Module):
...     def __init__(self, in_feats, out_feats):
...         super().__init__()
...         self.conv1 = dgl.nn.GraphConv(in_feats, out_feats)
...         self.conv2 = dgl.nn.GraphConv(out_feats, out_feats)
...
...     def forward(self, g, h, embed=False, edge_weight=None):
...         h = self.conv1(g, h, edge_weight=edge_weight)
...         if embed:
...             return h
...         return self.conv2(g, h)
>>> # Load dataset
>>> data = dgl.data.CoraGraphDataset(verbose=False)
>>> g = data[0]
>>> features = g.ndata["feat"]
>>> labels = g.ndata["label"]
>>> # Train the model
>>> model = Model(features.shape[1], data.num_classes)
>>> criterion = torch.nn.CrossEntropyLoss()
>>> optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
>>> for epoch in range(20):
...     logits = model(g, features)
...     loss = criterion(logits, labels)
...     optimizer.zero_grad()
...     loss.backward()
...     optimizer.step()
>>> # Initialize the explainer
>>> explainer = dgl.nn.PGExplainer(
...     model, data.num_classes, num_hops=2, explain_graph=False
... )
>>> # Train the explainer
>>> # Define explainer temperature parameter
>>> init_tmp, final_tmp = 5.0, 1.0
>>> optimizer_exp = torch.optim.Adam(explainer.parameters(), lr=0.01)
>>> epochs = 10
>>> for epoch in range(epochs):
...     tmp = float(init_tmp * np.power(final_tmp / init_tmp, epoch / epochs))
...     loss = explainer.train_step_node(g.nodes(), g, features, tmp)
...     optimizer_exp.zero_grad()
...     loss.backward()
...     optimizer_exp.step()
>>> # Explain the prediction for graph 0
>>> probs, edge_weight, bg, inverse_indices = explainer.explain_node(
...     0, g, features
... )
forward(*input: Any) None

定义每次调用时执行的计算。

应被所有子类覆盖。

注意

虽然前向传播的实现需要在该函数中定义,但之后应调用 Module 实例而不是直接调用此函数,因为前者负责运行已注册的钩子,而后者则默默忽略它们。

train_step(graph, feat, temperature, **kwargs)[source]

计算用于图分类的解释网络的损失

参数:
  • graph (DGLGraph) – 输入的批次同构图。

  • feat (Tensor) – 输入特征,形状为 \((N, D)\)\(N\) 是节点数,\(D\) 是特征大小。

  • temperature (float) – 采样过程中使用的温度参数。

  • kwargs (dict) – 传递给 GNN 模型的额外参数。

返回:

表示损失的标量张量。

返回类型:

Tensor

train_step_node(nodes, graph, feat, temperature, **kwargs)[source]

计算用于节点分类的解释网络的损失

参数:
  • nodes (int, iterable[int], tensor) – 用于训练解释网络的图中的节点,不能包含重复值。

  • graph (DGLGraph) – 输入的同构图。

  • feat (Tensor) – 输入特征,形状为 \((N, D)\)\(N\) 是节点数,\(D\) 是特征大小。

  • temperature (float) – 采样过程中使用的温度参数。

  • kwargs (dict) – 传递给 GNN 模型的额外参数。

返回:

表示损失的标量张量。

返回类型:

Tensor