Skip to content

6.3 MLA(多头潜在注意力)

上一节介绍的 GQA 通过让多个查询头共享同一组 Key/Value 来压缩 KV 缓存,其本质是一种结构化的低秩约束——同一组内的所有头被强制使用完全相同的 K 和 V 向量。这种"复制"策略虽然简单高效,但在表达能力上存在天然上限:共享 KV 的头无法从相同输入中提取不同的键值信息。

多头潜在注意力(Multi-Head Latent Attention, MLA)提供了一条截然不同的压缩路径。它不再限制哪些头共享 KV,而是将所有头的 Key 和 Value 联合投影到一个低维潜在空间,推理时仅缓存这个低维向量;当实际需要计算注意力时,再将其投影回高维空间恢复出各头的 K 和 V。MLA 由 DeepSeek-V2 首次提出并引入实际部署,随后在 DeepSeek-V3 和 DeepSeek-R1 中延续使用。消融实验表明,在同等 KV 缓存预算下,MLA 的建模性能不仅超过 GQA,甚至略优于标准 MHA。

本节将从数学原理出发,完整推导 MLA 的低秩压缩机制、矩阵吸收技巧和解耦位置编码设计,并给出 PyTorch 参考实现。

MLA 多头潜在注意力架构示意

图 6-9:MLA 架构概览。Key 和 Value 被联合投影到低维潜在空间,推理时仅缓存潜向量,通过矩阵吸收避免恢复高维表示。

6.3.1 核心思想:低秩联合压缩

MHA 的 KV 缓存瓶颈。 回顾标准 MHA:对于第 t 个 token 的输入向量 xtRd,第 i 个注意力头通过独立的投影矩阵生成键和值:

kt(i)=xtWK(i),vt(i)=xtWV(i),WK(i),WV(i)Rd×dh

其中 d 为隐层维度,h 为注意力头数,dh=d/h 为每头维度。KV 缓存需要存储每个 token 的全部头的 Key 和 Value,单 token 缓存量为 2×h×dh=2d 个浮点数。

GQA 的约束本质。 GQA 通过将 h 个头划分为 G 组,组内共享 K/V,使缓存量降至 2×G×dh。但从线性代数视角看,这等价于强制同一组内的投影矩阵 WK(i) 完全相同——一种非常"刚性"的低秩约束。

MLA 的柔性低秩压缩。 MLA 引入一个维度为 dcd 的潜在空间。对于第 t 个 token,首先将输入 xt 通过一个下采样矩阵 WDKVRd×dc 投影到低维:

ctKV=xtWDKVRdc

这个低维向量 ctKV 称为 KV 潜向量(latent vector)。然后通过各头独立的上采样矩阵恢复出高维的 Key 和 Value:

kt(i)=ctKVWUK(i),vt(i)=ctKVWUV(i),WUK(i),WUV(i)Rdc×dh

展开可以看到,MLA 中第 i 头的 Key 实际上是 kt(i)=xtWDKVWUK(i),等效于一个秩不超过 dc 的投影。与 GQA 的"复制共享"不同,MLA 允许每个头通过独立的上采样矩阵从同一个潜向量中提取不同的信息,提供了更大的表达灵活度。

此时的问题。 如果按上述公式直接计算,推理阶段仍然需要将 ctKV 乘以上采样矩阵恢复出完整的多头 K/V,然后将恢复后的高维向量存入 KV 缓存——缓存大小依然是 2×h×dh,压缩毫无意义。要真正实现缓存压缩,需要一个关键的代数技巧:矩阵吸收

6.3.2 矩阵吸收:真正实现缓存压缩

矩阵吸收利用的是矩阵乘法的结合律(Associativity),将上采样矩阵从 KV 端转移到 Query 端和输出端,从而使推理时无需显式恢复高维 K/V。

Key 端的吸收。 考虑第 i 个头的注意力分数(省略缩放因子):

Scoret,j(i)=qt(i)(kj(i))T

代入 MLA 的低秩表达式:

Scoret,j(i)=(xtWQ(i))(cjKVWUK(i))T=xtWQ(i)(WUK(i))T(cjKV)T

注意 WQ(i)(WUK(i))T 是两个常数矩阵的乘积,可以在推理前离线预计算为一个吸收矩阵

W~Q(i)=WQ(i)(WUK(i))TRd×dc

推理时,Query 的计算变为 q~t(i)=xtW~Q(i)Rdc,注意力分数直接用潜向量计算:

Scoret,j(i)=q~t(i)(cjKV)T

Value 端的吸收。 类似地,第 i 个头的注意力输出为:

ot(i)=jAttnt,jvj(i)=jAttnt,jcjKVWUV(i)

多头输出拼接后需要乘以输出投影矩阵 WO。根据结合律,WUV(i) 可以被吸收到 WO 中,推理时直接对潜向量 cjKV 做加权求和,再乘以预计算的吸收矩阵即可。

吸收后的推理流程。 经过 Key 端和 Value 端的矩阵吸收,推理阶段的 KV 缓存中只需要存储 cjKVRdc,无需恢复出任何高维 K/V 向量。单 token 的 KV 缓存从 MHA 的 2×h×dh=2d 个元素骤降为 dc 个元素。

代价是推理时多了一步矩阵乘法(吸收矩阵与潜向量的乘法)。但在 LLM 解码阶段,瓶颈是显存带宽而非计算量(Memory-Bound),用少量额外计算换取大幅缓存压缩,在吞吐量上是净收益。

6.3.3 解耦旋转位置编码

MLA 低秩投影的第一步

图 6-10:MLA 低秩投影过程。Key 和 Value 被联合投影到低维潜在空间,潜向量的维度远小于原始 KV 维度,实现缓存压缩。

上述矩阵吸收有一个前提:WQ(i)(WUK(i))T 之间不能有依赖于 token 位置的动态矩阵。然而,主流 LLM 广泛采用的旋转位置编码(RoPE)恰好破坏了这一前提。

RoPE 导致吸收失效。 RoPE 通过一个依赖于绝对位置 t 的正交矩阵 Rt 作用于 Query 和 Key。施加 RoPE 后,注意力分数变为:

Scoret,j(i)=(xtWQ(i)Rt)(cjKVWUK(i)Rj)T=xtWQ(i)RtRjT=Rtj(WUK(i))T(cjKV)T

位置相关的动态矩阵 Rtj 夹在 WQ(i)(WUK(i))T 之间,两个常数矩阵无法预先合并为吸收矩阵。如果不进行吸收,就必须在每步推理时将 cjKV 动态恢复为完整的 kj(i) 以注入位置信息,KV 缓存压缩的优势完全丧失。

解耦策略。 MLA 的解决方案是将 Query 和 Key 各自拆分为内容部分(不含 RoPE)和位置部分(含 RoPE),两部分拼接后共同参与注意力计算。

具体地,引入一个较小的 RoPE 专用维度 dr(DeepSeek-V2 中取 dr=64):

Query 解耦:

qt(i)=[qt,content(i)dh 维,qt,ropedr 维]
  • 内容部分:qt,content(i)=ctQWUQ(i)(来自 Query 潜向量,不受 RoPE 影响)
  • 位置部分:qt,rope=RoPE(ctQWQR)(施加 RoPE,所有头共享)

Key 解耦:

kj(i)=[kj,content(i)dh 维,kj,ropedr 维]
  • 内容部分:kj,content(i)=cjKVWUK(i)(不受 RoPE 影响)
  • 位置部分:kj,rope=RoPE(xjWKR)(施加 RoPE,所有头共享)

拼接后的注意力分数:

Scoret,j(i)=1dh+dr[qt,content(i)(kj,content(i))T内容项:无 RoPE,可矩阵吸收+qt,rope(kj,rope)T位置项:含 RoPE,单独计算]

通过这种解耦:

  1. 内容项完全遵循 6.3.2 节的结合律,推理时只需缓存 cjKV
  2. 位置项需要额外缓存所有头共享的 kj,ropeRdr,但 dr 很小(通常为 64)。

因此,MLA 推理阶段每个 token 的实际 KV 缓存大小为 dc+dr 个元素。

6.3.4 PyTorch 实现

以下是 MLA 注意力模块的参考实现。为突出核心机制,省略了解耦 RoPE 和矩阵吸收优化,采用"训练模式"的显式低秩压缩—恢复流程:

python
import torch
import torch.nn as nn


class MultiHeadLatentAttention(nn.Module):
    def __init__(self, d_in, d_out, num_heads, latent_dim,
                 dropout=0.0, qkv_bias=False):
        super().__init__()
        assert d_out % num_heads == 0, "d_out must be divisible by num_heads"

        self.num_heads = num_heads
        self.head_dim = d_out // num_heads
        self.d_out = d_out
        self.latent_dim = latent_dim

        # 下采样:将输入投影到低维潜在空间
        self.W_DKV = nn.Linear(d_in, latent_dim, bias=qkv_bias)

        # 上采样:从潜向量恢复多头 Key 和 Value
        self.W_UK = nn.Linear(latent_dim, d_out, bias=False)
        self.W_UV = nn.Linear(latent_dim, d_out, bias=False)

        # Query 投影(不压缩,与标准 MHA 相同)
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out, bias=False)
        self.dropout = nn.Dropout(dropout)

        # KV Cache 缓冲区:缓存低维潜向量而非高维 K/V
        self.register_buffer("cache_c", None, persistent=False)
        self.ptr_current_pos = 0

    def forward(self, x, use_cache=False):
        b, seq_len, _ = x.shape

        # Query:标准投影
        Q = self.W_query(x)
        Q = Q.view(b, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

        # KV:先下采样到潜在空间
        c_kv = self.W_DKV(x)  # (b, seq_len, latent_dim)

        # 缓存逻辑:缓存的是低维潜向量,而非高维 K/V
        if use_cache:
            if self.cache_c is None:
                self.cache_c = c_kv
            else:
                self.cache_c = torch.cat([self.cache_c, c_kv], dim=1)
            c_kv_full = self.cache_c
        else:
            c_kv_full = c_kv

        # 从潜向量恢复高维 Key 和 Value
        K = self.W_UK(c_kv_full)
        V = self.W_UV(c_kv_full)

        K = K.view(b, -1, self.num_heads, self.head_dim).transpose(1, 2)
        V = V.view(b, -1, self.num_heads, self.head_dim).transpose(1, 2)

        # 缩放点积注意力 + 因果掩码
        scale = self.head_dim ** 0.5
        attn_scores = Q @ K.transpose(-2, -1) / scale

        num_q = Q.size(2)
        num_k = K.size(2)
        if use_cache and num_q < num_k:
            # Decode 阶段:动态构建因果掩码
            offset = num_k - num_q
            row_idx = torch.arange(num_q, device=x.device).unsqueeze(1)
            col_idx = torch.arange(num_k, device=x.device).unsqueeze(0)
            mask = col_idx > row_idx + offset
        else:
            mask = torch.triu(
                torch.ones(num_q, num_k, device=x.device, dtype=torch.bool),
                diagonal=1
            )
        attn_scores = attn_scores.masked_fill(mask, float("-inf"))

        attn_weights = torch.softmax(attn_scores, dim=-1)
        attn_weights = self.dropout(attn_weights)

        context = (attn_weights @ V).transpose(1, 2).contiguous()
        context = context.view(b, num_q, self.d_out)
        return self.out_proj(context)

    def reset_cache(self):
        self.cache_c = None
        self.ptr_current_pos = 0

代码要点解读:

  1. W_DKV 是关键新增模块——它将 d 维输入压缩到 dc 维(latent_dim),是 MLA 相对于 MHA 的核心结构差异。
  2. 缓存的对象是 c_kv(低维潜向量),而非高维 K/V 向量。 这是 MLA 节省 KV 缓存内存的直接来源:缓存从每 token 2×d 个元素降至 dc 个元素。
  3. W_UKW_UV 在每次注意力计算时在线执行,将潜向量恢复为高维 K/V。这增加了推理计算量,但由于解码阶段受限于显存带宽而非计算吞吐,这一代价在实际部署中通常可以忽略。
  4. 训练模式 vs 推理模式。 上述实现采用的是训练模式——显式计算高维 K/V 后再做注意力。在优化后的推理模式中,可以通过矩阵吸收(6.3.2 节)将 W_UK 合并到 Query 投影、W_UV 合并到输出投影,从而在推理时完全跳过高维 K/V 的恢复过程。

6.3.5 KV 缓存内存对比

MLA 矩阵吸收的推理优化

图 6-11:MLA 矩阵吸收过程。推理时将上采样矩阵吸收到 Query 和输出投影中,直接对潜向量计算注意力,避免恢复高维 KV 表示。

MLA 的 KV 缓存公式与 MHA/GQA 有本质不同。MHA 缓存的是高维 K/V 向量对,而 MLA 缓存的是低维潜向量(加上解耦 RoPE 的位置向量):

方法单 token 单层缓存元素数公式
MHA2×h×dh2d
GQA(G 组)2×G×dh2Gdh
MLAdc+dr潜向量 + RoPE 键

MLA 的总 KV 缓存字节数为:

KV_bytesMLA=B×L×S×(dc+dr)×bytes_per_elem

其中 B 为 batch size,L 为层数,S 为序列长度。注意 MLA 的公式中没有因子 2(K/V 两份),也没有头数 h——这是因为所有头共享同一个潜向量 ctKV,RoPE 键 ktrope 也在所有头间共享。

对比一:理论配置。 使用 DeepSeek-V2 的参数(d=5120h=128dh=128dc=512dr=64):

方法单 token 单层缓存压缩比(相对 MHA)
MHA2×128×128=32,768 元素1x
MLA512+64=576 元素56.9x

这是一个惊人的 57 倍压缩,远超 GQA 通常能达到的 4-8 倍。

对比二:实验配置。 使用一个小规模对比实验的参数(emb_dim=768n_heads=24n_layers=12context_length=32768latent_dim=192batch_size=1dtype=bf16):

MHA 的 KV 缓存:

1×12×32,768×768×2×2=1,207,959,552 bytes1.21 GB

MLA 的 KV 缓存(仅计算潜向量,dr 在此简化实现中未引入):

1×12×32,768×192×2=150,994,944 bytes0.15 GB

压缩比为 (768×2)/192=8 倍。在包含模型参数和前馈层的完整模型中,实测端到端内存从 MHA 的 1.54 GB 降至 MLA 的 0.68 GB——KV 缓存被大幅压缩,而模型参数和前馈层的固定开销稀释了压缩比例。

配置KV 缓存模型总内存(含 FFN)端到端压缩
MHA(emb_dim=768~1.21 GB1.54 GB1.0x
MLA(latent_dim=192~0.15 GB0.68 GB2.3x

表 6-4:MHA 与 MLA 端到端内存对比(emb_dim=768, n_heads=24, n_layers=12, context_length=32768, bf16)。KV 缓存本身压缩 8 倍,但模型参数的固定开销使整体压缩比降至约 2.3 倍。

需要注意的是,KV 缓存的节省效果与上下文长度成正比。在上述配置中,当上下文长度从 32K 增至 128K,KV 缓存将成为总内存的主导项,MLA 的 8 倍压缩优势将更加充分地体现。

MLA 矩阵吸收过程示意

图 6-12:MLA 矩阵吸收技巧。将 Key 端的上采样矩阵吸收到 Query 投影中,Value 端的上采样矩阵吸收到输出投影中,推理时直接对潜向量计算注意力。

6.3.6 DeepSeek-V2 消融实验

DeepSeek-V2 论文提供了一组关键消融实验,直接对比了 MHA、GQA 和 MLA 在相同训练设置下的建模性能(以困惑度 / Perplexity 或下游任务准确率衡量)。实验结论可以总结为两点:

  1. GQA 性能低于 MHA。 在同等模型规模下,GQA 虽然节省了 KV 缓存,但在建模质量上出现了可测量的下降。这与预期一致——GQA 强制多个头共享完全相同的 K/V,相当于对注意力空间施加了刚性约束,限制了模型的表达能力。

  2. MLA 性能略优于 MHA。 在 KV 缓存压缩幅度远超 GQA 的情况下,MLA 不仅没有损失建模质量,反而略微超过了标准 MHA。这一看似反直觉的结果有合理的解释:低秩投影充当了一种信息瓶颈(Information Bottleneck),迫使模型学习更紧凑、更具泛化性的 KV 表示,类似于自编码器中瓶颈层的正则化效果。

这一消融结果是 DeepSeek 团队选择 MLA 而非 GQA 的核心依据。从工程部署角度看,MLA 实现了"缓存更小、性能更好"的理想组合——这在 GQA 的框架下是不可能达到的。

6.3.7 训练与推理的模式差异

MLA 在训练和推理阶段的计算模式存在显著差异,这是理解其工程实现的关键:

训练阶段采用类似 MHA 的并行计算模式,显式地计算所有头的 qt(i)kt(i)vt(i)。此时不进行矩阵吸收,因为训练需要通过 WUKWUV 的梯度来更新这些上采样矩阵。Query 同样可以进行低秩压缩(引入 ctQRdc)以减少训练阶段的激活值显存占用,但这与推理时的 KV 缓存无关。

推理阶段分为三步:

  1. 离线预处理:将上采样矩阵分别吸收到 WQWO 中,预计算吸收矩阵。
  2. Prefill 阶段:并行计算所有 token 的 ctKVktrope,存入缓存。
  3. Decode 阶段:形式上退化为等效的 MQA——所有头共享同一个 cjKV 进行注意力计算,访存量极大降低。

第三步值得特别说明:由于矩阵吸收后 Query 的维度变为 dc、Value 的等效维度也变为 dc,Decode 阶段的矩阵乘法计算量(FLOPs)实际上高于标准 MHA。但 LLM 解码阶段是极度的 Memory-Bound,用富余的计算资源换取极低的访存带宽,这正是 MLA 能够大幅提升推理吞吐量的本质原因。

本节小结

本节介绍了多头潜在注意力(MLA)的原理与实现:

  • 核心思想:将所有头的 Key 和 Value 联合投影到一个低维潜在空间(维度 dc),推理时仅缓存低维潜向量 ctKV,通过矩阵吸收技巧避免恢复高维 K/V。与 GQA 的"复制共享"不同,MLA 允许每个头从共享潜向量中提取不同信息,提供更大的表达灵活度。
  • 矩阵吸收:利用矩阵乘法的结合律,将 Key 端的上采样矩阵 WUK 吸收到 Query 投影中,将 Value 端的 WUV 吸收到输出投影中。吸收后推理时只需对潜向量执行注意力计算,无需显式恢复高维 K/V。
  • 解耦 RoPE:RoPE 的位置相关矩阵会阻断矩阵吸收。MLA 将 Q/K 拆分为内容部分(可吸收)和位置部分(含 RoPE,单独计算),额外缓存一个所有头共享的 ktropeRdr
  • 内存收益:单 token 单层 KV 缓存从 MHA 的 2d 个元素降至 dc+dr 个元素。在 DeepSeek-V2 配置下实现约 57 倍压缩;在小规模实验配置下(latent_dim=192),KV 缓存压缩 8 倍,端到端模型内存从 1.54 GB 降至 0.68 GB。
  • 建模性能:DeepSeek-V2 消融实验表明,MLA 在 KV 缓存大幅压缩的前提下,建模性能不仅不低于 MHA,反而略有提升——低秩瓶颈起到了正则化效果。这是 MLA 优于 GQA 的核心优势。

延伸阅读:MLA 的高效推理实现 FlashMLA 将在 §9.4 中详细讨论,包括其针对 GPU 内存层次的优化策略。