HeteroPGExplainer

class dgl.nn.pytorch.explain.HeteroPGExplainer(model, num_features, num_hops=None, explain_graph=True, coff_budget=0.01, coff_connect=0.0005, sample_bias=0.0)[source]

基类:PGExplainer

PGExplainer 来自 Parameterized Explainer for Graph Neural Network,适用于异构图

PGExplainer 采用深度神经网络(解释网络)来参数化解释的生成过程,这使其能够集体解释多个实例。PGExplainer 将底层结构建模为边分布,并从中采样生成解释图。

参数
  • model (nn.Module) —

    用于解释的 GNN 模型,处理多类别图分类

    • 其 forward 函数必须具有以下形式:forward(self, graph, nfeat, embed, edge_weight)

    • 如果 embed=False,其 forward 函数的输出是 logits;否则是中间节点嵌入。

  • num_features (int) — model 使用的节点嵌入尺寸。

  • coff_budget (float, 可选) — 用于约束解释尺寸的尺寸正则化项。默认值:0.01。

  • coff_connect (float, 可选) — 用于约束解释连通性的熵正则化项。默认值:5e-4。

  • sample_bias (float, 可选) — 总体的某些成员在样本中被选中的可能性系统地高于其他成员。默认值:0.0。

explain_graph(graph, feat, temperature=1.0, training=False, **kwargs)[source]

学习并返回一个边掩码,该掩码在解释 GNN 对图做出的预测中起关键作用。同时,返回基于该边掩码选择的边所做出的预测。

参数
  • graph (DGLGraph) — 异构图。

  • feat (dict[str, Tensor]) — 将节点类型(键)映射到特征张量(值)的字典。输入特征的形状为 \((N_t, D_t)\)\(N_t\) 是节点类型 \(t\) 的节点数,\(D_t\) 是节点类型 \(t\) 的特征尺寸。

  • temperature (float) — 提供给采样过程的温度参数。

  • training (bool) — 是否训练解释网络。

  • kwargs (dict) — 传递给 GNN 模型的附加参数。

返回

  • Tensor — 给定掩码图的分类概率。这是形状为 \((B, L)\) 的张量,其中 \(L\) 是数据集中标签的不同类型,\(B\) 是批处理大小。

  • dict[str, Tensor] — 将边类型(键)映射到形状为 \((E_t)\) 的边张量(值)的字典,其中 \(E_t\) 是图中断边类型 \(t\) 的边数。权重越高表示该边的贡献越大。

示例

>>> import dgl
>>> import torch as th
>>> import torch.nn as nn
>>> import numpy as np
>>> # Define the model
>>> class Model(nn.Module):
...     def __init__(self, in_feats, hid_feats, out_feats, rel_names):
...         super().__init__()
...         self.conv = dgl.nn.HeteroGraphConv(
...             {rel: dgl.nn.GraphConv(in_feats, hid_feats) for rel in rel_names},
...             aggregate="sum",
...         )
...         self.fc = nn.Linear(hid_feats, out_feats)
...         nn.init.xavier_uniform_(self.fc.weight)
...
...     def forward(self, g, h, embed=False, edge_weight=None):
...         if edge_weight:
...             mod_kwargs = {
...                 etype: {"edge_weight": mask} for etype, mask in edge_weight.items()
...             }
...             h = self.conv(g, h, mod_kwargs=mod_kwargs)
...         else:
...             h = self.conv(g, h)
...
...         if embed:
...             return h
...
...         with g.local_scope():
...             g.ndata["h"] = h
...             hg = 0
...             for ntype in g.ntypes:
...                 hg = hg + dgl.mean_nodes(g, "h", ntype=ntype)
...             return self.fc(hg)
>>> # Load dataset
>>> input_dim = 5
>>> hidden_dim = 5
>>> num_classes = 2
>>> g = dgl.heterograph({("user", "plays", "game"): ([0, 1, 1, 2], [0, 0, 1, 1])})
>>> g.nodes["user"].data["h"] = th.randn(g.num_nodes("user"), input_dim)
>>> g.nodes["game"].data["h"] = th.randn(g.num_nodes("game"), input_dim)
>>> transform = dgl.transforms.AddReverse()
>>> g = transform(g)
>>> # define and train the model
>>> model = Model(input_dim, hidden_dim, num_classes, g.canonical_etypes)
>>> optimizer = th.optim.Adam(model.parameters())
>>> for epoch in range(10):
...     logits = model(g, g.ndata["h"])
...     loss = th.nn.functional.cross_entropy(logits, th.tensor([1]))
...     optimizer.zero_grad()
...     loss.backward()
...     optimizer.step()
>>> # Initialize the explainer
>>> explainer = dgl.nn.HeteroPGExplainer(model, hidden_dim)
>>> # Train the explainer
>>> # Define explainer temperature parameter
>>> init_tmp, final_tmp = 5.0, 1.0
>>> optimizer_exp = th.optim.Adam(explainer.parameters(), lr=0.01)
>>> for epoch in range(20):
...     tmp = float(init_tmp * np.power(final_tmp / init_tmp, epoch / 20))
...     loss = explainer.train_step(g, g.ndata["h"], tmp)
...     optimizer_exp.zero_grad()
...     loss.backward()
...     optimizer_exp.step()
>>> # Explain the graph
>>> feat = g.ndata.pop("h")
>>> probs, edge_mask = explainer.explain_graph(g, feat)
explain_node(nodes, graph, feat, temperature=1.0, training=False, **kwargs)[source]

学习并返回一个边掩码,该掩码在解释 GNN 对提供的节点 ID 集做出的预测中起关键作用。同时,返回使用批处理图和边掩码所做出的预测。

参数
  • nodes (dict[str, Iterable[int]]) — 将节点类型(键)映射到节点 ID 可迭代集合(值)的字典。

  • graph (DGLGraph) — 异构图。

  • feat (dict[str, Tensor]) — 将节点类型(键)映射到特征张量(值)的字典。输入特征的形状为 \((N_t, D_t)\)\(N_t\) 是节点类型 \(t\) 的节点数,\(D_t\) 是节点类型 \(t\) 的特征尺寸。

  • temperature (float) — 提供给采样过程的温度参数。

  • training (bool) — 是否训练解释网络。

  • kwargs (dict) — 传递给 GNN 模型的附加参数。

返回

  • dict[str, Tensor] — 将节点类型(键)映射到节点标签分类概率(值)的字典。值为形状为 \((N_t, L)\) 的张量,其中 \(L\) 是数据集中节点标签的不同类型,\(N_t\) 是图中断节点类型 \(t\) 的节点数。

  • dict[str, Tensor] — 将边类型(键)映射到形状为 \((E_t)\) 的边张量(值)的字典,其中 \(E_t\) 是图中断边类型 \(t\) 的边数。权重越高表示该边的贡献越大。

  • DGLGraph — 在输入中心节点的 k 跳入邻域上诱导的子图的批处理集合。

  • dict[str, Tensor] — 将节点类型(键)映射到节点 ID 张量(值)的字典,这些节点 ID 对应于子图的中心节点。

示例

>>> import dgl
>>> import torch as th
>>> import torch.nn as nn
>>> import numpy as np
>>> # Define the model
>>> class Model(nn.Module):
...     def __init__(self, in_feats, hid_feats, out_feats, rel_names):
...         super().__init__()
...         self.conv = dgl.nn.HeteroGraphConv(
...             {rel: dgl.nn.GraphConv(in_feats, hid_feats) for rel in rel_names},
...             aggregate="sum",
...         )
...         self.fc = nn.Linear(hid_feats, out_feats)
...         nn.init.xavier_uniform_(self.fc.weight)
...
...     def forward(self, g, h, embed=False, edge_weight=None):
...         if edge_weight:
...             mod_kwargs = {
...                 etype: {"edge_weight": mask} for etype, mask in edge_weight.items()
...             }
...             h = self.conv(g, h, mod_kwargs=mod_kwargs)
...         else:
...             h = self.conv(g, h)
...
...         return h
>>> # Load dataset
>>> input_dim = 5
>>> hidden_dim = 5
>>> num_classes = 2
>>> g = dgl.heterograph({("user", "plays", "game"): ([0, 1, 1, 2], [0, 0, 1, 1])})
>>> g.nodes["user"].data["h"] = th.randn(g.num_nodes("user"), input_dim)
>>> g.nodes["game"].data["h"] = th.randn(g.num_nodes("game"), input_dim)
>>> transform = dgl.transforms.AddReverse()
>>> g = transform(g)
>>> # define and train the model
>>> model = Model(input_dim, hidden_dim, num_classes, g.canonical_etypes)
>>> optimizer = th.optim.Adam(model.parameters())
>>> for epoch in range(10):
...     logits = model(g, g.ndata["h"])['user']
...     loss = th.nn.functional.cross_entropy(logits, th.tensor([1,1,1]))
...     optimizer.zero_grad()
...     loss.backward()
...     optimizer.step()
>>> # Initialize the explainer
>>> explainer = dgl.nn.HeteroPGExplainer(
...     model, hidden_dim, num_hops=2, explain_graph=False
... )
>>> # Train the explainer
>>> # Define explainer temperature parameter
>>> init_tmp, final_tmp = 5.0, 1.0
>>> optimizer_exp = th.optim.Adam(explainer.parameters(), lr=0.01)
>>> for epoch in range(20):
...     tmp = float(init_tmp * np.power(final_tmp / init_tmp, epoch / 20))
...     loss = explainer.train_step_node(
...         { ntype: g.nodes(ntype) for ntype in g.ntypes },
...         g, g.ndata["h"], tmp
...     )
...     optimizer_exp.zero_grad()
...     loss.backward()
...     optimizer_exp.step()
>>> # Explain the graph
>>> feat = g.ndata.pop("h")
>>> probs, edge_mask, bg, inverse_indices = explainer.explain_node(
...     { "user": [0] }, g, feat
... )
forward(*input: Any) None

定义每次调用时执行的计算。

应由所有子类覆盖。

注意

虽然 forward pass 的实现需要在该函数中定义,但之后应调用 Module 实例而不是直接调用此函数,因为前者会处理运行已注册的钩子,而后者会静默忽略它们。

train_step(graph, feat, temperature, **kwargs)[source]

计算用于图分类的解释网络的损失。

参数
  • graph (DGLGraph) — 输入的批处理异构图。

  • feat (dict[str, Tensor]) — 将节点类型(键)映射到特征张量(值)的字典。输入特征的形状为 \((N_t, D_t)\)\(N_t\) 是节点类型 \(t\) 的节点数,\(D_t\) 是节点类型 \(t\) 的特征尺寸。

  • temperature (float) — 提供给采样过程的温度参数。

  • kwargs (dict) — 传递给 GNN 模型的附加参数。

返回

表示损失的标量张量。

返回类型

Tensor

train_step_node(nodes, graph, feat, temperature, **kwargs)[source]

计算用于节点分类的解释网络的损失。

参数
  • nodes (dict[str, Iterable[int]]) — 将节点类型(键)映射到节点 ID 可迭代集合(值)的字典。

  • graph (DGLGraph) — 输入的异构图。

  • feat (dict[str, Tensor]) — 将节点类型(键)映射到特征张量(值)的字典。输入特征的形状为 \((N_t, D_t)\)\(N_t\) 是节点类型 \(t\) 的节点数,\(D_t\) 是节点类型 \(t\) 的特征尺寸。

  • temperature (float) — 提供给采样过程的温度参数。

  • kwargs (dict) — 传递给 GNN 模型的附加参数。

返回

表示损失的标量张量。

返回类型

Tensor