Skip to content

19.4 投机采样(Speculative Decoding)

大语言模型的自回归推理有一个反直觉的事实:GPU 的计算单元大部分时间不是在算数,而是在等数据。每生成一个 Token,就需要从显存中加载一次全量模型权重,而实际的矩阵乘法运算量极小——这就是推理阶段的内存墙(Memory Wall) 问题。投机采样(Speculative Decoding)正是为打破这堵墙而设计的无损加速技术:它用一个小模型快速"猜"多个 Token,再让大模型一次性"验",把串行的访存瓶颈转化为可并行的计算任务。


19.4.1 核心动机:推理为何受限于内存带宽

要理解投机采样的动机,首先需要区分推理中两个阶段的计算特性。

Prefill 阶段处理用户的完整输入提示,所有 Token 同时可见,可以在序列维度上并行计算——此时 GPU 的算力被充分利用,属于计算受限(Compute-Bound) 场景。

Generation(解码)阶段则完全不同。由于自回归特性,每一步只生成一个 Token,必须等前一个 Token 生成完毕才能开始下一个。每生成一个 Token,GPU 需要将整个模型的权重(对于 70B 模型约 140 GB)从显存加载到计算单元,但实际只做了一次向量-矩阵乘法。用算术强度(Arithmetic Intensity) 来衡量——每传输一个字节所执行的浮点运算次数——解码阶段的算术强度接近 1 FLOP/Byte,远低于 GPU 的平衡点(通常在 100 FLOP/Byte 以上),落入 Roofline 模型的内存受限区域。

Roofline 模型示意:LLM 推理解码阶段落入内存受限区域(绿色),GPU 算力大量闲置

图 19-4:Roofline 模型下的 LLM 推理。解码阶段的算术强度极低,性能完全受限于内存带宽,GPU 的计算峰值无法发挥。

这意味着一个关键的不对称性:生成一个 Token 很慢(受访存限制),但验证一批 Token 很快(可以并行计算,类似 Prefill)。投机采样正是利用了这种不对称性:用一个参数量小、推理速度快的草稿模型(Draft Model) Mq 串行生成 γ 个候选 Token,然后用目标大模型(Target Model) Mp 一次前向传播并行验证这些候选——大模型只需多加载一次权重,就能检查多个位置的预测是否正确。


19.4.2 两阶段算法:草稿生成与并行验证

设目标大模型的条件概率分布为 p(x),草稿小模型的条件概率分布为 q(x)。投机采样的每一轮迭代包含两个阶段。

阶段一:草稿生成。 给定当前已生成的前缀序列,草稿模型 Mq 以标准自回归方式依次生成 γ 个候选 Token:x1,x2,,xγ。由于 Mq 参数量远小于 Mp(例如 0.6B vs 4B,或 1B vs 70B),这一步的延迟很低。

阶段二:并行验证。 将前缀序列与 γ 个候选 Token 拼接后,整体输入目标大模型 Mp。大模型通过一次前向传播即可同时计算出所有 γ 个位置的条件概率分布——这与 Prefill 阶段的并行计算原理完全相同,不会带来额外的访存轮次。

投机采样流程:草稿模型串行生成候选 Token,目标模型一次 Forward 并行验证

图 19-5:投机采样的两阶段流程。草稿模型生成 4 个候选 Token,目标模型一次前向传播验证所有位置,接受前两个、拒绝第三个并从修正分布中重采样。

拿到目标模型的概率分布后,算法从左到右逐个检查候选 Token xi

  • 若接受:保留 xi,继续检查 xi+1
  • 若拒绝:丢弃 xi 及其后所有候选,从修正分布中重采样一个 Token 替换 xi,本轮结束。
  • 若全部接受:额外从大模型在位置 γ+1 的分布中采样一个新 Token,本轮共生成 γ+1 个有效 Token。

下图展示了一个实际运行的示例:大模型验证了 9 次就生成了 37 个 Token——因为每次验证都一次性确认了多个由小模型生成的候选。

逐 Token 接受/拒绝示例:绿色为接受,红色为拒绝并重采样

图 19-6:Token 级别的接受与拒绝。绿色 Token 由草稿模型生成并被大模型接受,红色/蓝色 Token 被拒绝后由大模型重采样替换。


19.4.3 无损生成的数学证明 [必读]

投机采样最精妙之处在于:它在数学上严格保证最终输出的 Token 分布与单独运行大模型 Mp 完全一致——这是一种无损加速,不会降低生成质量。

拒绝采样准则。 对于草稿模型采样得到的候选 Token x,接受概率定义为:

accept(x)=min(1,p(x)q(x))

这意味着:

  • q(x)p(x) 时(大模型认为该 Token 至少和草稿模型一样好),无条件接受
  • q(x)>p(x) 时(草稿模型"过度自信"),以概率 p(x)/q(x) 接受,以概率 1p(x)/q(x) 拒绝。

修正重采样分布。 当 Token 被拒绝时,需要从一个残差分布 p(x) 中重新采样:

p(x)=norm(max(0, p(x)q(x)))=max(0, p(x)q(x))xmax(0, p(x)q(x))

为什么这保证了无损? 我们来严格证明:对于词表中任意 Token x,投机采样输出 x 的总概率恰好等于 p(x)

总概率由两条路径贡献:

路径一(直接接受): 草稿模型采样到 x(概率 q(x)),且被接受(概率 min(1,p(x)/q(x)))。

Paccept(x)=q(x)min(1,p(x)q(x))=min(q(x), p(x))

路径二(拒绝后重采样到 x): 草稿模型采样到某个 xx,被拒绝,然后从 p 中恰好重采样到 x。拒绝事件发生的总概率为:

Preject=xq(x)max(0, 1p(x)q(x))=xmax(0, q(x)p(x))

从修正分布重采样到 x 的概率为 p(x)=max(0,p(x)q(x))/Z,其中归一化常数 Z=xmax(0,p(x)q(x))

注意到一个关键等式:xmax(0,q(x)p(x))=xmax(0,p(x)q(x))=Z(因为两个分布的概率质量之和都为 1,"多出"的部分必须等于"缺少"的部分)。因此:

Presample(x)=Zmax(0, p(x)q(x))Z=max(0, p(x)q(x))

两条路径相加:

P(x)=min(q(x),p(x))+max(0, p(x)q(x))=p(x)

最后一步成立是因为:min(a,b)+max(0,ba)=b 对任意非负 a,b 恒成立。这就完成了无损性的证明。

直觉总结: 总采样概率可以看作两个区域的叠加——p(x)q(x) 重叠的部分(min(p,q),由接受路径贡献)加上 p(x) 超出 q(x) 的残差部分(max(0,pq),由重采样路径贡献),二者恰好拼合成完整的 p(x)

具体数值示例。 为了让证明更加直观,我们用一个只有两个 Token {A,B} 的简化词表来验证。假设:

  • 草稿模型分布:q(A)=0.8,q(B)=0.2
  • 目标模型分布:p(A)=0.3,p(B)=0.7

草稿模型对 A 过度自信(q(A)>p(A)),对 B 信心不足(q(B)<p(B))。

计算采样 A 的总概率:

  • 路径一:草稿采样到 A(概率 0.8),接受概率 min(1,0.3/0.8)=0.375。贡献 = 0.8×0.375=0.3
  • 路径二:草稿采样到 B(概率 0.2),接受概率 min(1,0.7/0.2)=1.0,不会拒绝,所以路径二不会给 A 贡献重采样概率。
  • 总概率 P(A)=0.3=p(A)

计算采样 B 的总概率:

  • 路径一:草稿采样到 B(概率 0.2),接受概率 min(1,0.7/0.2)=1.0。贡献 = 0.2×1.0=0.2
  • 路径二:草稿采样到 A(概率 0.8),拒绝概率 10.375=0.625。修正分布 p(B)=max(0,0.70.2)/0.5=1.0(因为 max(0,p(A)q(A))=0,只有 B 有残差)。贡献 = 0.8×0.625×1.0=0.5
  • 总概率 P(B)=0.2+0.5=0.7=p(B)

结果完美还原了目标分布 p,验证了无损性。


19.4.4 效率分析:接受率与加速比

投机采样的加速效果由两个核心参数决定:前瞻步数 γ平均接受率 α

接受率的定义。 在草稿模型的分布下,每个候选 Token 被大模型接受的期望概率为:

α=Exq[min(1,p(x)q(x))]

α 的取值范围为 [0,1],它衡量了草稿模型与目标模型的分布一致性。当两个模型对下一个 Token 的预测高度一致时,α 接近 1;当两者差异巨大时,α 趋近 0。

每轮期望生成的有效 Token 数。 在一轮投机迭代中(草稿模型生成 γ 个候选),期望生成的有效 Token 数为:

E[#tokens]=1αγ+11α

推导过程。N 为一轮中被接受的候选 Token 数(0Nγ)。假设各位置的接受率独立且相同(均为 α),则:

  • 恰好前 k 个被接受、第 k+1 个被拒绝的概率为 αk(1α)k=0,1,,γ1)。
  • 全部 γ 个候选都被接受的概率为 αγ

无论哪种情况,一轮至少产出一个 Token(要么接受的候选、要么拒绝后重采样的 Token、要么全部接受后的 bonus Token)。因此期望有效 Token 数为:

E[N+1]=k=0γ1(k+1)αk(1α)+(γ+1)αγ=1αγ+11α

最后的化简利用了几何级数求和公式 k=0nαk=(1αn+1)/(1α)

场景αγ期望有效 Token 数含义
理想情况156草稿模型几乎完美,一轮生成 γ+1 个 Token
良好匹配0.752.94典型的同系列大小模型组合
匹配较差0.351.40草稿模型太弱,加速有限
最差情况051退化为标准自回归,还额外浪费了草稿模型的计算

表 19-3:不同接受率下的期望有效 Token 数。

加速比分析。 设大模型一次前向传播的延迟为 Tp,草稿模型为 Tq。标准自回归生成 n 个 Token 需要 nTp。投机采样每轮需要 γTq+Tp(草稿生成 + 一次并行验证),期望生成 E 个 Token。因此加速比约为:

SpeedupETpγTq+Tp=E1+γTq/Tp

TqTp(草稿模型远快于目标模型)且 α 较高时,加速比可以接近 E,实践中通常在 2x-3x 范围内。

一个具体的数值示例。 假设使用 Qwen3-0.6B(Tq=5ms)作为草稿模型,Qwen3-4B(Tp=30ms)作为目标模型,设 γ=5,接受率 α=0.67

  • 标准自回归生成 50 个 Token 需要:50×30ms=1500ms
  • 每轮投机采样耗时:5×5ms+30ms=55ms
  • 每轮期望有效 Token 数:(10.676)/(10.67)2.72
  • 生成 50 个 Token 约需 50/2.7218.4
  • 总耗时:18.4×55ms1012ms
  • 加速比约 1500/10121.48x

如果接受率提升到 α=0.85(选择更匹配的草稿模型),每轮期望有效 Token 数提升到约 4.37,加速比可达 2.4x。这说明草稿模型的选择对加速效果有决定性影响。


19.4.5 贪婪模式 vs 采样模式

一个重要的实践发现:贪婪解码(Greedy Decoding)下的投机采样比随机采样模式更快。

  • 贪婪模式:大模型和草稿模型都取 argmax。只要两者预测的最高概率 Token 一致就接受。对于常见的下一个词,两个模型的 argmax 往往相同,接受率很高。
  • 随机采样模式:即使两个模型的概率分布几乎相同,随机抽样也可能碰巧抽到不同的 Token,导致本不该拒绝的候选被拒绝,接受率显著下降。

实验数据表明,在某些任务上贪婪模式可以带来约 1.1x 的额外加速,而随机采样模式可能反而比标准自回归还慢(约 0.85x),原因是随机性导致大量无谓的拒绝,同时还多付出了草稿模型的计算开销。

直觉上可以这样理解:在贪婪模式下,接受条件简化为"大小模型的 argmax 是否一致"——这是一个确定性判断,不受随机性干扰。而在采样模式下,即使 pq 的分布形状非常接近,两次独立的随机采样也可能抽到不同的 Token,导致拒绝率上升。因此,如果应用场景允许贪婪解码(如代码生成、事实性问答),优先使用贪婪模式配合投机采样可以获得最佳加速效果。


19.4.6 工程实践要点

KV Cache 回滚。 当大模型拒绝了第 i 个候选 Token 时,xi 之后所有候选的 KV Cache 都必须回滚——这在不同推理框架中的实现难度和性能开销差异很大,往往比算法本身更难工程化。vLLM 和 SGLang 等主流推理框架已经内置了投机采样支持,但 KV Cache 的裁剪(Trimming)逻辑仍然是性能调优的重点。

草稿模型的选择。 理想的草稿模型应满足两个条件:(1)推理速度远快于目标模型,通常参数量为目标模型的 1/10 到 1/5;(2)与目标模型的分布尽量一致,通常选择同系列的小规模版本效果最佳(如用 Qwen3-0.6B 配合 Qwen3-4B,或用 Llama-1B 配合 Llama-70B)。

前瞻步数 γ 的调优。 γ 过小会限制单轮产出的 Token 数,γ 过大则后续位置的接受率会显著下降(因为草稿模型的累积误差越来越大)。实践中 γ=46 是常见的选择,具体最优值取决于草稿模型的质量和任务特性。

自草稿(Self-Drafting)变体。 除了使用独立的小模型作为草稿器,近年来还出现了不依赖外部草稿模型的方案:Medusa 为目标模型额外训练多个预测头,每个头独立预测未来第 k 个 Token;EAGLE 系列则利用目标模型自身的隐层特征递归生成草稿候选。这些方法省去了维护独立草稿模型的开销,但需要额外的训练或微调。


19.4.7 完整实现:投机采样代码

下面给出一个自包含的投机采样实现,包含草稿生成、拒绝采样、修正分布重采样的完整逻辑。

python
import torch
import torch.nn.functional as F
from typing import Tuple

def speculative_decode(
    target_model,          # 目标大模型(callable, 输入 token ids -> logits)
    draft_model,           # 草稿小模型(callable, 输入 token ids -> logits)
    prefix: torch.Tensor,  # 初始前缀序列, shape: (seq_len,)
    gamma: int = 5,        # 每轮草稿生成的候选 Token 数
    max_tokens: int = 50,  # 最大生成 Token 数
    temperature: float = 1.0,
) -> Tuple[torch.Tensor, dict]:
    """
    投机采样生成算法。
    保证输出分布与单独使用 target_model 完全一致(无损)。
    """
    generated = prefix.clone()
    stats = {"total_draft": 0, "total_accepted": 0, "num_rounds": 0}

    while generated.shape[0] - prefix.shape[0] < max_tokens:
        # ========== 阶段一:草稿模型串行生成 γ 个候选 ==========
        draft_tokens = []
        draft_probs = []
        current = generated.clone()

        for _ in range(gamma):
            with torch.no_grad():
                logits = draft_model(current.unsqueeze(0))  # (1, seq_len, vocab)
                logits_last = logits[0, -1, :] / temperature
                prob_q = F.softmax(logits_last, dim=-1)

            # 从草稿分布中采样
            token = torch.multinomial(prob_q, num_samples=1)  # (1,)
            draft_tokens.append(token.item())
            draft_probs.append(prob_q)
            current = torch.cat([current, token])

        stats["total_draft"] += gamma

        # ========== 阶段二:目标模型一次前向并行验证 ==========
        verify_input = current.unsqueeze(0)  # prefix + γ 个候选
        with torch.no_grad():
            target_logits = target_model(verify_input)  # (1, seq_len+γ, vocab)

        # ========== 阶段三:逐个拒绝采样 ==========
        n_accepted = 0
        for i in range(gamma):
            pos = generated.shape[0] + i - 1  # 对应目标 logits 的位置
            p_logits = target_logits[0, pos, :] / temperature
            p = F.softmax(p_logits, dim=-1)
            q = draft_probs[i]
            x = draft_tokens[i]

            # 接受概率: min(1, p(x)/q(x))
            p_x = p[x].item()
            q_x = q[x].item()
            accept_prob = min(1.0, p_x / max(q_x, 1e-10))

            if torch.rand(1).item() < accept_prob:
                # 接受该候选 Token
                generated = torch.cat([generated, torch.tensor([x], device=generated.device)])
                n_accepted += 1
            else:
                # 拒绝:从修正分布 p'(x) = norm(max(0, p-q)) 中重采样
                residual = torch.clamp(p - q, min=0)
                residual_sum = residual.sum()
                if residual_sum > 1e-10:
                    p_prime = residual / residual_sum
                else:
                    p_prime = p  # 退化为直接使用目标分布
                new_token = torch.multinomial(p_prime, num_samples=1)
                generated = torch.cat([generated, new_token.to(generated.device)])
                break
        else:
            # 所有 γ 个候选均被接受,从大模型最后一个位置再采样一个
            bonus_logits = target_logits[0, -1, :] / temperature
            bonus_prob = F.softmax(bonus_logits, dim=-1)
            bonus_token = torch.multinomial(bonus_prob, num_samples=1)
            generated = torch.cat([generated, bonus_token.to(generated.device)])
            n_accepted += 1  # 计入 bonus token

        stats["total_accepted"] += n_accepted
        stats["num_rounds"] += 1

    # 截断到目标长度
    generated = generated[: prefix.shape[0] + max_tokens]
    stats["acceptance_rate"] = stats["total_accepted"] / max(stats["total_draft"], 1)
    return generated, stats

代码关键点解读:

  1. 草稿生成循环:逐步调用草稿模型生成 γ 个 Token,同时保存每一步的概率分布 q(x),供后续验证使用。
  2. 目标模型并行验证:将完整序列(前缀 + γ 个候选)一次性输入目标模型,利用因果注意力掩码,一次前向传播获得所有位置的概率分布 p(x)
  3. 拒绝采样:按照 min(1,p(x)/q(x)) 的接受概率逐个检查候选 Token。一旦拒绝,立即从修正分布 p(x)=norm(max(0,pq)) 中重采样并终止本轮。
  4. 全部接受的 bonus:若 γ 个候选全部通过验证,目标模型在最后一个位置的输出可以额外提供一个"免费"Token。
  5. torch.clamp(p - q, min=0) 实现了 max(0,p(x)q(x)) 的向量化计算,避免逐元素循环。

19.4.8 小结

本节介绍了投机采样(Speculative Decoding)的完整理论与实践。回顾要点如下:

要素核心内容
动机自回归推理受限于内存带宽而非计算能力,GPU 大部分时间在等待数据搬运
方法小模型 Mq 串行生成 γ 个草稿 Token,大模型 Mp 一次前向传播并行验证
无损保证拒绝采样 + 修正分布 p(x)=norm(max(0,pq)),数学证明输出分布恒等于 p(x)
效率接受率 α=E[min(1,p/q)] 决定加速比,实践中约 2-3 倍
实践贪婪模式优于随机采样;γ=46 为常见选择;KV Cache 回滚是主要工程难点

投机采样的核心洞察是利用**"生成慢、验证快"** 的不对称性——这一思想不仅限于大小模型配对,也催生了 Medusa、EAGLE 等自草稿变体,以及与 MTP(Multi-Token Prediction)训练头结合的方案。理解了本节的数学原理,读者可以更深入地评估和选择适合自身场景的推理加速策略。