HeteroPGExplainer
- class dgl.nn.pytorch.explain.HeteroPGExplainer(model, num_features, num_hops=None, explain_graph=True, coff_budget=0.01, coff_connect=0.0005, sample_bias=0.0)[source]
基类:
PGExplainer
PGExplainer 来自 Parameterized Explainer for Graph Neural Network,适用于异构图
PGExplainer 采用深度神经网络(解释网络)来参数化解释的生成过程,这使其能够集体解释多个实例。PGExplainer 将底层结构建模为边分布,并从中采样生成解释图。
- 参数:
model (nn.Module) —
用于解释的 GNN 模型,处理多类别图分类
其 forward 函数必须具有以下形式:
forward(self, graph, nfeat, embed, edge_weight)
。如果 embed=False,其 forward 函数的输出是 logits;否则是中间节点嵌入。
num_features (int) —
model
使用的节点嵌入尺寸。coff_budget (float, 可选) — 用于约束解释尺寸的尺寸正则化项。默认值:0.01。
coff_connect (float, 可选) — 用于约束解释连通性的熵正则化项。默认值:5e-4。
sample_bias (float, 可选) — 总体的某些成员在样本中被选中的可能性系统地高于其他成员。默认值:0.0。
- explain_graph(graph, feat, temperature=1.0, training=False, **kwargs)[source]
学习并返回一个边掩码,该掩码在解释 GNN 对图做出的预测中起关键作用。同时,返回基于该边掩码选择的边所做出的预测。
- 参数:
- 返回:
Tensor — 给定掩码图的分类概率。这是形状为 \((B, L)\) 的张量,其中 \(L\) 是数据集中标签的不同类型,\(B\) 是批处理大小。
dict[str, Tensor] — 将边类型(键)映射到形状为 \((E_t)\) 的边张量(值)的字典,其中 \(E_t\) 是图中断边类型 \(t\) 的边数。权重越高表示该边的贡献越大。
示例
>>> import dgl >>> import torch as th >>> import torch.nn as nn >>> import numpy as np
>>> # Define the model >>> class Model(nn.Module): ... def __init__(self, in_feats, hid_feats, out_feats, rel_names): ... super().__init__() ... self.conv = dgl.nn.HeteroGraphConv( ... {rel: dgl.nn.GraphConv(in_feats, hid_feats) for rel in rel_names}, ... aggregate="sum", ... ) ... self.fc = nn.Linear(hid_feats, out_feats) ... nn.init.xavier_uniform_(self.fc.weight) ... ... def forward(self, g, h, embed=False, edge_weight=None): ... if edge_weight: ... mod_kwargs = { ... etype: {"edge_weight": mask} for etype, mask in edge_weight.items() ... } ... h = self.conv(g, h, mod_kwargs=mod_kwargs) ... else: ... h = self.conv(g, h) ... ... if embed: ... return h ... ... with g.local_scope(): ... g.ndata["h"] = h ... hg = 0 ... for ntype in g.ntypes: ... hg = hg + dgl.mean_nodes(g, "h", ntype=ntype) ... return self.fc(hg)
>>> # Load dataset >>> input_dim = 5 >>> hidden_dim = 5 >>> num_classes = 2 >>> g = dgl.heterograph({("user", "plays", "game"): ([0, 1, 1, 2], [0, 0, 1, 1])}) >>> g.nodes["user"].data["h"] = th.randn(g.num_nodes("user"), input_dim) >>> g.nodes["game"].data["h"] = th.randn(g.num_nodes("game"), input_dim)
>>> transform = dgl.transforms.AddReverse() >>> g = transform(g)
>>> # define and train the model >>> model = Model(input_dim, hidden_dim, num_classes, g.canonical_etypes) >>> optimizer = th.optim.Adam(model.parameters()) >>> for epoch in range(10): ... logits = model(g, g.ndata["h"]) ... loss = th.nn.functional.cross_entropy(logits, th.tensor([1])) ... optimizer.zero_grad() ... loss.backward() ... optimizer.step()
>>> # Initialize the explainer >>> explainer = dgl.nn.HeteroPGExplainer(model, hidden_dim)
>>> # Train the explainer >>> # Define explainer temperature parameter >>> init_tmp, final_tmp = 5.0, 1.0 >>> optimizer_exp = th.optim.Adam(explainer.parameters(), lr=0.01) >>> for epoch in range(20): ... tmp = float(init_tmp * np.power(final_tmp / init_tmp, epoch / 20)) ... loss = explainer.train_step(g, g.ndata["h"], tmp) ... optimizer_exp.zero_grad() ... loss.backward() ... optimizer_exp.step()
>>> # Explain the graph >>> feat = g.ndata.pop("h") >>> probs, edge_mask = explainer.explain_graph(g, feat)
- explain_node(nodes, graph, feat, temperature=1.0, training=False, **kwargs)[source]
学习并返回一个边掩码,该掩码在解释 GNN 对提供的节点 ID 集做出的预测中起关键作用。同时,返回使用批处理图和边掩码所做出的预测。
- 参数:
nodes (dict[str, Iterable[int]]) — 将节点类型(键)映射到节点 ID 可迭代集合(值)的字典。
graph (DGLGraph) — 异构图。
feat (dict[str, Tensor]) — 将节点类型(键)映射到特征张量(值)的字典。输入特征的形状为 \((N_t, D_t)\)。\(N_t\) 是节点类型 \(t\) 的节点数,\(D_t\) 是节点类型 \(t\) 的特征尺寸。
temperature (float) — 提供给采样过程的温度参数。
training (bool) — 是否训练解释网络。
kwargs (dict) — 传递给 GNN 模型的附加参数。
- 返回:
dict[str, Tensor] — 将节点类型(键)映射到节点标签分类概率(值)的字典。值为形状为 \((N_t, L)\) 的张量,其中 \(L\) 是数据集中节点标签的不同类型,\(N_t\) 是图中断节点类型 \(t\) 的节点数。
dict[str, Tensor] — 将边类型(键)映射到形状为 \((E_t)\) 的边张量(值)的字典,其中 \(E_t\) 是图中断边类型 \(t\) 的边数。权重越高表示该边的贡献越大。
DGLGraph — 在输入中心节点的 k 跳入邻域上诱导的子图的批处理集合。
dict[str, Tensor] — 将节点类型(键)映射到节点 ID 张量(值)的字典,这些节点 ID 对应于子图的中心节点。
示例
>>> import dgl >>> import torch as th >>> import torch.nn as nn >>> import numpy as np
>>> # Define the model >>> class Model(nn.Module): ... def __init__(self, in_feats, hid_feats, out_feats, rel_names): ... super().__init__() ... self.conv = dgl.nn.HeteroGraphConv( ... {rel: dgl.nn.GraphConv(in_feats, hid_feats) for rel in rel_names}, ... aggregate="sum", ... ) ... self.fc = nn.Linear(hid_feats, out_feats) ... nn.init.xavier_uniform_(self.fc.weight) ... ... def forward(self, g, h, embed=False, edge_weight=None): ... if edge_weight: ... mod_kwargs = { ... etype: {"edge_weight": mask} for etype, mask in edge_weight.items() ... } ... h = self.conv(g, h, mod_kwargs=mod_kwargs) ... else: ... h = self.conv(g, h) ... ... return h
>>> # Load dataset >>> input_dim = 5 >>> hidden_dim = 5 >>> num_classes = 2 >>> g = dgl.heterograph({("user", "plays", "game"): ([0, 1, 1, 2], [0, 0, 1, 1])}) >>> g.nodes["user"].data["h"] = th.randn(g.num_nodes("user"), input_dim) >>> g.nodes["game"].data["h"] = th.randn(g.num_nodes("game"), input_dim)
>>> transform = dgl.transforms.AddReverse() >>> g = transform(g)
>>> # define and train the model >>> model = Model(input_dim, hidden_dim, num_classes, g.canonical_etypes) >>> optimizer = th.optim.Adam(model.parameters()) >>> for epoch in range(10): ... logits = model(g, g.ndata["h"])['user'] ... loss = th.nn.functional.cross_entropy(logits, th.tensor([1,1,1])) ... optimizer.zero_grad() ... loss.backward() ... optimizer.step()
>>> # Initialize the explainer >>> explainer = dgl.nn.HeteroPGExplainer( ... model, hidden_dim, num_hops=2, explain_graph=False ... )
>>> # Train the explainer >>> # Define explainer temperature parameter >>> init_tmp, final_tmp = 5.0, 1.0 >>> optimizer_exp = th.optim.Adam(explainer.parameters(), lr=0.01) >>> for epoch in range(20): ... tmp = float(init_tmp * np.power(final_tmp / init_tmp, epoch / 20)) ... loss = explainer.train_step_node( ... { ntype: g.nodes(ntype) for ntype in g.ntypes }, ... g, g.ndata["h"], tmp ... ) ... optimizer_exp.zero_grad() ... loss.backward() ... optimizer_exp.step()
>>> # Explain the graph >>> feat = g.ndata.pop("h") >>> probs, edge_mask, bg, inverse_indices = explainer.explain_node( ... { "user": [0] }, g, feat ... )
- forward(*input: Any) None
定义每次调用时执行的计算。
应由所有子类覆盖。
注意
虽然 forward pass 的实现需要在该函数中定义,但之后应调用
Module
实例而不是直接调用此函数,因为前者会处理运行已注册的钩子,而后者会静默忽略它们。