Transformer 作为图神经网络

作者:Zihao Ye, Jinjing Zhou, Qipeng Guo, Quan Gan, Zheng Zhang

警告

本教程旨在通过代码解释论文,以帮助理解。因此,此实现并未针对运行效率进行优化。对于推荐的实现,请参考官方示例

在本教程中,你将学习 Transformer 模型的简化实现。你可以看到一些最重要的设计要点。例如,这里只有单头注意力。完整的代码可以在这里找到。

整体结构类似于研究论文《Transformer 带注释版》中的结构。

Transformer 模型作为 CNN/RNN 架构在序列建模方面的替代品,在研究论文《Attention is All You Need》中被提出。它改进了机器翻译以及自然语言推理任务(GPT)的最新技术水平。最近关于使用大规模语料库预训练 Transformer(BERT)的研究支持了它能够学习高质量语义表示的观点。

Transformer 的有趣之处在于其广泛使用注意力机制。注意力的经典用法来自于机器翻译模型,其中输出 token 会关注所有输入 token。

Transformer 还在解码器和编码器中额外应用了自注意力。这个过程迫使相互关联的词汇结合在一起,而不管它们在序列中的位置。这与基于 RNN 的模型不同,后者将词汇(在源句子中)沿着链条结合,这被认为限制性太强。

Transformer 的注意力层

在 Transformer 的注意力层中,对于每个节点,模块学习为其输入边分配权重。对于节点对\((i, j)\)(从\(i\)\(j\)),节点特征为\(x_i, x_j \in \mathbb{R}^n\),它们连接的得分定义如下:

\[\begin{split}q_j = W_q\cdot x_j \\ k_i = W_k\cdot x_i\\ v_i = W_v\cdot x_i\\ \textrm{score} = q_j^T k_i\end{split}\]

其中\(W_q, W_k, W_v \in \mathbb{R}^{n\times d_k}\)分别将表示\(x\)映射到“查询 (query)”、“键 (key)”和“值 (value)”空间。

还有其他实现得分函数的方法。点积衡量给定查询\(q_j\)和键\(k_i\)的相似度:如果\(j\)需要存储在\(i\)中的信息,则位置\(j\)的查询向量(\(q_j\))应该接近位置\(i\)的键向量(\(k_i\))。

然后使用得分计算输入值的总和,并根据边的权重进行归一化,结果存储在\(\textrm{wv}\)中。然后对\(\textrm{wv}\)应用一个仿射层以获得输出\(o\)

\[\begin{split}w_{ji} = \frac{\exp\{\textrm{score}_{ji} \}}{\sum\limits_{(k, i)\in E}\exp\{\textrm{score}_{ki} \}} \\ \textrm{wv}_i = \sum_{(k, i)\in E} w_{ki} v_k \\ o = W_o\cdot \textrm{wv} \\\end{split}\]

多头注意力层

在 Transformer 中,注意力是多头的。一个头很像卷积网络中的一个通道。多头注意力由多个注意力头组成,其中每个头指代一个单独的注意力模块。所有头的\(\textrm{wv}^{(i)}\)被拼接起来,并通过一个仿射层映射到输出\(o\)

\[o = W_o \cdot \textrm{concat}\left(\left[\textrm{wv}^{(0)}, \textrm{wv}^{(1)}, \cdots, \textrm{wv}^{(h)}\right]\right)\]

下面的代码封装了多头注意力所需的组件,并提供了两个接口。

  • get 将状态“x”映射到查询 (query)、键 (key) 和值 (value),这是后续步骤 (propagate_attention) 所需的。

  • get_o 将注意力后的更新值映射到输出\(o\),用于后处理。

class MultiHeadAttention(nn.Module):
    "Multi-Head Attention"
    def __init__(self, h, dim_model):
        "h: number of heads; dim_model: hidden dimension"
        super(MultiHeadAttention, self).__init__()
        self.d_k = dim_model // h
        self.h = h
        # W_q, W_k, W_v, W_o
        self.linears = clones(nn.Linear(dim_model, dim_model), 4)

    def get(self, x, fields='qkv'):
        "Return a dict of queries / keys / values."
        batch_size = x.shape[0]
        ret = {}
        if 'q' in fields:
            ret['q'] = self.linears[0](x).view(batch_size, self.h, self.d_k)
        if 'k' in fields:
            ret['k'] = self.linears[1](x).view(batch_size, self.h, self.d_k)
        if 'v' in fields:
            ret['v'] = self.linears[2](x).view(batch_size, self.h, self.d_k)
        return ret

    def get_o(self, x):
        "get output of the multi-head attention"
        batch_size = x.shape[0]
        return self.linears[3](x.view(batch_size, -1))

DGL 如何使用图神经网络实现 Transformer

通过将注意力视为图中的边并在边上采用消息传递来引发适当的处理,你可以获得 Transformer 的不同视角。

图结构

通过将源句子和目标句子的 token 映射到节点来构建图。完整的 Transformer 图由三个子图组成

源语言图。这是一个完全图,每个 token \(s_i\) 可以关注任何其他 token \(s_j\)(包括自环)。 image0 目标语言图。该图是半完全的,即\(t_i\)仅在\(i > j\)时关注\(t_j\)(输出 token 不能依赖未来的词汇)。 image1 跨语言图。这是一个二分图,其中从每个源 token \(s_i\) 到每个目标 token \(t_j\) 都有一条边,这意味着每个目标 token 都可以关注源 token。 image2

整体图如下所示: image3

在数据集准备阶段预先构建图。

消息传递

一旦定义了图结构,就可以继续定义消息传递的计算。

假设你已经计算了所有的查询\(q_i\)、键\(k_i\)和值\(v_i\)。对于每个节点\(i\)(无论是源 token 还是目标 token),你可以将注意力计算分解为两个步骤

  1. 消息计算: 通过计算\(q_i\)\(k_j\)的缩放点积,计算\(i\)与所有待关注节点\(j\)之间的注意力得分\(\mathrm{score}_{ij}\)。从\(j\)发送到\(i\)的消息将包含得分\(\mathrm{score}_{ij}\)和值\(v_j\)

  2. 消息聚合: 根据得分\(\mathrm{score}_{ij}\),聚合来自所有\(j\)的值\(v_j\)

简单实现

消息计算

计算score 并将源节点的v发送到目标节点的邮箱

def message_func(edges):
    return {'score': ((edges.src['k'] * edges.dst['q'])
                      .sum(-1, keepdim=True)),
            'v': edges.src['v']}
消息聚合

对所有输入边进行归一化并加权求和以获得输出

import torch as th
import torch.nn.functional as F

def reduce_func(nodes, d_k=64):
    v = nodes.mailbox['v']
    att = F.softmax(nodes.mailbox['score'] / th.sqrt(d_k), 1)
    return {'dx': (att * v).sum(1)}
在特定边上执行
import functools.partial as partial
def naive_propagate_attention(self, g, eids):
    g.send_and_recv(eids, message_func, partial(reduce_func, d_k=self.d_k))

使用内置函数加速

为了加速消息传递过程,请使用 DGL 的内置函数,包括

  • fn.src_mul_egdes(src_field, edges_field, out_field) 将源节点属性与边属性相乘,并将结果发送到目标节点的邮箱,键为out_field

  • fn.copy_e(edges_field, out_field) 将边属性复制到目标节点的邮箱。

  • fn.sum(edges_field, out_field) 对边属性求和,并将聚合结果发送到目标节点的邮箱。

在这里,你将这些内置函数组装到propagate_attention中,这也是最终实现中的主要图操作函数。为了加速它,将softmax操作分解为以下步骤。回想一下,每个头有两个阶段。

  1. 通过将源节点的k与目标节点的q相乘来计算注意力得分

    • g.apply_edges(src_dot_dst('k', 'q', 'score'), eids)

  2. 对所有目标节点的输入边进行缩放 Softmax

    • 步骤 1:使用缩放归一化常数对得分进行指数化

      • g.apply_edges(scaled_exp('score', np.sqrt(self.d_k)))

        \[\textrm{score}_{ij}\leftarrow\exp{\left(\frac{\textrm{score}_{ij}}{ \sqrt{d_k}}\right)}\]
    • 步骤 2:获取关联节点上按每个节点的输入边上的“得分”加权后的“值”;获取每个节点的输入边上的“得分”之和用于归一化。注意,这里的\(\textrm{wv}\)未进行归一化。

      • msg: fn.u_mul_e('v', 'score', 'v'), reduce: fn.sum('v', 'wv')

        \[\textrm{wv}_j=\sum_{i=1}^{N} \textrm{score}_{ij} \cdot v_i\]
      • msg: fn.copy_e('score', 'score'), reduce: fn.sum('score', 'z')

        \[\textrm{z}_j=\sum_{i=1}^{N} \textrm{score}_{ij}\]

\(\textrm{wv}\)的归一化留待后处理。

def src_dot_dst(src_field, dst_field, out_field):
    def func(edges):
        return {out_field: (edges.src[src_field] * edges.dst[dst_field]).sum(-1, keepdim=True)}

    return func

def scaled_exp(field, scale_constant):
    def func(edges):
        # clamp for softmax numerical stability
        return {field: th.exp((edges.data[field] / scale_constant).clamp(-5, 5))}

    return func


def propagate_attention(self, g, eids):
    # Compute attention score
    g.apply_edges(src_dot_dst('k', 'q', 'score'), eids)
    g.apply_edges(scaled_exp('score', np.sqrt(self.d_k)))
    # Update node state
    g.send_and_recv(eids,
                    [fn.u_mul_e('v', 'score', 'v'), fn.copy_e('score', 'score')],
                    [fn.sum('v', 'wv'), fn.sum('score', 'z')])

预处理和后处理

在 Transformer 中,数据需要在调用propagate_attention函数之前和之后进行预处理和后处理。

预处理 预处理函数pre_func首先对节点表示进行归一化,然后将它们映射到一组查询、键和值,以自注意力为例

\[\begin{split}x \leftarrow \textrm{LayerNorm}(x) \\ [q, k, v] \leftarrow [W_q, W_k, W_v ]\cdot x\end{split}\]

后处理 后处理函数post_funcs完成 Transformer 一层的全部计算:1. 对\(\textrm{wv}\)进行归一化,并获得多头注意力层的输出\(o\)

\[\begin{split}\textrm{wv} \leftarrow \frac{\textrm{wv}}{z} \\ o \leftarrow W_o\cdot \textrm{wv} + b_o\end{split}\]

添加残差连接

\[x \leftarrow x + o\]
  1. \(x\)应用一个两层位置感知前馈层,然后添加残差连接

    \[x \leftarrow x + \textrm{LayerNorm}(\textrm{FFN}(x))\]

    其中\(\textrm{FFN}\)指前馈函数。

class Encoder(nn.Module):
    def __init__(self, layer, N):
        super(Encoder, self).__init__()
        self.N = N
        self.layers = clones(layer, N)
        self.norm = LayerNorm(layer.size)

    def pre_func(self, i, fields='qkv'):
        layer = self.layers[i]
        def func(nodes):
            x = nodes.data['x']
            norm_x = layer.sublayer[0].norm(x)
            return layer.self_attn.get(norm_x, fields=fields)
        return func

    def post_func(self, i):
        layer = self.layers[i]
        def func(nodes):
            x, wv, z = nodes.data['x'], nodes.data['wv'], nodes.data['z']
            o = layer.self_attn.get_o(wv / z)
            x = x + layer.sublayer[0].dropout(o)
            x = layer.sublayer[1](x, layer.feed_forward)
            return {'x': x if i < self.N - 1 else self.norm(x)}
        return func

class Decoder(nn.Module):
    def __init__(self, layer, N):
        super(Decoder, self).__init__()
        self.N = N
        self.layers = clones(layer, N)
        self.norm = LayerNorm(layer.size)

    def pre_func(self, i, fields='qkv', l=0):
        layer = self.layers[i]
        def func(nodes):
            x = nodes.data['x']
            if fields == 'kv':
                norm_x = x # In enc-dec attention, x has already been normalized.
            else:
                norm_x = layer.sublayer[l].norm(x)
            return layer.self_attn.get(norm_x, fields)
        return func

    def post_func(self, i, l=0):
        layer = self.layers[i]
        def func(nodes):
            x, wv, z = nodes.data['x'], nodes.data['wv'], nodes.data['z']
            o = layer.self_attn.get_o(wv / z)
            x = x + layer.sublayer[l].dropout(o)
            if l == 1:
                x = layer.sublayer[2](x, layer.feed_forward)
            return {'x': x if i < self.N - 1 else self.norm(x)}
        return func

这就完成了 Transformer 编码器和解码器一层的所有过程。

注意

子层连接部分与原论文略有不同。然而,此实现与《Transformer 带注释版》OpenNMT中的实现相同。

Transformer 图的主类

Transformer 的处理流程可以看作是在完整图中的两阶段消息传递(适当添加预处理和后处理):1)编码器中的自注意力,2)解码器中的自注意力,随后是编码器和解码器之间的交叉注意力,如下所示。 image4

class Transformer(nn.Module):
    def __init__(self, encoder, decoder, src_embed, tgt_embed, pos_enc, generator, h, d_k):
        super(Transformer, self).__init__()
        self.encoder, self.decoder = encoder, decoder
        self.src_embed, self.tgt_embed = src_embed, tgt_embed
        self.pos_enc = pos_enc
        self.generator = generator
        self.h, self.d_k = h, d_k

    def propagate_attention(self, g, eids):
        # Compute attention score
        g.apply_edges(src_dot_dst('k', 'q', 'score'), eids)
        g.apply_edges(scaled_exp('score', np.sqrt(self.d_k)))
        # Send weighted values to target nodes
        g.send_and_recv(eids,
                        [fn.u_mul_e('v', 'score', 'v'), fn.copy_e('score', 'score')],
                        [fn.sum('v', 'wv'), fn.sum('score', 'z')])

    def update_graph(self, g, eids, pre_pairs, post_pairs):
        "Update the node states and edge states of the graph."

        # Pre-compute queries and key-value pairs.
        for pre_func, nids in pre_pairs:
            g.apply_nodes(pre_func, nids)
        self.propagate_attention(g, eids)
        # Further calculation after attention mechanism
        for post_func, nids in post_pairs:
            g.apply_nodes(post_func, nids)

    def forward(self, graph):
        g = graph.g
        nids, eids = graph.nids, graph.eids

        # Word Embedding and Position Embedding
        src_embed, src_pos = self.src_embed(graph.src[0]), self.pos_enc(graph.src[1])
        tgt_embed, tgt_pos = self.tgt_embed(graph.tgt[0]), self.pos_enc(graph.tgt[1])
        g.nodes[nids['enc']].data['x'] = self.pos_enc.dropout(src_embed + src_pos)
        g.nodes[nids['dec']].data['x'] = self.pos_enc.dropout(tgt_embed + tgt_pos)

        for i in range(self.encoder.N):
            # Step 1: Encoder Self-attention
            pre_func = self.encoder.pre_func(i, 'qkv')
            post_func = self.encoder.post_func(i)
            nodes, edges = nids['enc'], eids['ee']
            self.update_graph(g, edges, [(pre_func, nodes)], [(post_func, nodes)])

        for i in range(self.decoder.N):
            # Step 2: Dncoder Self-attention
            pre_func = self.decoder.pre_func(i, 'qkv')
            post_func = self.decoder.post_func(i)
            nodes, edges = nids['dec'], eids['dd']
            self.update_graph(g, edges, [(pre_func, nodes)], [(post_func, nodes)])
            # Step 3: Encoder-Decoder attention
            pre_q = self.decoder.pre_func(i, 'q', 1)
            pre_kv = self.decoder.pre_func(i, 'kv', 1)
            post_func = self.decoder.post_func(i, 1)
            nodes_e, nodes_d, edges = nids['enc'], nids['dec'], eids['ed']
            self.update_graph(g, edges, [(pre_q, nodes_d), (pre_kv, nodes_e)], [(post_func, nodes_d)])

        return self.generator(g.ndata['x'][nids['dec']])

注意

通过调用update_graph函数,你几乎可以使用相同的代码在任何子图上创建自己的 Transformer。这种灵活性使我们能够发现新的稀疏结构(参考此处提到的局部注意力)。注意,在此实现中,你没有使用 mask 或 padding,这使得逻辑更清晰并节省内存。权衡之下,该实现速度较慢。

训练

本教程不涵盖原论文中提到的标签平滑(Label Smoothing)和 Noam 优化等其他技术。有关这些模块的详细描述,请阅读哈佛大学 NLP 团队撰写的《Transformer 带注释版》

任务和数据集

Transformer 是适用于各种 NLP 任务的通用框架。本教程重点介绍序列到序列学习:这是一个说明其工作原理的典型案例。

至于数据集,有两个示例任务:copy 和 sort,以及两个真实世界的翻译任务:multi30k 英德任务和 wmt14 英德任务。

  • copy 数据集:将输入序列复制到输出。(训练/验证/测试:9000, 1000, 1000)

  • sort 数据集:将输入序列排序作为输出。(训练/验证/测试:9000, 1000, 1000)

  • Multi30k 英德,将句子从英语翻译成德语。(训练/验证/测试:29000, 1000, 1000)

  • WMT14 英德,将句子从英语翻译成德语。(训练/验证/测试:4500966/3000/3003)

注意

使用 wmt14 进行训练需要多 GPU 支持,目前尚不可用。欢迎贡献!

图构建

批量处理 这与处理 Tree-LSTM 的方式类似。预先构建一个图池,包括所有可能的输入长度和输出长度组合。然后对于批次中的每个样本,调用dgl.batch将其大小的图批量处理成一个大型图。

你可以将创建图池和构建 BatchedGraph 的过程封装在dataset.GraphPooldataset.TranslationDataset中。

graph_pool = GraphPool()

data_iter = dataset(graph_pool, mode='train', batch_size=1, devices=devices)
for graph in data_iter:
    print(graph.nids['enc']) # encoder node ids
    print(graph.nids['dec']) # decoder node ids
    print(graph.eids['ee']) # encoder-encoder edge ids
    print(graph.eids['ed']) # encoder-decoder edge ids
    print(graph.eids['dd']) # decoder-decoder edge ids
    print(graph.src[0]) # Input word index list
    print(graph.src[1]) # Input positions
    print(graph.tgt[0]) # Output word index list
    print(graph.tgt[1]) # Ouptut positions
    break

输出

tensor([0, 1, 2, 3, 4, 5, 6, 7, 8], device='cuda:0')
tensor([ 9, 10, 11, 12, 13, 14, 15, 16, 17, 18], device='cuda:0')
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
        36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53,
        54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71,
        72, 73, 74, 75, 76, 77, 78, 79, 80], device='cuda:0')
tensor([ 81,  82,  83,  84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,
         95,  96,  97,  98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108,
        109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122,
        123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136,
        137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150,
        151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164,
        165, 166, 167, 168, 169, 170], device='cuda:0')
tensor([171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184,
        185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198,
        199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212,
        213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225],
       device='cuda:0')
tensor([28, 25,  7, 26,  6,  4,  5,  9, 18], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8], device='cuda:0')
tensor([ 0, 28, 25,  7, 26,  6,  4,  5,  9, 18], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')

整合所有部分

在 copy 任务上训练一个单头、一层、128 维度的 Transformer。其他参数设置为默认值。

本教程不包含推理模块。推理需要 beam search。如需完整实现,请参阅GitHub 仓库

from tqdm.auto import tqdm
import torch as th
import numpy as np

from loss import LabelSmoothing, SimpleLossCompute
from modules import make_model
from optims import NoamOpt
from dgl.contrib.transformer import get_dataset, GraphPool

def run_epoch(data_iter, model, loss_compute, is_train=True):
    for i, g in tqdm(enumerate(data_iter)):
        with th.set_grad_enabled(is_train):
            output = model(g)
            loss = loss_compute(output, g.tgt_y, g.n_tokens)
    print('average loss: {}'.format(loss_compute.avg_loss))
    print('accuracy: {}'.format(loss_compute.accuracy))

N = 1
batch_size = 128
devices = ['cuda' if th.cuda.is_available() else 'cpu']

dataset = get_dataset("copy")
V = dataset.vocab_size
criterion = LabelSmoothing(V, padding_idx=dataset.pad_id, smoothing=0.1)
dim_model = 128

# Create model
model = make_model(V, V, N=N, dim_model=128, dim_ff=128, h=1)

# Sharing weights between Encoder & Decoder
model.src_embed.lut.weight = model.tgt_embed.lut.weight
model.generator.proj.weight = model.tgt_embed.lut.weight

model, criterion = model.to(devices[0]), criterion.to(devices[0])
model_opt = NoamOpt(dim_model, 1, 400,
                    th.optim.Adam(model.parameters(), lr=1e-3, betas=(0.9, 0.98), eps=1e-9))
loss_compute = SimpleLossCompute

att_maps = []
for epoch in range(4):
    train_iter = dataset(graph_pool, mode='train', batch_size=batch_size, devices=devices)
    valid_iter = dataset(graph_pool, mode='valid', batch_size=batch_size, devices=devices)
    print('Epoch: {} Training...'.format(epoch))
    model.train(True)
    run_epoch(train_iter, model,
              loss_compute(criterion, model_opt), is_train=True)
    print('Epoch: {} Evaluating...'.format(epoch))
    model.att_weight_map = None
    model.eval()
    run_epoch(valid_iter, model,
              loss_compute(criterion, None), is_train=False)
    att_maps.append(model.att_weight_map)

可视化

训练后,你可以可视化 Transformer 在 copy 任务上生成的注意力。

src_seq = dataset.get_seq_by_id(VIZ_IDX, mode='valid', field='src')
tgt_seq = dataset.get_seq_by_id(VIZ_IDX, mode='valid', field='tgt')[:-1]
# visualize head 0 of encoder-decoder attention
att_animation(att_maps, 'e2d', src_seq, tgt_seq, 0)

image5 从图中可以看出,解码器节点逐渐学会关注输入序列中对应的节点,这是预期的行为。

多头注意力

除了在玩具任务上训练的单头注意力的可视化之外。我们还可视化了在 multi-30k 数据集上训练的单层 Transformer 网络的编码器自注意力、解码器自注意力和编码器-解码器注意力得分。

从可视化中可以看出不同注意力头的多样性,这是预期的结果。不同的注意力头学习词对之间不同的关系。

  • 编码器自注意力 image6

  • 编码器-解码器注意力 目标序列中的大多数词都关注源序列中与之相关的词,例如:生成“See”(德语)时,一些头关注“lake”;生成“Eisfischerhütte”时,一些头关注“ice”。 image7

  • 解码器自注意力 大多数词关注它们前面的几个词。 image8

自适应通用 Transformer

Google 最近的一篇研究论文《Universal Transformer》是一个示例,展示了update_graph如何适应更复杂的更新规则。

提出 Universal Transformer 是为了解决普通 Transformer 在计算上不够通用的问题,通过在 Transformer 中引入循环机制来实现。

  • Universal Transformer 的基本思想是在每个循环步骤中,通过对表示应用 Transformer 层,来重复修改序列中所有符号的表示。

  • 与普通 Transformer 相比,Universal Transformer 在其层之间共享权重,并且不固定循环时间(即 Transformer 中的层数)。

进一步的优化采用了自适应计算时间 (ACT) 机制,允许模型动态调整序列中每个位置的表示被修改的次数(下文称为步骤)。该模型也称为自适应通用 Transformer (AUT)。

在 AUT 中,你维护一个活跃节点列表。在每个步骤\(t\)中,我们通过以下方式计算此列表中所有节点的停止概率:\(h (0<h<1)\)

\[h^t_i = \sigma(W_h x^t_i + b_h)\]

然后动态决定哪些节点仍然活跃。当且仅当\(\sum_{t=1}^{T-1} h_t < 1 - \varepsilon \leq \sum_{t=1}^{T}h_t\)时,节点在时间\(T\)停止。停止的节点将从列表中移除。该过程持续进行,直到列表为空或达到预定义的最多步骤。从 DGL 的角度来看,这意味着“活跃”图会随时间变得越来越稀疏。

节点\(s_i\)的最终状态是\(x_i^t\)\(h_i^t\)加权平均的结果

\[s_i = \sum_{t=1}^{T} h_i^t\cdot x_i^t\]

在 DGL 中,通过对仍然活跃的节点和与这些节点相关的边调用update_graph来实现算法。以下代码展示了 DGL 中的 Universal Transformer 类

class UTransformer(nn.Module):
    "Universal Transformer(https://arxiv.org/pdf/1807.03819.pdf) with ACT(https://arxiv.org/pdf/1603.08983.pdf)."
    MAX_DEPTH = 8
    thres = 0.99
    act_loss_weight = 0.01
    def __init__(self, encoder, decoder, src_embed, tgt_embed, pos_enc, time_enc, generator, h, d_k):
        super(UTransformer, self).__init__()
        self.encoder,  self.decoder = encoder, decoder
        self.src_embed, self.tgt_embed = src_embed, tgt_embed
        self.pos_enc, self.time_enc = pos_enc, time_enc
        self.halt_enc = HaltingUnit(h * d_k)
        self.halt_dec = HaltingUnit(h * d_k)
        self.generator = generator
        self.h, self.d_k = h, d_k

    def step_forward(self, nodes):
        # add positional encoding and time encoding, increment step by one
        x = nodes.data['x']
        step = nodes.data['step']
        pos = nodes.data['pos']
        return {'x': self.pos_enc.dropout(x + self.pos_enc(pos.view(-1)) + self.time_enc(step.view(-1))),
                'step': step + 1}

    def halt_and_accum(self, name, end=False):
        "field: 'enc' or 'dec'"
        halt = self.halt_enc if name == 'enc' else self.halt_dec
        thres = self.thres
        def func(nodes):
            p = halt(nodes.data['x'])
            sum_p = nodes.data['sum_p'] + p
            active = (sum_p < thres) & (1 - end)
            _continue = active.float()
            r = nodes.data['r'] * (1 - _continue) + (1 - sum_p) * _continue
            s = nodes.data['s'] + ((1 - _continue) * r + _continue * p) * nodes.data['x']
            return {'p': p, 'sum_p': sum_p, 'r': r, 's': s, 'active': active}
        return func

    def propagate_attention(self, g, eids):
        # Compute attention score
        g.apply_edges(src_dot_dst('k', 'q', 'score'), eids)
        g.apply_edges(scaled_exp('score', np.sqrt(self.d_k)), eids)
        # Send weighted values to target nodes
        g.send_and_recv(eids,
                        [fn.u_mul_e('v', 'score', 'v'), fn.copy_e('score', 'score')],
                        [fn.sum('v', 'wv'), fn.sum('score', 'z')])

    def update_graph(self, g, eids, pre_pairs, post_pairs):
        "Update the node states and edge states of the graph."
        # Pre-compute queries and key-value pairs.
        for pre_func, nids in pre_pairs:
            g.apply_nodes(pre_func, nids)
        self.propagate_attention(g, eids)
        # Further calculation after attention mechanism
        for post_func, nids in post_pairs:
            g.apply_nodes(post_func, nids)

    def forward(self, graph):
        g = graph.g
        N, E = graph.n_nodes, graph.n_edges
        nids, eids = graph.nids, graph.eids

        # embed & pos
        g.nodes[nids['enc']].data['x'] = self.src_embed(graph.src[0])
        g.nodes[nids['dec']].data['x'] = self.tgt_embed(graph.tgt[0])
        g.nodes[nids['enc']].data['pos'] = graph.src[1]
        g.nodes[nids['dec']].data['pos'] = graph.tgt[1]

        # init step
        device = next(self.parameters()).device
        g.ndata['s'] = th.zeros(N, self.h * self.d_k, dtype=th.float, device=device)    # accumulated state
        g.ndata['p'] = th.zeros(N, 1, dtype=th.float, device=device)                    # halting prob
        g.ndata['r'] = th.ones(N, 1, dtype=th.float, device=device)                     # remainder
        g.ndata['sum_p'] = th.zeros(N, 1, dtype=th.float, device=device)                # sum of pondering values
        g.ndata['step'] = th.zeros(N, 1, dtype=th.long, device=device)                  # step
        g.ndata['active'] = th.ones(N, 1, dtype=th.uint8, device=device)                # active

        for step in range(self.MAX_DEPTH):
            pre_func = self.encoder.pre_func('qkv')
            post_func = self.encoder.post_func()
            nodes = g.filter_nodes(lambda v: v.data['active'].view(-1), nids['enc'])
            if len(nodes) == 0: break
            edges = g.filter_edges(lambda e: e.dst['active'].view(-1), eids['ee'])
            end = step == self.MAX_DEPTH - 1
            self.update_graph(g, edges,
                              [(self.step_forward, nodes), (pre_func, nodes)],
                              [(post_func, nodes), (self.halt_and_accum('enc', end), nodes)])

        g.nodes[nids['enc']].data['x'] = self.encoder.norm(g.nodes[nids['enc']].data['s'])

        for step in range(self.MAX_DEPTH):
            pre_func = self.decoder.pre_func('qkv')
            post_func = self.decoder.post_func()
            nodes = g.filter_nodes(lambda v: v.data['active'].view(-1), nids['dec'])
            if len(nodes) == 0: break
            edges = g.filter_edges(lambda e: e.dst['active'].view(-1), eids['dd'])
            self.update_graph(g, edges,
                              [(self.step_forward, nodes), (pre_func, nodes)],
                              [(post_func, nodes)])

            pre_q = self.decoder.pre_func('q', 1)
            pre_kv = self.decoder.pre_func('kv', 1)
            post_func = self.decoder.post_func(1)
            nodes_e = nids['enc']
            edges = g.filter_edges(lambda e: e.dst['active'].view(-1), eids['ed'])
            end = step == self.MAX_DEPTH - 1
            self.update_graph(g, edges,
                              [(pre_q, nodes), (pre_kv, nodes_e)],
                              [(post_func, nodes), (self.halt_and_accum('dec', end), nodes)])

        g.nodes[nids['dec']].data['x'] = self.decoder.norm(g.nodes[nids['dec']].data['s'])
        act_loss = th.mean(g.ndata['r']) # ACT loss

        return self.generator(g.ndata['x'][nids['dec']]), act_loss * self.act_loss_weight

调用filter_nodesfilter_edge来查找仍然活跃的节点/边

注意

  • filter_nodes() 接受一个谓词和一个节点 ID 列表/张量作为输入,然后返回满足给定谓词的节点 ID 张量。

  • filter_edges() 接受一个谓词和一个边 ID 列表/张量作为输入,然后返回满足给定谓词的边 ID 张量。

如需完整实现,请参阅GitHub 仓库

下图显示了自适应计算时间的效果。句子中不同位置的表示被修改了不同的次数。

image9

你还可以可视化 AUT 在 sort 任务训练过程中(达到 99.7% 准确率)节点上步骤分布的动态变化,这展示了 AUT 如何在训练过程中学习减少循环步骤。 image10

注意

由于存在许多依赖项,该 Notebook 本身无法直接执行。下载7_transformer.py,并将 Python 脚本复制到目录examples/pytorch/transformer,然后运行python 7_transformer.py以查看其工作原理。

脚本总运行时间: (0 分钟 0.000 秒)

由 Sphinx-Gallery 生成的图库