HeteroSubgraphX
- class dgl.nn.pytorch.explain.HeteroSubgraphX(model, num_hops, coef=10.0, high2low=True, num_child=12, num_rollouts=20, node_min=3, shapley_steps=100, log=False)[source]
基础类:
Module
SubgraphX,摘自论文 On Explainability of Graph Neural Networks via Subgraph Explorations,适用于异构图
它从原始图中识别出最重要的子图,该子图在基于 GNN 的图分类中起着关键作用。
它采用蒙特卡洛树搜索 (MCTS) 有效探索不同的子图以进行解释,并使用 Shapley 值作为衡量子图重要性的指标。
- 参数:
model (nn.Module) –
用于解释的多类别图分类 GNN 模型
其 forward 函数必须具有
forward(self, graph, nfeat)
的形式。其 forward 函数的输出是 logits。
num_hops (int) – 模型中的消息传递层数
coef (float, optional) – 此超参数控制探索与利用之间的权衡。值越高,算法越倾向于探索相对未访问的节点。默认值: 10.0
high2low (bool, optional) – 如果为 True,则在扩展搜索树中的子节点时,将使用“High2low”策略来修剪动作,即从高度节点扩展到低度节点。否则,将使用“Low2high”策略。默认值: True
num_child (int, optional) – 这是在扩展搜索树中的子节点时要展开的子节点数量。默认值: 12
num_rollouts (int, optional) – 这是 MCTS 的 rollout 次数。默认值: 20
node_min (int, optional) – 这是根据子图中的节点数量定义叶节点的阈值。默认值: 3
shapley_steps (int, optional) – 这是估算 Shapley 值时进行蒙特卡洛采样的步数。默认值: 100
log (bool, optional) – 如果为 True,则会记录进度。默认值: False
- explain_graph(graph, feat, target_class, **kwargs)[source]
找到原始图中最重要的子图,用于模型将图分类到目标类别。
- 参数:
- 返回值:
将张量节点 ID(值)与节点类型(键)关联的字典,表示最重要的子图
- 返回类型:
示例
>>> import dgl >>> import dgl.function as fn >>> import torch as th >>> import torch.nn as nn >>> import torch.nn.functional as F >>> from dgl.nn import HeteroSubgraphX
>>> class Model(nn.Module): ... def __init__(self, in_dim, num_classes, canonical_etypes): ... super(Model, self).__init__() ... self.etype_weights = nn.ModuleDict( ... { ... "_".join(c_etype): nn.Linear(in_dim, num_classes) ... for c_etype in canonical_etypes ... } ... ) ... ... def forward(self, graph, feat): ... with graph.local_scope(): ... c_etype_func_dict = {} ... for c_etype in graph.canonical_etypes: ... src_type, etype, dst_type = c_etype ... wh = self.etype_weights["_".join(c_etype)](feat[src_type]) ... graph.nodes[src_type].data[f"h_{c_etype}"] = wh ... c_etype_func_dict[c_etype] = ( ... fn.copy_u(f"h_{c_etype}", "m"), ... fn.mean("m", "h"), ... ) ... graph.multi_update_all(c_etype_func_dict, "sum") ... hg = 0 ... for ntype in graph.ntypes: ... if graph.num_nodes(ntype): ... hg = hg + dgl.mean_nodes(graph, "h", ntype=ntype) ... return hg
>>> input_dim = 5 >>> num_classes = 2 >>> g = dgl.heterograph({("user", "plays", "game"): ([0, 1, 1, 2], [0, 0, 1, 1])}) >>> g.nodes["user"].data["h"] = th.randn(g.num_nodes("user"), input_dim) >>> g.nodes["game"].data["h"] = th.randn(g.num_nodes("game"), input_dim)
>>> transform = dgl.transforms.AddReverse() >>> g = transform(g)
>>> # define and train the model >>> model = Model(input_dim, num_classes, g.canonical_etypes) >>> feat = g.ndata["h"] >>> optimizer = th.optim.Adam(model.parameters()) >>> for epoch in range(10): ... logits = model(g, feat) ... loss = F.cross_entropy(logits, th.tensor([1])) ... optimizer.zero_grad() ... loss.backward() ... optimizer.step()
>>> # Explain for the graph >>> explainer = HeteroSubgraphX(model, num_hops=1) >>> explainer.explain_graph(g, feat, target_class=1) {'game': tensor([0, 1]), 'user': tensor([1, 2])}