Skip to content

6.1 KV Cache

大语言模型的推理过程是逐 token 生成的:给定已有序列,模型预测下一个 token,将其拼接到序列末尾,再用更长的序列预测再下一个 token。这一自回归生成循环中隐藏着大量冗余计算——每一步都需要对所有历史 token 重新计算注意力中的键(Key)和值(Value)向量,而这些向量在之前的步骤中已经计算过。KV Cache 正是针对这一冗余提出的优化技术:将已计算的 K、V 向量缓存起来,后续步骤直接复用,从而将逐 token 生成的计算复杂度从二次降到线性。本节将从原理出发,完整实现一个带 KV Cache 的 GPT 推理系统,并进一步讨论预分配张量和滑动窗口截断等工程优化。

自注意力机制中的 Q/K/V 计算流程

图 6-1:自注意力机制计算流程。每个 token 生成 Query、Key、Value 向量,通过注意力矩阵计算加权和。KV Cache 缓存历史 K/V 避免重复计算。

6.1.1 为什么需要 KV Cache

自回归生成的冗余计算。 回顾注意力机制的计算过程:对于输入序列中的每个 token,模型通过线性投影分别生成 Query(Q)、Key(K)、Value(V)三个向量,然后计算注意力分数 Attention(Q,K,V)=softmax(QKTdk)V

假设模型正在处理提示词 "Time flies",此时注意力计算涉及两个 token 的 K、V 向量。当模型生成了新 token "fast" 后,下一轮输入变为 "Time flies fast",需要重新计算三个 token 的所有 K、V 向量。但仔细观察可以发现:"Time" 和 "flies" 的 K、V 向量与上一轮完全相同——它们仅取决于对应位置的输入 embedding 和线性投影权重,与序列中后续 token 无关。

将这一观察推广到整个生成过程。设生成序列长度为 T,在第 t 步(t=1,2,,T),不使用缓存的朴素实现需要计算 t 个 token 的 K、V 向量。整个生成过程的总计算量正比于:

t=1Tt=T(T+1)2=O(T2)

这意味着生成长度翻倍,计算量将翻四倍。

KV Cache 的核心思想。 既然前 t1 个 token 的 K、V 向量不会改变,我们只需在第一步计算全部 K、V,之后每步仅为新生成的单个 token 计算一组 K、V,然后将其追加到缓存中。这样每步的新增计算量为常数 O(1)(仅一个 token 的投影),整个生成过程的投影计算总量降为 O(T)

以四步生成过程 "遥→遥→领→先" 为例,展开注意力计算公式:

Att1=softmax(Q1K1T)V1Att2=softmax(Q2[K1,K2]T)[V1,V2]Att3=softmax(Q3[K1,K2,K3]T)[V1,V2,V3]Att4=softmax(Q4[K1,K2,K3,K4]T)[V1,V2,V3,V4]

关键观察:

  1. Attt 只需要 Qt——当前 token 的 Query 向量。历史 token 的 Query 不需要重新计算,因为它们的输出已经在之前的步骤中使用过了。
  2. K1,,Kt1V1,,Vt1 可以复用——它们在之前的步骤中已经计算过,且值不会改变。

因此,KV Cache 的工作流程分为两个阶段:

  • Prefill(预填充)阶段:将完整的提示词序列一次性送入模型,计算所有 token 的 K、V 向量并存入缓存。
  • Decode(解码)阶段:每步仅将新生成的单个 token 送入模型,计算其 Q、K、V 向量,其中 K、V 追加到缓存,Q 与缓存中的所有 K 计算注意力分数。

6.1.2 基础实现

注意力权重矩阵与因果掩码

图 6-2:注意力权重矩阵的可视化。因果掩码确保每个位置只能关注自身及之前的 token,形成下三角矩阵模式。

下面在 MultiHeadAttention 类中实现 KV Cache。核心改动只有三处:注册缓存缓冲区、在前向传播中维护缓存、提供缓存重置方法。

第一步:注册缓存缓冲区。MultiHeadAttention 的构造函数中,使用 register_buffer 注册两个缓冲区 cache_kcache_v,并维护一个位置指针 ptr_current_pos 追踪当前缓存中已填充的 token 数:

python
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout,
                 num_heads, qkv_bias=False):
        super().__init__()
        assert d_out % num_heads == 0, "d_out must be divisible by num_heads"

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads

        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            "mask",
            torch.triu(torch.ones(context_length, context_length), diagonal=1),
            persistent=False
        )

        # KV Cache 缓冲区
        self.register_buffer("cache_k", None, persistent=False)
        self.register_buffer("cache_v", None, persistent=False)
        self.ptr_current_pos = 0  # 位置指针,追踪已缓存的 token 数

使用 register_buffer 而非普通属性,确保缓存会随模型一起迁移到目标设备(CPU/GPU),但不会被优化器视为可训练参数。设置 persistent=False 表示缓存不需要被保存到 state_dict 中——它是纯推理时的临时状态。

第二步:在前向传播中维护缓存。 扩展 forward 方法,增加 use_cache 参数。当启用缓存时,仅为新输入的 token 计算 K、V,然后将其追加到缓存中;因果掩码也需要相应调整——Query 的行索引不再从 0 开始,而是从 ptr_current_pos 开始:

python
class MultiHeadAttention(nn.Module):
    # ... __init__ 同前 ...

    def forward(self, x, use_cache=False):
        b, num_tokens, d_in = x.shape

        # 始终只为新输入的 token 计算 Q、K、V
        keys_new = self.W_key(x)
        values_new = self.W_value(x)
        queries = self.W_query(x)

        # 拆分多头: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
        keys_new = keys_new.view(b, num_tokens, self.num_heads, self.head_dim)
        values_new = values_new.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)

        # 缓存逻辑:追加新 K/V 到缓存,或直接使用新 K/V
        if use_cache:
            if self.cache_k is None:
                self.cache_k = keys_new
                self.cache_v = values_new
            else:
                self.cache_k = torch.cat([self.cache_k, keys_new], dim=1)
                self.cache_v = torch.cat([self.cache_v, values_new], dim=1)
            keys, values = self.cache_k, self.cache_v
        else:
            keys, values = keys_new, values_new

        # 转置: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)

        # 计算注意力分数
        attn_scores = queries @ keys.transpose(2, 3)

        # 因果掩码——启用缓存时需要偏移行索引
        num_tokens_Q = queries.shape[-2]
        num_tokens_K = keys.shape[-2]
        if use_cache:
            mask_bool = self.mask.bool()[
                self.ptr_current_pos : self.ptr_current_pos + num_tokens_Q,
                :num_tokens_K
            ]
            self.ptr_current_pos += num_tokens_Q
        else:
            mask_bool = self.mask.bool()[:num_tokens_Q, :num_tokens_K]

        attn_scores.masked_fill_(mask_bool, -torch.inf)
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        context_vec = (attn_weights @ values).transpose(1, 2)
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec)
        return context_vec

掩码偏移是最容易出错的部分。在 Prefill 阶段,ptr_current_pos = 0num_tokens_Q 等于提示词长度,掩码从第 0 行开始截取,与不使用缓存时完全一致。在 Decode 阶段,每次只输入一个新 token,num_tokens_Q = 1,掩码从 ptr_current_pos 行截取一行。由于因果掩码的第 i 行仅允许关注位置 0i,新 token 恰好可以看到缓存中所有历史 token——这正是我们期望的行为。

第三步:缓存重置。 不同序列之间必须清空缓存,否则前一个序列的 K、V 会污染下一个序列的注意力计算:

python
class MultiHeadAttention(nn.Module):
    # ... 省略其他方法 ...

    def reset_cache(self):
        self.cache_k = None
        self.cache_v = None
        self.ptr_current_pos = 0

第四步:在模型层级传播 use_cache 需要修改 TransformerBlockGPTModel,将 use_cache 参数逐层传递。GPTModel 还需要维护自己的位置计数器,确保位置 embedding 的索引正确递增:

python
class TransformerBlock(nn.Module):
    # __init__ 不变
    def forward(self, x, use_cache=False):
        shortcut = x
        x = self.norm1(x)
        x = self.att(x, use_cache=use_cache)  # 传递 use_cache
        x = self.drop_shortcut(x)
        x = x + shortcut

        shortcut = x
        x = self.norm2(x)
        x = self.ff(x)
        x = self.drop_shortcut(x)
        x = x + shortcut
        return x


class GPTModel(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"])
        self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"])
        self.drop_emb = nn.Dropout(cfg["drop_rate"])

        # 使用 ModuleList 替代 Sequential,以便逐层传递 use_cache
        self.trf_blocks = nn.ModuleList(
            [TransformerBlock(cfg) for _ in range(cfg["n_layers"])]
        )
        self.current_pos = 0  # 模型级位置计数器

        self.final_norm = LayerNorm(cfg["emb_dim"])
        self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False)

    def forward(self, in_idx, use_cache=False):
        batch_size, seq_len = in_idx.shape
        tok_embeds = self.tok_emb(in_idx)

        # 位置 embedding:缓存模式下从 current_pos 开始,否则从 0 开始
        if use_cache:
            pos_ids = torch.arange(
                self.current_pos, self.current_pos + seq_len,
                device=in_idx.device, dtype=torch.long
            )
            self.current_pos += seq_len
        else:
            pos_ids = torch.arange(0, seq_len, device=in_idx.device, dtype=torch.long)

        pos_embeds = self.pos_emb(pos_ids).unsqueeze(0)
        x = tok_embeds + pos_embeds
        x = self.drop_emb(x)

        for blk in self.trf_blocks:
            x = blk(x, use_cache=use_cache)

        x = self.final_norm(x)
        logits = self.out_head(x)
        return logits

    def reset_kv_cache(self):
        for blk in self.trf_blocks:
            blk.att.reset_cache()
        self.current_pos = 0

注意 nn.Sequential 被替换为 nn.ModuleList——前者的 forward 方法签名固定为 forward(x),无法传递额外参数;后者需要手动循环调用每个模块,但支持任意参数传递。

第五步:使用缓存的生成函数。 完整的生成流程如下:

python
def generate_with_kv_cache(model, idx, max_new_tokens, context_size=None):
    model.eval()
    ctx_len = context_size or model.pos_emb.num_embeddings

    with torch.no_grad():
        # Prefill: 将完整提示词送入模型,初始化缓存
        model.reset_kv_cache()
        logits = model(idx[:, -ctx_len:], use_cache=True)

        for _ in range(max_new_tokens):
            # 贪心解码:选择概率最高的 token
            next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
            idx = torch.cat([idx, next_idx], dim=1)
            # Decode: 仅将新 token 送入模型
            logits = model(next_idx, use_cache=True)

    return idx

与不使用缓存的版本对比:

python
def generate_without_cache(model, idx, max_new_tokens, context_size):
    model.eval()
    for _ in range(max_new_tokens):
        with torch.no_grad():
            logits = model(idx[:, -context_size:])  # 每步都送入完整序列
        next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
        idx = torch.cat([idx, next_idx], dim=1)
    return idx

关键差异在于 Decode 阶段的输入:缓存版本每步仅传入 next_idx(一个 token),朴素版本每步传入 idx[:, -context_size:](整个序列)。随着序列增长,两者的计算量差距迅速拉大。

6.1.3 优化:预分配张量与滑动窗口

因果掩码与 Prefill/Decode 阶段

图 6-3:因果掩码在 Prefill 和 Decode 阶段的不同行为。Prefill 阶段处理完整提示词,Decode 阶段每步仅处理一个新 token 并利用 KV Cache 避免重复计算。

上述基础实现虽然逻辑清晰,但存在两个工程缺陷:

  1. 反复分配内存。 每次调用 torch.cat 追加新 K、V 时,PyTorch 必须分配一块新的连续内存,将旧数据和新数据拷贝过去,然后释放旧内存。随着序列变长,这一过程涉及的数据量线性增长,导致显著的内存碎片和性能下降。
  2. 内存无上限增长。 缓存大小随生成长度线性增长,对于长序列场景,KV Cache 可能耗尽 GPU 显存。以一个具体的例子估算:batch_size=32、num_heads=32、num_layers=32、head_dim=128、seq_length=2048、float32 类型,KV Cache 需要 2×32×32×128×2048×32×4 字节 64 GB 显存——这足以耗尽大多数单卡的显存。

针对这两个问题,分别采用预分配张量和滑动窗口截断两种优化策略。

优化一:预分配张量。 在模型初始化时,根据最大序列长度一次性分配好 K、V 缓存的完整内存,后续生成过程中只需向对应位置写入数据,无需反复分配和拷贝:

python
class MultiHeadAttentionOptimized(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout,
                 num_heads, qkv_bias=False, window_size=None):
        super().__init__()
        assert d_out % num_heads == 0

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads
        self.window_size = window_size or context_length

        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out)
        self.dropout = nn.Dropout(dropout)

        # 预分配缓存(初始化为 None,首次使用时按 batch_size 分配)
        self.register_buffer("cache_k", None, persistent=False)
        self.register_buffer("cache_v", None, persistent=False)

    def forward(self, x, use_cache=False):
        b, num_tokens, _ = x.shape

        keys_new = self.W_key(x)
        values_new = self.W_value(x)
        queries = self.W_query(x)

        keys_new = keys_new.view(b, num_tokens, self.num_heads, self.head_dim)
        values_new = values_new.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)

        # 先转置再操作缓存: (b, num_heads, num_tokens, head_dim)
        keys_new = keys_new.transpose(1, 2)
        values_new = values_new.transpose(1, 2)
        queries = queries.transpose(1, 2)

        if use_cache:
            # 首次调用或 batch_size 变化时,预分配完整缓存
            if self.cache_k is None or self.cache_k.size(0) != b:
                self.cache_k = torch.zeros(
                    b, self.num_heads, self.window_size, self.head_dim,
                    device=x.device
                )
                self.cache_v = torch.zeros_like(self.cache_k)
                self.ptr_cur = 0

            # 如果追加后会溢出窗口,丢弃最早的 token
            if self.ptr_cur + num_tokens > self.window_size:
                overflow = self.ptr_cur + num_tokens - self.window_size
                self.cache_k[:, :, :-overflow, :] = \
                    self.cache_k[:, :, overflow:, :].clone()
                self.cache_v[:, :, :-overflow, :] = \
                    self.cache_v[:, :, overflow:, :].clone()
                self.ptr_cur -= overflow

            # 写入新 K/V 到缓存的对应位置(无需重新分配内存)
            self.cache_k[:, :, self.ptr_cur:self.ptr_cur + num_tokens, :] = keys_new
            self.cache_v[:, :, self.ptr_cur:self.ptr_cur + num_tokens, :] = values_new
            self.ptr_cur += num_tokens

            # 只取已填充部分
            keys = self.cache_k[:, :, :self.ptr_cur, :]
            values = self.cache_v[:, :, :self.ptr_cur, :]
        else:
            keys, values = keys_new, values_new

        # 注意力计算
        attn_scores = queries @ keys.transpose(2, 3)

        # 动态构建因果掩码
        K = attn_scores.size(-1)
        if num_tokens == K:
            causal_mask = torch.triu(
                torch.ones(num_tokens, K, device=x.device, dtype=torch.bool),
                diagonal=1
            )
        else:
            offset = K - num_tokens
            row_idx = torch.arange(num_tokens, device=x.device).unsqueeze(1)
            col_idx = torch.arange(K, device=x.device).unsqueeze(0)
            causal_mask = col_idx > row_idx + offset

        attn_scores.masked_fill_(causal_mask.unsqueeze(0).unsqueeze(0), -torch.inf)
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        context_vec = (attn_weights @ values).transpose(1, 2)
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec)
        return context_vec

    def reset_cache(self):
        self.cache_k = None
        self.cache_v = None

与基础版本相比,这一实现的关键变化有三处:

  1. 缓存形状固定为 (b, num_heads, window_size, head_dim),一次性分配后不再变化。写入新 K/V 时通过切片赋值 cache_k[:, :, ptr:ptr+n, :] = keys_new,这是原地操作,不触发内存分配。
  2. 先转置再存入缓存。基础版本在 (b, num_tokens, num_heads, head_dim) 维度上拼接后再转置,优化版本直接以 (b, num_heads, num_tokens, head_dim) 格式存储,减少一次转置操作。
  3. 动态因果掩码替代预存储掩码。预分配版本的缓存长度可能因滑动窗口截断而变化,不再适合使用固定的上三角掩码。通过 col_idx > row_idx + offset 动态计算掩码,其中 offset = K - num_tokens 是缓存中已有的历史 token 数量。

优化二:滑动窗口截断。 当生成长度超过 window_size 时,缓存中最早的 token 会被丢弃。上面代码中的溢出处理逻辑实现了这一策略:

python
if self.ptr_cur + num_tokens > self.window_size:
    overflow = self.ptr_cur + num_tokens - self.window_size
    # 将缓存内容整体左移 overflow 个位置
    self.cache_k[:, :, :-overflow, :] = self.cache_k[:, :, overflow:, :].clone()
    self.cache_v[:, :, :-overflow, :] = self.cache_v[:, :, overflow:, :].clone()
    self.ptr_cur -= overflow

滑动窗口将 KV Cache 的显存占用锁定在 O(window_size),不再随生成长度增长。代价是模型无法回顾窗口之外的历史 token——对于需要长距离依赖的任务,窗口大小需要仔细权衡。在实践中,当 window_size 等于模型的上下文长度时,滑动窗口不会丢弃任何信息,退化为纯预分配优化。

6.1.4 性能对比

Prefill 与 Decode 两阶段的计算模式

图 6-4:KV Cache 的两阶段推理流程。Prefill 阶段一次性处理所有输入 token 并填充缓存,Decode 阶段逐个生成新 token 并增量更新缓存。

在 GPT-2 124M 模型上(12 层、12 头、768 维 embedding、上下文长度 1024),以 "Hello, I am" 作为提示词生成 200 个新 token,三种实现的推理速度对比如下(Mac Mini M4 芯片,CPU 推理):

实现方式吞吐量 (tokens/sec)相对加速
无缓存(朴素实现)271.0x
KV Cache(基础版,torch.cat 拼接)1445.3x
KV Cache(优化版,预分配 + 滑动窗口)1666.1x

表 6-1:三种实现的推理吞吐量对比(GPT-2 124M,200 tokens 生成,CPU)。

几点说明:

  • 基础版 KV Cache 已经带来 5.3 倍加速,因为注意力层中 K/V 的投影计算从 O(T2) 降到了 O(T)
  • 优化版在基础版之上再提升约 15%,主要来自消除 torch.cat 的内存分配开销。
  • 这些数据基于 CPU 推理。在 GPU 上,由于小模型的计算量不足以充分利用 GPU 并行度,设备通信和 kernel 启动开销可能抵消 KV Cache 的收益。对于更大的模型(如 7B+ 参数),KV Cache 在 GPU 上的加速比通常更为显著。
  • 正确性验证:三种实现在相同输入下生成完全相同的 token 序列。这是验证 KV Cache 实现正确性的关键标准——缓存是纯粹的计算优化,不应改变模型的数学行为。

6.1.5 优缺点分析

优势:

  • 计算效率:K/V 投影的总计算量从 O(T2) 降至 O(T),生成越长的序列收益越大。
  • 实现简洁:核心逻辑仅涉及缓冲区注册、条件拼接和位置指针维护,对模型架构的侵入性很小。

代价:

  • 显存占用线性增长:每生成一个新 token,缓存增加 2×L×H×dk 个浮点数(L 为层数,H 为头数,dk 为每头维度,因子 2 对应 K 和 V 两份)。对于大模型和长序列,KV Cache 可能成为显存瓶颈。
  • 只能用于推理:训练阶段使用教师强制,所有位置的 K/V 本就需要完整计算,缓存无意义。
  • 增加代码复杂度use_cache 标志需要贯穿整个模型层级(Attention → TransformerBlock → GPTModel → 生成函数),增加了维护负担。

KV Cache 的显存开销催生了一系列后续优化技术,包括 Multi-Query Attention(多个 Q 头共享一组 K/V)、Grouped-Query Attention(折中方案,若干 Q 头为一组共享 K/V)、以及量化 KV Cache(将缓存从 FP16 压缩到 INT8/INT4)。这些技术将在后续章节中讨论。

因果注意力掩码示意图

图 6-5:因果注意力掩码。下三角掩码确保每个位置只能看到之前的 token,是自回归生成中 KV Cache 有效性的基础。

本节小结

本节介绍了 KV Cache 的原理与实现:

  • 原理层面:自回归生成中,历史 token 的 K/V 向量不随新 token 的加入而改变,因此可以缓存复用。缓存将 K/V 投影的总计算量从 O(T2) 降至 O(T),是 LLM 推理加速的基础技术。
  • 基础实现:通过 register_buffer 注册 cache_k/cache_v 缓冲区,使用 torch.cat 逐步追加新 K/V,配合位置指针 ptr_current_pos 正确偏移因果掩码。
  • 优化实现:预分配固定大小的张量消除反复内存分配,滑动窗口截断控制显存上限。优化后在 124M 模型上达到 166 tokens/sec,相比无缓存的 27 tokens/sec 实现 6.1 倍加速。
  • 核心权衡:KV Cache 以显存换速度。缓存大小 =2×L×H×dk×T×sizeof(dtype),对于大模型和长序列可能成为显存主要消耗项,需要配合 MQA/GQA 等架构级优化共同使用。