HeteroSubgraphX

class dgl.nn.pytorch.explain.HeteroSubgraphX(model, num_hops, coef=10.0, high2low=True, num_child=12, num_rollouts=20, node_min=3, shapley_steps=100, log=False)[source]

基础类: Module

SubgraphX,摘自论文 On Explainability of Graph Neural Networks via Subgraph Explorations,适用于异构图

它从原始图中识别出最重要的子图,该子图在基于 GNN 的图分类中起着关键作用。

它采用蒙特卡洛树搜索 (MCTS) 有效探索不同的子图以进行解释,并使用 Shapley 值作为衡量子图重要性的指标。

参数:
  • model (nn.Module) –

    用于解释的多类别图分类 GNN 模型

    • 其 forward 函数必须具有 forward(self, graph, nfeat) 的形式。

    • 其 forward 函数的输出是 logits。

  • num_hops (int) – 模型中的消息传递层数

  • coef (float, optional) – 此超参数控制探索与利用之间的权衡。值越高,算法越倾向于探索相对未访问的节点。默认值: 10.0

  • high2low (bool, optional) – 如果为 True,则在扩展搜索树中的子节点时,将使用“High2low”策略来修剪动作,即从高度节点扩展到低度节点。否则,将使用“Low2high”策略。默认值: True

  • num_child (int, optional) – 这是在扩展搜索树中的子节点时要展开的子节点数量。默认值: 12

  • num_rollouts (int, optional) – 这是 MCTS 的 rollout 次数。默认值: 20

  • node_min (int, optional) – 这是根据子图中的节点数量定义叶节点的阈值。默认值: 3

  • shapley_steps (int, optional) – 这是估算 Shapley 值时进行蒙特卡洛采样的步数。默认值: 100

  • log (bool, optional) – 如果为 True,则会记录进度。默认值: False

explain_graph(graph, feat, target_class, **kwargs)[source]

找到原始图中最重要的子图,用于模型将图分类到目标类别。

参数:
  • graph (DGLGraph) – 异构图

  • feat (dict[str, Tensor]) – 将输入节点特征(值)与图中存在的相应节点类型(键)关联的字典。输入特征的形状为 \((N_t, D_t)\)\(N_t\) 是节点类型 \(t\) 的节点数,\(D_t\) 是节点类型 \(t\) 的特征大小。

  • target_class (int) – 要解释的目标类别

  • kwargs (dict) – 传递给 GNN 模型的附加参数

返回值:

将张量节点 ID(值)与节点类型(键)关联的字典,表示最重要的子图

返回类型:

dict[str, Tensor]

示例

>>> import dgl
>>> import dgl.function as fn
>>> import torch as th
>>> import torch.nn as nn
>>> import torch.nn.functional as F
>>> from dgl.nn import HeteroSubgraphX
>>> class Model(nn.Module):
...     def __init__(self, in_dim, num_classes, canonical_etypes):
...         super(Model, self).__init__()
...         self.etype_weights = nn.ModuleDict(
...             {
...                 "_".join(c_etype): nn.Linear(in_dim, num_classes)
...                 for c_etype in canonical_etypes
...             }
...         )
...
...     def forward(self, graph, feat):
...         with graph.local_scope():
...             c_etype_func_dict = {}
...             for c_etype in graph.canonical_etypes:
...                 src_type, etype, dst_type = c_etype
...                 wh = self.etype_weights["_".join(c_etype)](feat[src_type])
...                 graph.nodes[src_type].data[f"h_{c_etype}"] = wh
...                 c_etype_func_dict[c_etype] = (
...                     fn.copy_u(f"h_{c_etype}", "m"),
...                     fn.mean("m", "h"),
...                 )
...             graph.multi_update_all(c_etype_func_dict, "sum")
...             hg = 0
...             for ntype in graph.ntypes:
...                 if graph.num_nodes(ntype):
...                     hg = hg + dgl.mean_nodes(graph, "h", ntype=ntype)
...             return hg
>>> input_dim = 5
>>> num_classes = 2
>>> g = dgl.heterograph({("user", "plays", "game"): ([0, 1, 1, 2], [0, 0, 1, 1])})
>>> g.nodes["user"].data["h"] = th.randn(g.num_nodes("user"), input_dim)
>>> g.nodes["game"].data["h"] = th.randn(g.num_nodes("game"), input_dim)
>>> transform = dgl.transforms.AddReverse()
>>> g = transform(g)
>>> # define and train the model
>>> model = Model(input_dim, num_classes, g.canonical_etypes)
>>> feat = g.ndata["h"]
>>> optimizer = th.optim.Adam(model.parameters())
>>> for epoch in range(10):
...     logits = model(g, feat)
...     loss = F.cross_entropy(logits, th.tensor([1]))
...     optimizer.zero_grad()
...     loss.backward()
...     optimizer.step()
>>> # Explain for the graph
>>> explainer = HeteroSubgraphX(model, num_hops=1)
>>> explainer.explain_graph(g, feat, target_class=1)
{'game': tensor([0, 1]), 'user': tensor([1, 2])}
forward(*input: Any) None

定义每次调用时执行的计算。

应被所有子类重写。

注意

尽管前向传播的实现需要在此函数中定义,但之后应该调用 Module 实例而不是此函数本身,因为前者会处理注册的钩子,而后者会默默忽略它们。