Skip to content

2.3 注意力机制

在序列建模任务中,模型需要判断"当前应该关注哪里"。传统的循环网络将整个输入序列压缩为一个固定长度的上下文向量,当序列较长时,早期的信息不可避免地被稀释甚至丢失。注意力机制(Attention Mechanism)的核心思想是:不再依赖单一的压缩表示,而是让模型在每一步都能动态地、有选择地"回看"输入序列的所有位置,根据当前需求为不同位置分配不同的权重。这一思想最终发展为 Transformer 架构的基石,彻底改变了自然语言处理乃至整个深度学习的面貌。

本节将从注意力机制的基本框架出发,依次介绍 Query-Key-Value 抽象、三种主要的注意力计分函数、Bahdanau 注意力的历史背景,以及 Transformer 中最核心的多头注意力设计。

2.3.1 Query-Key-Value 框架

注意力机制可以用一个统一的抽象来描述:查询(Query)键(Key)值(Value)

一个直观的比喻是:想象你走进一家自助餐厅。你带着个人口味偏好(Query)去审视每道菜的标签和外观(Key),根据匹配程度决定拿取多少(注意力权重),最终盘子里装的是实际的食物(Value)的加权组合。

形式化地,设有一组键值对 {(k1,v1),(k2,v2),,(kn,vn)} 和一个查询 q。注意力机制的计算分为三步:

  1. 计算注意力分数:用某种评分函数 a(q,ki) 衡量查询与每个键的匹配程度,得到原始分数 ei=a(q,ki)
  2. 归一化为权重:通过 softmax 将分数转化为概率分布,即 αi=exp(ei)j=1nexp(ej),确保所有权重非负且和为 1。
  3. 加权汇聚:用权重对值进行加权求和,得到注意力输出 o=i=1nαivi

注意力机制的 QKV 计算流程

图 2-6:注意力机制的计算流程。查询(Query)与每个键(Key)计算注意力分数,经 softmax 归一化后作为权重对值(Value)进行加权求和,得到最终输出。

整个过程可以简洁地表达为:

Attention(q,K,V)=softmax(a(q,K))V

其中 KV 分别为所有键和值的矩阵形式。这一框架的强大之处在于,通过改变评分函数 a(,) 的定义,可以衍生出不同的注意力变体。

2.3.2 注意力计分函数

评分函数 a(q,k) 决定了查询与键之间"相关性"的度量方式。历史上,研究者提出了三种主要的计分函数。

加性注意力(Additive Attention)。 也称为 Bahdanau 注意力中使用的计分方式。它通过一个前馈网络来学习查询与键之间的匹配关系:

a(q,k)=wvtanh(Wqq+Wkk)

其中 WqRh×dqWkRh×dkwvRh 均为可学习参数,h 为隐藏层维度。加性注意力的优点是不要求查询和键的维度相同(通过 WqWk 分别投影到同一维度 h),灵活性较高。缺点是引入了额外的参数和非线性运算,计算开销相对较大。

点积注意力(Dot-Product Attention)。 当查询和键的维度相同(dq=dk=d)时,最简单的计分方式是直接计算它们的点积:

a(q,k)=qk

点积的几何意义是衡量两个向量的方向一致性:方向越一致,点积越大,注意力权重越高。点积注意力无需额外参数,计算效率高,且能够充分利用矩阵乘法的硬件加速优势。

缩放点积注意力(Scaled Dot-Product Attention)。 纯点积注意力存在一个隐患:当向量维度 dk 较大时,点积结果的方差也随之增大。假设查询和键的各分量都是均值为 0、方差为 1 的独立随机变量,则 qk=i=1dkqiki 的均值为 0,方差为 dk。当 dk 很大时,点积值可能变得非常大或非常小,导致 softmax 输出趋于极端——大部分权重集中在少数几个位置,梯度接近于零,训练难以进行。

为此,Vaswani 等人(2017)在 "Attention Is All You Need" 中提出了除以 dk 进行缩放:

a(q,k)=qkdk

缩放后,点积的方差被重新归一化为 1,softmax 的输入维持在合理范围内,梯度更加稳定。这就是 缩放点积注意力,也是现代 Transformer 中的标准选择。

将上述公式推广到矩阵形式——设查询矩阵 QRn×dk,键矩阵 KRm×dk,值矩阵 VRm×dv——则缩放点积注意力的完整表达式为:

Attention(Q,K,V)=softmax(QKdk)V

其中 QKRn×m 是注意力分数矩阵,其第 i 行第 j 列表示第 i 个查询与第 j 个键之间的匹配程度。softmax 按行进行归一化,使每个查询的注意力权重之和为 1。最后乘以 V 得到 n×dv 的输出矩阵,其中每一行是对应查询的注意力汇聚结果。

三种计分函数的对比如下:

计分函数公式额外参数适用场景
加性注意力wvtanh(Wqq+Wkk)Wq,Wk,wvdqdk
点积注意力qk低维场景
缩放点积注意力qkdkTransformer 标准选择

表 2-3:三种注意力计分函数的对比。

2.3.3 Bahdanau 注意力

注意力机制的概念在 Bahdanau 等人(2015)的工作中被首次系统地引入到序列到序列模型中。在传统的编码器-解码器架构中,编码器将输入序列 (x1,x2,,xT) 逐步处理,最终产生一个固定长度的上下文向量 c,解码器再基于 c 逐步生成输出序列。这种设计的瓶颈在于,所有输入信息必须被压缩到一个固定大小的向量中——对于长序列,这几乎是不可能的。

Bahdanau 注意力在序列到序列模型中的应用

图 2-7:带有注意力机制的编码器-解码器架构(Bahdanau 注意力)。解码器在每一步生成输出时,不再仅依赖固定的上下文向量,而是通过注意力机制动态地关注编码器各位置的隐状态,生成当前步特有的上下文向量。

Bahdanau 注意力的核心改进是:为解码器的每一步生成一个专属的上下文向量。具体而言,设编码器在位置 j 的隐状态为 hj(这里 hj 同时充当键和值),解码器在时间步 t 的隐状态为 st(充当查询),则注意力计算过程为:

etj=a(st1,hj)=wvtanh(Wqst1+Wkhj)αtj=exp(etj)k=1Texp(etk)ct=j=1Tαtjhj

最终,上下文向量 ct 与解码器输入拼接后送入 RNN 单元,生成当前步的输出。注意,不同的时间步 t 会产生不同的注意力权重分布 {αt1,αt2,,αtT},从而动态聚焦于输入序列的不同部分。

Bahdanau 注意力属于交叉注意力(Cross-Attention)的范畴——查询来自解码器,键和值来自编码器,是两个不同序列之间的注意力交互。这与后来 Transformer 中的自注意力(Self-Attention)有本质区别:自注意力中,查询、键和值都来自同一个序列内部,每个位置都在与同序列的其他位置(包括自身)进行交互。

从 QKV 框架的视角来看,Bahdanau 注意力使用的是加性计分函数,且键和值是共享的(都是编码器隐状态)。虽然加性计分函数的表达能力强,但由于涉及额外的可学习参数矩阵和非线性运算,在序列长度较大时计算效率不如点积方式。这也是 Transformer 最终采用缩放点积注意力的重要原因之一。

2.3.4 自注意力与矩阵运算

在 Transformer 架构中,注意力机制被应用于一种全新的场景——自注意力(Self-Attention)。不同于 Bahdanau 注意力中查询和键分属不同序列,自注意力中查询、键和值都从同一个输入序列派生而来。每个位置的 token 都在与序列中所有其他位置(包括自身)进行交互,从而根据上下文动态更新自身的表示。

以句子"我爱吃苹果"为例。对于"苹果"这个 token,我们希望它能注意到"吃"这个词,从而理解这里的"苹果"是水果而非品牌。自注意力正是通过让每个 token 携带的查询向量与所有 token 的键向量进行匹配来实现这一点。

具体地,设输入序列的嵌入矩阵为 XRn×dmodeln 为序列长度,dmodel 为嵌入维度),通过三个可学习的线性变换生成 QKV

Q=XWQ,K=XWK,V=XWV

其中 WQ,WKRdmodel×dkWVRdmodel×dv。然后执行缩放点积注意力:

Attention(Q,K,V)=softmax(QKdk)V

这个过程可以完全通过矩阵乘法实现,没有任何循环依赖——所有 token 的注意力更新可以同步并行计算,这正是 Transformer 能够充分利用 GPU 并行能力的关键优势。

在注意力矩阵 softmax(QKdk)Rn×n 中,第 i 行第 j 列的元素 αij 表示第 i 个 token 对第 j 个 token 的注意力权重。通过这个权重矩阵左乘值矩阵 V,每个 token 获得了一个根据上下文加权汇聚的新表示。

Mask 机制。 在实际应用中,注意力计算往往需要配合掩码(Mask)。例如在自回归语言模型中,第 i 个 token 只能看到位置 1,2,,i,不能"偷看"未来的信息。为此,在注意力分数矩阵上对未来位置填充 (或一个极大的负数,如 109),经 softmax 后这些位置的权重变为 0,实现了因果掩码(Causal Mask)。此外,对于 padding 位置的 token,也需要通过掩码将其排除在注意力计算之外。

2.3.5 多头注意力

单一的注意力函数只能从一种"视角"捕捉序列内的依赖关系。然而在自然语言中,一个词可能同时承载多种维度的信息——句法角色、语义关联、单复数、时态等。如果只用一组 Q、K、V 进行注意力计算,很难同时捕捉所有这些层面的关系。

多头注意力(Multi-Head Attention)的思想是:与其用一个"大注意力"做所有事情,不如将查询、键和值分别投影到 h 个不同的低维子空间中,在每个子空间内独立地执行注意力计算,最后将结果拼接起来。每个子空间被称为一个"头"(Head),不同的头可以学习关注不同类型的依赖关系。

多头注意力结构

图 2-8:多头注意力机制。输入的 Q、K、V 经线性投影后分成多个头,各头独立执行缩放点积注意力,最后拼接并通过线性变换得到最终输出。

形式化地,给定查询 qRdq、键 kRdk、值 vRdv,第 i 个注意力头 (i=1,,h) 的计算为:

headi=Attention(qWiQ,kWiK,vWiV)

其中 WiQRdmodel×dkWiKRdmodel×dkWiVRdmodel×dv 为第 i 个头的投影矩阵。在标准设计中,取 dk=dv=dmodel/h,这样 h 个头的总参数量与单头注意力大致相当。

所有头的输出拼接后,再通过一个线性变换映射回原始维度:

MultiHead(Q,K,V)=Concat(head1,,headh)WO

其中 WOR(hdv)×dmodel 为输出投影矩阵。

展开完整的计算过程:以 dmodel=512h=8 为例,每个头的维度为 dk=dv=64。对于一个长度为 n 的序列,输入 XRn×512 首先通过 WQ,WK,WVR512×512 得到 Q,K,VRn×512,然后将最后一维 reshape 为 (h,dk)=(8,64),转置为 (8,n,64),即 8 个头各自拥有 n×64 的 Q、K、V 矩阵,独立执行缩放点积注意力。计算结果 (8,n,64) 再转置、拼接回 (n,512),最后通过 WO 得到输出。

多头注意力的关键优势在于:不同的头可以自发地学习到不同类型的注意力模式。研究者通过可视化发现,部分头倾向于捕捉局部的相邻词依赖,部分头关注长距离的语法结构,还有一些头会聚焦于特定的语义角色。这种"分工合作"机制使得 Transformer 的表达能力远超单头注意力。

2.3.6 PyTorch 实现

以下提供一个自包含的 PyTorch 实现,涵盖缩放点积注意力和多头注意力。

python
import torch
import torch.nn as nn
import math


def scaled_dot_product_attention(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    mask: torch.Tensor | None = None,
    dropout: nn.Dropout | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
    """缩放点积注意力。

    Args:
        query: 形状 (batch, ..., seq_q, d_k)
        key:   形状 (batch, ..., seq_k, d_k)
        value: 形状 (batch, ..., seq_k, d_v)
        mask:  可选,形状可广播至 (batch, ..., seq_q, seq_k),
               值为 0 的位置将被屏蔽
        dropout: 可选的 Dropout 层,作用于注意力权重

    Returns:
        output: 形状 (batch, ..., seq_q, d_v)
        attn_weights: 形状 (batch, ..., seq_q, seq_k)
    """
    d_k = query.size(-1)
    # (batch, ..., seq_q, seq_k)
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)

    if mask is not None:
        scores = scores.masked_fill(mask == 0, float("-inf"))

    attn_weights = torch.softmax(scores, dim=-1)

    if dropout is not None:
        attn_weights = dropout(attn_weights)

    output = torch.matmul(attn_weights, value)
    return output, attn_weights


class MultiHeadAttention(nn.Module):
    """多头注意力模块。

    Args:
        d_model: 模型隐藏维度
        num_heads: 注意力头数
        dropout: 注意力权重的 dropout 概率
        bias: 线性投影是否包含偏置项
    """

    def __init__(
        self,
        d_model: int,
        num_heads: int,
        dropout: float = 0.0,
        bias: bool = False,
    ):
        super().__init__()
        assert d_model % num_heads == 0, "d_model 必须能被 num_heads 整除"

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads  # 每个头的维度

        self.W_q = nn.Linear(d_model, d_model, bias=bias)
        self.W_k = nn.Linear(d_model, d_model, bias=bias)
        self.W_v = nn.Linear(d_model, d_model, bias=bias)
        self.W_o = nn.Linear(d_model, d_model, bias=bias)

        self.dropout = nn.Dropout(dropout) if dropout > 0 else None

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        mask: torch.Tensor | None = None,
    ) -> torch.Tensor:
        """
        Args:
            query: (batch, seq_q, d_model)
            key:   (batch, seq_k, d_model)
            value: (batch, seq_k, d_model)
            mask:  可选,(batch, 1, seq_q, seq_k) 或可广播形状

        Returns:
            output: (batch, seq_q, d_model)
        """
        batch_size = query.size(0)

        # 线性投影: (batch, seq, d_model) -> (batch, seq, d_model)
        Q = self.W_q(query)
        K = self.W_k(key)
        V = self.W_v(value)

        # 拆分多头: (batch, seq, d_model) -> (batch, num_heads, seq, d_k)
        Q = Q.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        K = K.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        V = V.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)

        # 缩放点积注意力: (batch, num_heads, seq_q, d_k)
        output, _ = scaled_dot_product_attention(Q, K, V, mask, self.dropout)

        # 合并多头: (batch, num_heads, seq_q, d_k) -> (batch, seq_q, d_model)
        output = (
            output.transpose(1, 2)
            .contiguous()
            .view(batch_size, -1, self.d_model)
        )

        # 输出投影
        return self.W_o(output)

代码解读。 scaled_dot_product_attention 函数实现了标准的缩放点积注意力:计算 QK/dk,应用 mask(如有),softmax 归一化,最后乘以 VMultiHeadAttention 类在此基础上实现了多头机制:通过四个线性层分别生成 Q、K、V 和输出投影,利用 viewtranspose 操作将 (batch, seq, d_model) 重塑为 (batch, num_heads, seq, d_k) 以实现多头的并行计算,最后再拼接回原始维度。

可以用以下代码快速验证:

python
# 验证多头注意力的输入输出形状
batch_size, seq_len, d_model, num_heads = 2, 10, 512, 8
mha = MultiHeadAttention(d_model, num_heads, dropout=0.1)
x = torch.randn(batch_size, seq_len, d_model)
output = mha(x, x, x)  # 自注意力: Q=K=V=x
print(output.shape)     # torch.Size([2, 10, 512])

querykeyvalue 传入相同的张量时,这就是自注意力;当 query 来自解码器、keyvalue 来自编码器时,这就是交叉注意力。同一个 MultiHeadAttention 类可以同时支持这两种用法。

本节小结

本节系统介绍了注意力机制的核心概念和实现:

  • QKV 框架 是注意力机制的统一抽象:查询与键计算匹配分数,softmax 归一化为权重后对值加权求和。
  • 三种计分函数 各有适用场景:加性注意力灵活但较慢,点积注意力高效但高维时不稳定,缩放点积注意力通过除以 dk 解决了方差问题,成为 Transformer 的标准选择。
  • Bahdanau 注意力 首次将注意力机制引入序列到序列模型,让解码器在每一步动态关注编码器的不同位置,突破了固定长度上下文向量的瓶颈。
  • 自注意力 让序列内部的 token 互相交互,可以完全通过矩阵运算并行计算,是 Transformer 摒弃循环结构的关键。
  • 多头注意力 通过在多个子空间中独立执行注意力计算,使模型能够同时捕捉不同类型的依赖关系——局部语法、长距离语义、句法结构等——大幅增强了模型的表达能力。

这些概念共同构成了 Transformer 架构的注意力基础。下一节将在此基础上,完整介绍 Transformer 的编码器-解码器结构、位置编码、层归一化等组件,展示这些零件如何组装成一个完整的模型。