HeteroGNNExplainer

class dgl.nn.pytorch.explain.HeteroGNNExplainer(model, num_hops, lr=0.01, num_epochs=100, *, alpha1=0.005, alpha2=1.0, beta1=1.0, beta2=0.1, log=True)[source]

基类:Module

来自 GNNExplainer: Generating Explanations for Graph Neural Networks 的 GNNExplainer 模型,针对异构图进行了调整

它识别出在基于 GNN 的节点分类和图分类中起关键作用的紧凑子图结构和小的节点特征子集。

为了生成解释,它通过优化以下目标函数来学习边缘掩码 \(M\) 和特征掩码 \(F\)。

\[l(y, \hat{y}) + \alpha_1 \|M\|_1 + \alpha_2 H(M) + \beta_1 \|F\|_1 + \beta_2 H(F)\]

其中 \(l\) 是损失函数,\(y\) 是原始模型预测,\(\hat{y}\) 是应用边缘和特征掩码后的模型预测,\(H\) 是熵函数。

参数:
  • model (nn.Module) –

    要解释的 GNN 模型。

    • 其 forward 函数所需的参数是 graph 和 feat。后者用于输入节点特征。

    • 它还应可选地接受一个用于边缘权重的 eweight 参数,并在消息传递中将消息乘以它。

    • 其 forward 函数的输出是预测的节点/图类别的 logits。

    另请参阅 explain_node()explain_graph() 中的示例。

  • num_hops (int) – GNN 信息聚合的跳数。

  • lr (float, 可选) – 要使用的学习率,默认为 0.01。

  • num_epochs (int, 可选) – 训练的 epoch 数。

  • alpha1 (float, 可选) – 值越高,通过减少边缘掩码的总和,使解释边缘掩码越稀疏。

  • alpha2 (float, 可选) – 值越高,通过减少边缘掩码的熵,使解释边缘掩码越稀疏。

  • beta1 (float, 可选) – 值越高,通过减少节点特征掩码的均值,使解释节点特征掩码越稀疏。

  • beta2 (float, 可选) – 值越高,通过减少节点特征掩码的熵,使解释节点特征掩码越稀疏。

  • log (bool, 可选) – 如果为 True,将记录计算过程,默认为 True。

explain_graph(graph, feat, **kwargs)[source]

学习并返回在解释 GNN 对图的预测中起关键作用的节点特征掩码和边缘掩码。

参数:
  • graph (DGLGraph) – 将要解释的异构图。

  • feat (dict[str, Tensor]) – 将输入节点特征(值)与图中存在的相应节点类型(键)关联起来的字典。输入特征的形状为 \( (N_t, D_t) \)。\( N_t \) 是节点类型 \( t \) 的节点数,\( D_t \) 是节点类型 \( t \) 的特征大小。

  • kwargs (dict) – 传递给 GNN 模型的附加参数。

返回值:

  • feat_mask (dict[str, Tensor]) – 将学习到的节点特征重要性掩码(值)与相应的节点类型(键)关联起来的字典。掩码的形状为 \( (D_t) \),其中 \( D_t \) 是节点类型 t 的节点特征大小。值在 \( (0, 1) \) 范围内。值越高,表示越重要。

  • edge_mask (dict[Tuple[str], Tensor]) – 将学习到的边缘重要性掩码(值)与相应的规范边缘类型(键)关联起来的字典。掩码的形状为 \( (E_t) \),其中 \( E_t \) 是图中规范边缘类型 \( t \) 的边缘数量。值在 \( (0, 1) \) 范围内。值越高,表示越重要。

示例

>>> import dgl
>>> import dgl.function as fn
>>> import torch as th
>>> import torch.nn as nn
>>> import torch.nn.functional as F
>>> from dgl.nn import HeteroGNNExplainer
>>> class Model(nn.Module):
...     def __init__(self, in_dim, num_classes, canonical_etypes):
...         super(Model, self).__init__()
...         self.etype_weights = nn.ModuleDict({
...             '_'.join(c_etype): nn.Linear(in_dim, num_classes)
...             for c_etype in canonical_etypes
...         })
...
...     def forward(self, graph, feat, eweight=None):
...         with graph.local_scope():
...             c_etype_func_dict = {}
...             for c_etype in graph.canonical_etypes:
...                 src_type, etype, dst_type = c_etype
...                 wh = self.etype_weights['_'.join(c_etype)](feat[src_type])
...                 graph.nodes[src_type].data[f'h_{c_etype}'] = wh
...                 if eweight is None:
...                     c_etype_func_dict[c_etype] = (fn.copy_u(f'h_{c_etype}', 'm'),
...                         fn.mean('m', 'h'))
...                 else:
...                     graph.edges[c_etype].data['w'] = eweight[c_etype]
...                     c_etype_func_dict[c_etype] = (
...                         fn.u_mul_e(f'h_{c_etype}', 'w', 'm'), fn.mean('m', 'h'))
...             graph.multi_update_all(c_etype_func_dict, 'sum')
...             hg = 0
...             for ntype in graph.ntypes:
...                 if graph.num_nodes(ntype):
...                     hg = hg + dgl.mean_nodes(graph, 'h', ntype=ntype)
...             return hg
>>> input_dim = 5
>>> num_classes = 2
>>> g = dgl.heterograph({
...     ('user', 'plays', 'game'): ([0, 1, 1, 2], [0, 0, 1, 1])})
>>> g.nodes['user'].data['h'] = th.randn(g.num_nodes('user'), input_dim)
>>> g.nodes['game'].data['h'] = th.randn(g.num_nodes('game'), input_dim)
>>> transform = dgl.transforms.AddReverse()
>>> g = transform(g)
>>> # define and train the model
>>> model = Model(input_dim, num_classes, g.canonical_etypes)
>>> feat = g.ndata['h']
>>> optimizer = th.optim.Adam(model.parameters())
>>> for epoch in range(10):
...     logits = model(g, feat)
...     loss = F.cross_entropy(logits, th.tensor([1]))
...     optimizer.zero_grad()
...     loss.backward()
...     optimizer.step()
>>> # Explain for the graph
>>> explainer = HeteroGNNExplainer(model, num_hops=1)
>>> feat_mask, edge_mask = explainer.explain_graph(g, feat)
>>> feat_mask
{'game': tensor([0.2684, 0.2597, 0.3135, 0.2976, 0.2607]),
 'user': tensor([0.2216, 0.2908, 0.2644, 0.2738, 0.2663])}
>>> edge_mask
{('game', 'rev_plays', 'user'): tensor([0.8922, 0.1966, 0.8371, 0.1330]),
 ('user', 'plays', 'game'): tensor([0.1785, 0.1696, 0.8065, 0.2167])}
explain_node(ntype, node_id, graph, feat, **kwargs)[source]

学习并返回节点特征掩码和一个子图,这些掩码和子图在解释 GNN 对类型为 ntype、ID 为 node_id 的节点的预测中起关键作用。

它要求 model 返回一个字典,将节点类型映射到特定于类型的预测。

参数:
  • ntype (str) – 要解释的节点类型。model 必须经过训练才能对此特定节点类型进行预测。

  • node_id (int) – 要解释的节点 ID。

  • graph (DGLGraph) – 一个异构图。

  • feat (dict[str, Tensor]) – 将输入节点特征(值)与图中存在的相应节点类型(键)关联起来的字典。输入特征的形状为 \( (N_t, D_t) \)。\( N_t \) 是节点类型 \( t \) 的节点数,\( D_t \) 是节点类型 \( t \) 的特征大小。

  • kwargs (dict) – 传递给 GNN 模型的附加参数。

返回值:

  • new_node_id (Tensor) – 输入中心节点的新 ID。

  • sg (DGLGraph) – 在输入中心节点的 k 跳入邻域上诱导的子图。

  • feat_mask (dict[str, Tensor]) – 将学习到的节点特征重要性掩码(值)与相应的节点类型(键)关联起来的字典。掩码的形状为 \( (D_t) \),其中 \( D_t \) 是节点类型 t 的节点特征大小。值在 \( (0, 1) \) 范围内。值越高,表示越重要。

  • edge_mask (dict[Tuple[str], Tensor]) – 将学习到的边缘重要性掩码(值)与相应的规范边缘类型(键)关联起来的字典。掩码的形状为 \( (E_t) \),其中 \( E_t \) 是子图中规范边缘类型 \( t \) 的边缘数量。值在 \( (0, 1) \) 范围内。值越高,表示越重要。

示例

>>> import dgl
>>> import dgl.function as fn
>>> import torch as th
>>> import torch.nn as nn
>>> import torch.nn.functional as F
>>> from dgl.nn import HeteroGNNExplainer
>>> class Model(nn.Module):
...     def __init__(self, in_dim, num_classes, canonical_etypes):
...         super(Model, self).__init__()
...         self.etype_weights = nn.ModuleDict({
...             '_'.join(c_etype): nn.Linear(in_dim, num_classes)
...             for c_etype in canonical_etypes
...         })
...
...     def forward(self, graph, feat, eweight=None):
...         with graph.local_scope():
...             c_etype_func_dict = {}
...             for c_etype in graph.canonical_etypes:
...                 src_type, etype, dst_type = c_etype
...                 wh = self.etype_weights['_'.join(c_etype)](feat[src_type])
...                 graph.nodes[src_type].data[f'h_{c_etype}'] = wh
...                 if eweight is None:
...                     c_etype_func_dict[c_etype] = (fn.copy_u(f'h_{c_etype}', 'm'),
...                         fn.mean('m', 'h'))
...                 else:
...                     graph.edges[c_etype].data['w'] = eweight[c_etype]
...                     c_etype_func_dict[c_etype] = (
...                         fn.u_mul_e(f'h_{c_etype}', 'w', 'm'), fn.mean('m', 'h'))
...             graph.multi_update_all(c_etype_func_dict, 'sum')
...             return graph.ndata['h']
>>> input_dim = 5
>>> num_classes = 2
>>> g = dgl.heterograph({
...     ('user', 'plays', 'game'): ([0, 1, 1, 2], [0, 0, 1, 1])})
>>> g.nodes['user'].data['h'] = th.randn(g.num_nodes('user'), input_dim)
>>> g.nodes['game'].data['h'] = th.randn(g.num_nodes('game'), input_dim)
>>> transform = dgl.transforms.AddReverse()
>>> g = transform(g)
>>> # define and train the model
>>> model = Model(input_dim, num_classes, g.canonical_etypes)
>>> feat = g.ndata['h']
>>> optimizer = th.optim.Adam(model.parameters())
>>> for epoch in range(10):
...     logits = model(g, feat)['user']
...     loss = F.cross_entropy(logits, th.tensor([1, 1, 1]))
...     optimizer.zero_grad()
...     loss.backward()
...     optimizer.step()
>>> # Explain the prediction for node 0 of type 'user'
>>> explainer = HeteroGNNExplainer(model, num_hops=1)
>>> new_center, sg, feat_mask, edge_mask = explainer.explain_node('user', 0, g, feat)
>>> new_center
tensor([0])
>>> sg
Graph(num_nodes={'game': 1, 'user': 1},
      num_edges={('game', 'rev_plays', 'user'): 1, ('user', 'plays', 'game'): 1,
                 ('user', 'rev_rev_plays', 'game'): 1},
      metagraph=[('game', 'user', 'rev_plays'), ('user', 'game', 'plays'),
                 ('user', 'game', 'rev_rev_plays')])
>>> feat_mask
{'game': tensor([0.2348, 0.2780, 0.2611, 0.2513, 0.2823]),
 'user': tensor([0.2716, 0.2450, 0.2658, 0.2876, 0.2738])}
>>> edge_mask
{('game', 'rev_plays', 'user'): tensor([0.0630]),
 ('user', 'plays', 'game'): tensor([0.1939]),
 ('user', 'rev_rev_plays', 'game'): tensor([0.9166])}
forward(*input: Any) None

定义每次调用时执行的计算。

应由所有子类覆盖。

注意

虽然前向传播的实现需要在此函数中定义,但之后应调用 Module 实例而不是此函数,因为前者负责运行注册的钩子,而后者会静默忽略它们。