6.1 KV Cache
大语言模型的推理过程是逐 token 生成的:给定已有序列,模型预测下一个 token,将其拼接到序列末尾,再用更长的序列预测再下一个 token。这一自回归生成循环中隐藏着大量冗余计算——每一步都需要对所有历史 token 重新计算注意力中的键(Key)和值(Value)向量,而这些向量在之前的步骤中已经计算过。KV Cache 正是针对这一冗余提出的优化技术:将已计算的 K、V 向量缓存起来,后续步骤直接复用,从而将逐 token 生成的计算复杂度从二次降到线性。本节将从原理出发,完整实现一个带 KV Cache 的 GPT 推理系统,并进一步讨论预分配张量和滑动窗口截断等工程优化。

图 6-1:自注意力机制计算流程。每个 token 生成 Query、Key、Value 向量,通过注意力矩阵计算加权和。KV Cache 缓存历史 K/V 避免重复计算。
6.1.1 为什么需要 KV Cache
自回归生成的冗余计算。 回顾注意力机制的计算过程:对于输入序列中的每个 token,模型通过线性投影分别生成 Query(Q)、Key(K)、Value(V)三个向量,然后计算注意力分数
假设模型正在处理提示词 "Time flies",此时注意力计算涉及两个 token 的 K、V 向量。当模型生成了新 token "fast" 后,下一轮输入变为 "Time flies fast",需要重新计算三个 token 的所有 K、V 向量。但仔细观察可以发现:"Time" 和 "flies" 的 K、V 向量与上一轮完全相同——它们仅取决于对应位置的输入 embedding 和线性投影权重,与序列中后续 token 无关。
将这一观察推广到整个生成过程。设生成序列长度为
这意味着生成长度翻倍,计算量将翻四倍。
KV Cache 的核心思想。 既然前
以四步生成过程 "遥→遥→领→先" 为例,展开注意力计算公式:
关键观察:
只需要 ——当前 token 的 Query 向量。历史 token 的 Query 不需要重新计算,因为它们的输出已经在之前的步骤中使用过了。 和 可以复用——它们在之前的步骤中已经计算过,且值不会改变。
因此,KV Cache 的工作流程分为两个阶段:
- Prefill(预填充)阶段:将完整的提示词序列一次性送入模型,计算所有 token 的 K、V 向量并存入缓存。
- Decode(解码)阶段:每步仅将新生成的单个 token 送入模型,计算其 Q、K、V 向量,其中 K、V 追加到缓存,Q 与缓存中的所有 K 计算注意力分数。
6.1.2 基础实现

图 6-2:注意力权重矩阵的可视化。因果掩码确保每个位置只能关注自身及之前的 token,形成下三角矩阵模式。
下面在 MultiHeadAttention 类中实现 KV Cache。核心改动只有三处:注册缓存缓冲区、在前向传播中维护缓存、提供缓存重置方法。
第一步:注册缓存缓冲区。 在 MultiHeadAttention 的构造函数中,使用 register_buffer 注册两个缓冲区 cache_k 和 cache_v,并维护一个位置指针 ptr_current_pos 追踪当前缓存中已填充的 token 数:
class MultiHeadAttention(nn.Module):
def __init__(self, d_in, d_out, context_length, dropout,
num_heads, qkv_bias=False):
super().__init__()
assert d_out % num_heads == 0, "d_out must be divisible by num_heads"
self.d_out = d_out
self.num_heads = num_heads
self.head_dim = d_out // num_heads
self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
self.out_proj = nn.Linear(d_out, d_out)
self.dropout = nn.Dropout(dropout)
self.register_buffer(
"mask",
torch.triu(torch.ones(context_length, context_length), diagonal=1),
persistent=False
)
# KV Cache 缓冲区
self.register_buffer("cache_k", None, persistent=False)
self.register_buffer("cache_v", None, persistent=False)
self.ptr_current_pos = 0 # 位置指针,追踪已缓存的 token 数使用 register_buffer 而非普通属性,确保缓存会随模型一起迁移到目标设备(CPU/GPU),但不会被优化器视为可训练参数。设置 persistent=False 表示缓存不需要被保存到 state_dict 中——它是纯推理时的临时状态。
第二步:在前向传播中维护缓存。 扩展 forward 方法,增加 use_cache 参数。当启用缓存时,仅为新输入的 token 计算 K、V,然后将其追加到缓存中;因果掩码也需要相应调整——Query 的行索引不再从 0 开始,而是从 ptr_current_pos 开始:
class MultiHeadAttention(nn.Module):
# ... __init__ 同前 ...
def forward(self, x, use_cache=False):
b, num_tokens, d_in = x.shape
# 始终只为新输入的 token 计算 Q、K、V
keys_new = self.W_key(x)
values_new = self.W_value(x)
queries = self.W_query(x)
# 拆分多头: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
keys_new = keys_new.view(b, num_tokens, self.num_heads, self.head_dim)
values_new = values_new.view(b, num_tokens, self.num_heads, self.head_dim)
queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
# 缓存逻辑:追加新 K/V 到缓存,或直接使用新 K/V
if use_cache:
if self.cache_k is None:
self.cache_k = keys_new
self.cache_v = values_new
else:
self.cache_k = torch.cat([self.cache_k, keys_new], dim=1)
self.cache_v = torch.cat([self.cache_v, values_new], dim=1)
keys, values = self.cache_k, self.cache_v
else:
keys, values = keys_new, values_new
# 转置: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
keys = keys.transpose(1, 2)
queries = queries.transpose(1, 2)
values = values.transpose(1, 2)
# 计算注意力分数
attn_scores = queries @ keys.transpose(2, 3)
# 因果掩码——启用缓存时需要偏移行索引
num_tokens_Q = queries.shape[-2]
num_tokens_K = keys.shape[-2]
if use_cache:
mask_bool = self.mask.bool()[
self.ptr_current_pos : self.ptr_current_pos + num_tokens_Q,
:num_tokens_K
]
self.ptr_current_pos += num_tokens_Q
else:
mask_bool = self.mask.bool()[:num_tokens_Q, :num_tokens_K]
attn_scores.masked_fill_(mask_bool, -torch.inf)
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
attn_weights = self.dropout(attn_weights)
context_vec = (attn_weights @ values).transpose(1, 2)
context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
context_vec = self.out_proj(context_vec)
return context_vec掩码偏移是最容易出错的部分。在 Prefill 阶段,ptr_current_pos = 0,num_tokens_Q 等于提示词长度,掩码从第 0 行开始截取,与不使用缓存时完全一致。在 Decode 阶段,每次只输入一个新 token,num_tokens_Q = 1,掩码从 ptr_current_pos 行截取一行。由于因果掩码的第
第三步:缓存重置。 不同序列之间必须清空缓存,否则前一个序列的 K、V 会污染下一个序列的注意力计算:
class MultiHeadAttention(nn.Module):
# ... 省略其他方法 ...
def reset_cache(self):
self.cache_k = None
self.cache_v = None
self.ptr_current_pos = 0第四步:在模型层级传播 use_cache。 需要修改 TransformerBlock 和 GPTModel,将 use_cache 参数逐层传递。GPTModel 还需要维护自己的位置计数器,确保位置 embedding 的索引正确递增:
class TransformerBlock(nn.Module):
# __init__ 不变
def forward(self, x, use_cache=False):
shortcut = x
x = self.norm1(x)
x = self.att(x, use_cache=use_cache) # 传递 use_cache
x = self.drop_shortcut(x)
x = x + shortcut
shortcut = x
x = self.norm2(x)
x = self.ff(x)
x = self.drop_shortcut(x)
x = x + shortcut
return x
class GPTModel(nn.Module):
def __init__(self, cfg):
super().__init__()
self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"])
self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"])
self.drop_emb = nn.Dropout(cfg["drop_rate"])
# 使用 ModuleList 替代 Sequential,以便逐层传递 use_cache
self.trf_blocks = nn.ModuleList(
[TransformerBlock(cfg) for _ in range(cfg["n_layers"])]
)
self.current_pos = 0 # 模型级位置计数器
self.final_norm = LayerNorm(cfg["emb_dim"])
self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False)
def forward(self, in_idx, use_cache=False):
batch_size, seq_len = in_idx.shape
tok_embeds = self.tok_emb(in_idx)
# 位置 embedding:缓存模式下从 current_pos 开始,否则从 0 开始
if use_cache:
pos_ids = torch.arange(
self.current_pos, self.current_pos + seq_len,
device=in_idx.device, dtype=torch.long
)
self.current_pos += seq_len
else:
pos_ids = torch.arange(0, seq_len, device=in_idx.device, dtype=torch.long)
pos_embeds = self.pos_emb(pos_ids).unsqueeze(0)
x = tok_embeds + pos_embeds
x = self.drop_emb(x)
for blk in self.trf_blocks:
x = blk(x, use_cache=use_cache)
x = self.final_norm(x)
logits = self.out_head(x)
return logits
def reset_kv_cache(self):
for blk in self.trf_blocks:
blk.att.reset_cache()
self.current_pos = 0注意 nn.Sequential 被替换为 nn.ModuleList——前者的 forward 方法签名固定为 forward(x),无法传递额外参数;后者需要手动循环调用每个模块,但支持任意参数传递。
第五步:使用缓存的生成函数。 完整的生成流程如下:
def generate_with_kv_cache(model, idx, max_new_tokens, context_size=None):
model.eval()
ctx_len = context_size or model.pos_emb.num_embeddings
with torch.no_grad():
# Prefill: 将完整提示词送入模型,初始化缓存
model.reset_kv_cache()
logits = model(idx[:, -ctx_len:], use_cache=True)
for _ in range(max_new_tokens):
# 贪心解码:选择概率最高的 token
next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
idx = torch.cat([idx, next_idx], dim=1)
# Decode: 仅将新 token 送入模型
logits = model(next_idx, use_cache=True)
return idx与不使用缓存的版本对比:
def generate_without_cache(model, idx, max_new_tokens, context_size):
model.eval()
for _ in range(max_new_tokens):
with torch.no_grad():
logits = model(idx[:, -context_size:]) # 每步都送入完整序列
next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
idx = torch.cat([idx, next_idx], dim=1)
return idx关键差异在于 Decode 阶段的输入:缓存版本每步仅传入 next_idx(一个 token),朴素版本每步传入 idx[:, -context_size:](整个序列)。随着序列增长,两者的计算量差距迅速拉大。
6.1.3 优化:预分配张量与滑动窗口

图 6-3:因果掩码在 Prefill 和 Decode 阶段的不同行为。Prefill 阶段处理完整提示词,Decode 阶段每步仅处理一个新 token 并利用 KV Cache 避免重复计算。
上述基础实现虽然逻辑清晰,但存在两个工程缺陷:
- 反复分配内存。 每次调用
torch.cat追加新 K、V 时,PyTorch 必须分配一块新的连续内存,将旧数据和新数据拷贝过去,然后释放旧内存。随着序列变长,这一过程涉及的数据量线性增长,导致显著的内存碎片和性能下降。 - 内存无上限增长。 缓存大小随生成长度线性增长,对于长序列场景,KV Cache 可能耗尽 GPU 显存。以一个具体的例子估算:batch_size=32、num_heads=32、num_layers=32、head_dim=128、seq_length=2048、float32 类型,KV Cache 需要
字节 64 GB 显存——这足以耗尽大多数单卡的显存。
针对这两个问题,分别采用预分配张量和滑动窗口截断两种优化策略。
优化一:预分配张量。 在模型初始化时,根据最大序列长度一次性分配好 K、V 缓存的完整内存,后续生成过程中只需向对应位置写入数据,无需反复分配和拷贝:
class MultiHeadAttentionOptimized(nn.Module):
def __init__(self, d_in, d_out, context_length, dropout,
num_heads, qkv_bias=False, window_size=None):
super().__init__()
assert d_out % num_heads == 0
self.d_out = d_out
self.num_heads = num_heads
self.head_dim = d_out // num_heads
self.window_size = window_size or context_length
self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
self.out_proj = nn.Linear(d_out, d_out)
self.dropout = nn.Dropout(dropout)
# 预分配缓存(初始化为 None,首次使用时按 batch_size 分配)
self.register_buffer("cache_k", None, persistent=False)
self.register_buffer("cache_v", None, persistent=False)
def forward(self, x, use_cache=False):
b, num_tokens, _ = x.shape
keys_new = self.W_key(x)
values_new = self.W_value(x)
queries = self.W_query(x)
keys_new = keys_new.view(b, num_tokens, self.num_heads, self.head_dim)
values_new = values_new.view(b, num_tokens, self.num_heads, self.head_dim)
queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
# 先转置再操作缓存: (b, num_heads, num_tokens, head_dim)
keys_new = keys_new.transpose(1, 2)
values_new = values_new.transpose(1, 2)
queries = queries.transpose(1, 2)
if use_cache:
# 首次调用或 batch_size 变化时,预分配完整缓存
if self.cache_k is None or self.cache_k.size(0) != b:
self.cache_k = torch.zeros(
b, self.num_heads, self.window_size, self.head_dim,
device=x.device
)
self.cache_v = torch.zeros_like(self.cache_k)
self.ptr_cur = 0
# 如果追加后会溢出窗口,丢弃最早的 token
if self.ptr_cur + num_tokens > self.window_size:
overflow = self.ptr_cur + num_tokens - self.window_size
self.cache_k[:, :, :-overflow, :] = \
self.cache_k[:, :, overflow:, :].clone()
self.cache_v[:, :, :-overflow, :] = \
self.cache_v[:, :, overflow:, :].clone()
self.ptr_cur -= overflow
# 写入新 K/V 到缓存的对应位置(无需重新分配内存)
self.cache_k[:, :, self.ptr_cur:self.ptr_cur + num_tokens, :] = keys_new
self.cache_v[:, :, self.ptr_cur:self.ptr_cur + num_tokens, :] = values_new
self.ptr_cur += num_tokens
# 只取已填充部分
keys = self.cache_k[:, :, :self.ptr_cur, :]
values = self.cache_v[:, :, :self.ptr_cur, :]
else:
keys, values = keys_new, values_new
# 注意力计算
attn_scores = queries @ keys.transpose(2, 3)
# 动态构建因果掩码
K = attn_scores.size(-1)
if num_tokens == K:
causal_mask = torch.triu(
torch.ones(num_tokens, K, device=x.device, dtype=torch.bool),
diagonal=1
)
else:
offset = K - num_tokens
row_idx = torch.arange(num_tokens, device=x.device).unsqueeze(1)
col_idx = torch.arange(K, device=x.device).unsqueeze(0)
causal_mask = col_idx > row_idx + offset
attn_scores.masked_fill_(causal_mask.unsqueeze(0).unsqueeze(0), -torch.inf)
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
attn_weights = self.dropout(attn_weights)
context_vec = (attn_weights @ values).transpose(1, 2)
context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
context_vec = self.out_proj(context_vec)
return context_vec
def reset_cache(self):
self.cache_k = None
self.cache_v = None与基础版本相比,这一实现的关键变化有三处:
- 缓存形状固定为
(b, num_heads, window_size, head_dim),一次性分配后不再变化。写入新 K/V 时通过切片赋值cache_k[:, :, ptr:ptr+n, :] = keys_new,这是原地操作,不触发内存分配。 - 先转置再存入缓存。基础版本在
(b, num_tokens, num_heads, head_dim)维度上拼接后再转置,优化版本直接以(b, num_heads, num_tokens, head_dim)格式存储,减少一次转置操作。 - 动态因果掩码替代预存储掩码。预分配版本的缓存长度可能因滑动窗口截断而变化,不再适合使用固定的上三角掩码。通过
col_idx > row_idx + offset动态计算掩码,其中offset = K - num_tokens是缓存中已有的历史 token 数量。
优化二:滑动窗口截断。 当生成长度超过 window_size 时,缓存中最早的 token 会被丢弃。上面代码中的溢出处理逻辑实现了这一策略:
if self.ptr_cur + num_tokens > self.window_size:
overflow = self.ptr_cur + num_tokens - self.window_size
# 将缓存内容整体左移 overflow 个位置
self.cache_k[:, :, :-overflow, :] = self.cache_k[:, :, overflow:, :].clone()
self.cache_v[:, :, :-overflow, :] = self.cache_v[:, :, overflow:, :].clone()
self.ptr_cur -= overflow滑动窗口将 KV Cache 的显存占用锁定在 window_size 等于模型的上下文长度时,滑动窗口不会丢弃任何信息,退化为纯预分配优化。
6.1.4 性能对比

图 6-4:KV Cache 的两阶段推理流程。Prefill 阶段一次性处理所有输入 token 并填充缓存,Decode 阶段逐个生成新 token 并增量更新缓存。
在 GPT-2 124M 模型上(12 层、12 头、768 维 embedding、上下文长度 1024),以 "Hello, I am" 作为提示词生成 200 个新 token,三种实现的推理速度对比如下(Mac Mini M4 芯片,CPU 推理):
| 实现方式 | 吞吐量 (tokens/sec) | 相对加速 |
|---|---|---|
| 无缓存(朴素实现) | 27 | 1.0x |
KV Cache(基础版,torch.cat 拼接) | 144 | 5.3x |
| KV Cache(优化版,预分配 + 滑动窗口) | 166 | 6.1x |
表 6-1:三种实现的推理吞吐量对比(GPT-2 124M,200 tokens 生成,CPU)。
几点说明:
- 基础版 KV Cache 已经带来 5.3 倍加速,因为注意力层中 K/V 的投影计算从
降到了 。 - 优化版在基础版之上再提升约 15%,主要来自消除
torch.cat的内存分配开销。 - 这些数据基于 CPU 推理。在 GPU 上,由于小模型的计算量不足以充分利用 GPU 并行度,设备通信和 kernel 启动开销可能抵消 KV Cache 的收益。对于更大的模型(如 7B+ 参数),KV Cache 在 GPU 上的加速比通常更为显著。
- 正确性验证:三种实现在相同输入下生成完全相同的 token 序列。这是验证 KV Cache 实现正确性的关键标准——缓存是纯粹的计算优化,不应改变模型的数学行为。
6.1.5 优缺点分析
优势:
- 计算效率:K/V 投影的总计算量从
降至 ,生成越长的序列收益越大。 - 实现简洁:核心逻辑仅涉及缓冲区注册、条件拼接和位置指针维护,对模型架构的侵入性很小。
代价:
- 显存占用线性增长:每生成一个新 token,缓存增加
个浮点数( 为层数, 为头数, 为每头维度,因子 2 对应 K 和 V 两份)。对于大模型和长序列,KV Cache 可能成为显存瓶颈。 - 只能用于推理:训练阶段使用教师强制,所有位置的 K/V 本就需要完整计算,缓存无意义。
- 增加代码复杂度:
use_cache标志需要贯穿整个模型层级(Attention → TransformerBlock → GPTModel → 生成函数),增加了维护负担。
KV Cache 的显存开销催生了一系列后续优化技术,包括 Multi-Query Attention(多个 Q 头共享一组 K/V)、Grouped-Query Attention(折中方案,若干 Q 头为一组共享 K/V)、以及量化 KV Cache(将缓存从 FP16 压缩到 INT8/INT4)。这些技术将在后续章节中讨论。

图 6-5:因果注意力掩码。下三角掩码确保每个位置只能看到之前的 token,是自回归生成中 KV Cache 有效性的基础。
本节小结
本节介绍了 KV Cache 的原理与实现:
- 原理层面:自回归生成中,历史 token 的 K/V 向量不随新 token 的加入而改变,因此可以缓存复用。缓存将 K/V 投影的总计算量从
降至 ,是 LLM 推理加速的基础技术。 - 基础实现:通过
register_buffer注册cache_k/cache_v缓冲区,使用torch.cat逐步追加新 K/V,配合位置指针ptr_current_pos正确偏移因果掩码。 - 优化实现:预分配固定大小的张量消除反复内存分配,滑动窗口截断控制显存上限。优化后在 124M 模型上达到 166 tokens/sec,相比无缓存的 27 tokens/sec 实现 6.1 倍加速。
- 核心权衡:KV Cache 以显存换速度。缓存大小
,对于大模型和长序列可能成为显存主要消耗项,需要配合 MQA/GQA 等架构级优化共同使用。