Skip to content

9.4 FlashAttention 与 FlashMLA

自注意力机制是 Transformer 的核心,但其标准实现存在严峻的性能瓶颈——计算和内存复杂度均为 O(N2),其中 N 为序列长度。然而,真正的瓶颈并非来自浮点运算量(FLOPs)本身,而来自内存访问(I/O)。标准实现需要将 N×N 的注意力矩阵 S=QK 完整地写入 GPU 的全局内存(HBM),而 GPU 片上高速缓存(SRAM)的容量远不足以容纳它。GPU 计算单元不得不反复等待 HBM 的读写操作完成,导致注意力计算成为一个典型的**内存带宽受限(memory-bound)**操作。

FlashAttention 系列算法正是针对这一瓶颈,从 IO-aware 的角度重新设计注意力计算流程,将其从内存受限转变为计算受限,实现了精确注意力的大幅加速与内存节省。

9.4.1 GPU 内存层级与 IO 复杂度分析

要理解 FlashAttention 的设计动机,首先需要了解 GPU 的内存层级结构。

GPU 内存层级是一个金字塔模型:

层级位置容量(典型值)延迟(时钟周期)带宽
寄存器(Registers)SM 内部~256 KB/SM~1最高
共享内存/L1(SRAM)SM 内部48–228 KB/SM~20–30~19 TB/s
L2 缓存芯片上,SM 外部数 MB~200–300~12 TB/s
全局内存(HBM)芯片外部40–80 GB~400+1.5–3.35 TB/s

关键矛盾在于:GPU 的计算能力增长速度远超内存带宽增长速度。以 A100 为例,其 Tensor Core 峰值算力约为 312 TFLOPS(FP16),而 HBM 带宽仅为 2.0 TB/s。这意味着每从 HBM 读取 1 字节数据,GPU 可以完成约 156 次浮点运算。如果一个操作的算术强度(FLOPs/字节)低于这个比值,计算单元就会空转等待数据,形成内存墙。

标准注意力的 IO 复杂度分析:

标准注意力的计算分为以下步骤(每步对应一个独立的 CUDA kernel):

  1. 计算 S=QK:从 HBM 读取 Q,K(各 N×d),写入 SN×N)到 HBM
  2. 计算 P=softmax(S):从 HBM 读取 S,写入 P 到 HBM
  3. 计算 O=PV:从 HBM 读取 P,V,写入 O 到 HBM

总的 HBM 访问量为:

Θ(Nd+N2)

其中 N2 项来自中间矩阵 SP 的读写。当 Nd 时(长序列场景),N2 项占据主导,成为严重的 IO 瓶颈。

FlashAttention 的 IO 复杂度:

FlashAttention 通过分块和算子融合,完全避免了 N×N 矩阵的 HBM 读写。其 HBM 访问量为:

Θ(N2d2M)

其中 M 为 SRAM 的大小。由于 d 通常较小(64 或 128),且 M 在数十 KB 量级,这一复杂度远低于标准实现的 Θ(Nd+N2)

9.4.2 在线 Softmax:分块计算的数学基础

FlashAttention 的核心挑战在于:Softmax 是一个全局依赖操作。标准的数值稳定 Softmax 公式为:

yi=exp(xim)j=1Nexp(xjm),m=maxjxj

计算任何一个 yi 都需要知道全局最大值 m 和全局归一化分母 =jexp(xjm),这与分块计算的局部性原则直接矛盾。

在线 Softmax 算法通过维护和增量更新两个统计量来解决这一问题。

设输入向量 x 被分为 T 个块 x(1),x(2),,x(T)。在处理完前 k 个块后,我们维护:

  • m(k):前 k 个块的全局最大值
  • (k):以 m(k) 为基准的全局指数和

初始化:

m(0)=,(0)=0

递推更新(处理第 k+1 个块):

  1. 计算当前块的局部统计量:
mlocal=max(x(k+1)),local=jexp(xj(k+1)mlocal)
  1. 更新全局最大值:
m(k+1)=max(m(k),mlocal)
  1. 伸缩更新全局指数和:
(k+1)=(k)exp(m(k)m(k+1))+localexp(mlocalm(k+1))

正确性证明:

该递推的正确性基于指数函数的性质 exp(ac)=exp(ab)exp(bc)。设处理完 k 个块后,(k) 满足不变式:

(k)=jblocks 1..kexp(xjm(k))

当引入新块并更新 m(k+1) 后:

(k)exp(m(k)m(k+1))=j1..kexp(xjm(k)+m(k)m(k+1))=j1..kexp(xjm(k+1))

类似地,localexp(mlocalm(k+1)) 将当前块的指数和统一到新基准 m(k+1) 下。两者相加即得到前 k+1 个块在新基准下的正确指数和,不变式成立。

从 Softmax 到注意力输出的在线更新:

在 FlashAttention 中,不仅需要在线计算 Softmax 的分母,还需要同步更新注意力输出 O。设处理完第 k 个 KV 块后,累积输出为 O(k)(未归一化),则引入第 k+1 个块时:

O(k+1)=exp(m(k)m(k+1))O(k)校正旧输出+exp(S(k+1)m(k+1))V(k+1)当前块贡献

最终输出为 Ofinal=O(T)/(T)

9.4.3 FlashAttention 前向与反向传播算法

前向传播算法:

FlashAttention 的前向传播将 Q,K,V 沿序列长度维度分块,在一个融合的 CUDA kernel 中完成全部计算。算法流程如下:

输入: Q,K,VRN×d,存储在 HBM 中;块大小 Br(Q 的行块大小)和 Bc(KV 的列块大小)。

输出: ORN×d,写入 HBM。

// 每个线程块负责计算 O 的一个行块 O_i
初始化: O_i = 0,  m_i = -∞,  ℓ_i = 0

for j = 1 to ⌈N/B_c⌉:                    // 遍历 KV 块
    从 HBM 加载 K_j, V_j 到 SRAM           // 每块大小 B_c × d
    __syncthreads()

    // 步骤 1: 计算局部注意力得分
    S_ij = Q_i · K_j^T                     // B_r × B_c,在 SRAM/寄存器中完成

    // 步骤 2: 在线 Softmax——局部统计量
    m_ij = rowmax(S_ij)
    P_ij = exp(S_ij - m_ij)
    ℓ_ij = rowsum(P_ij)

    // 步骤 3: 在线 Softmax——全局更新
    m_new = max(m_i, m_ij)
    ℓ_new = ℓ_i · exp(m_i - m_new) + ℓ_ij · exp(m_ij - m_new)

    // 步骤 4: 更新输出(含伸缩校正)
    O_i = (ℓ_i / ℓ_new) · exp(m_i - m_new) · O_i
        + (1 / ℓ_new) · exp(m_ij - m_new) · P_ij · V_j

    m_i = m_new,  ℓ_i = ℓ_new
    __syncthreads()

将 O_i 写回 HBM
存储 m_i, ℓ_i 到 HBM(供反向传播使用)

整个过程中,N×N 的注意力矩阵从未被写入 HBM——每个 Br×Bc 的小块 Sij 在 SRAM 中计算、使用、然后丢弃。

反向传播与重计算策略:

标准反向传播需要存储注意力矩阵 P 以计算梯度,这会重新引入 O(N2) 的内存开销。FlashAttention 采用**重计算(recomputation)**策略来规避这一问题:

  • 前向传播时:仅存储最终输出 O 和 Softmax 统计量 (mi,i),丢弃所有中间注意力矩阵块。
  • 反向传播时:从 HBM 重新加载 Q,K,V 的对应块到 SRAM,即时重新计算所需的注意力矩阵块 SijPij,然后在 SRAM 中完成梯度计算。

这是一个经典的以计算换内存的权衡。由于注意力计算在 FlashAttention 中已经从内存受限转变为计算受限,重计算引入的额外 FLOPs 几乎不会影响墙钟时间——GPU 的计算单元本来就没有被充分利用。实测表明,反向传播的重计算仅增加约 25%–33% 的总 FLOPs,但避免了 O(N2) 的内存开销,净效果是正面的。

9.4.4 FlashAttention 1/2/3 演进

FlashAttention 经历了三代迭代,每一代都在前代基础上挖掘更深层的硬件潜力。

特性FlashAttention-1 (2022)FlashAttention-2 (2023)FlashAttention-3 (2024)
目标硬件A100 (Ampere)A100 (Ampere)H100 (Hopper)
核心创新IO-aware 分块 + 在线 Softmax减少非矩阵乘 FLOPs、优化并行异步流水线、FP8 支持、低精度
循环结构外层遍历 KV,内层遍历 Q外层遍历 Q,内层遍历 KV继承 FA2 + warp 特化
并行维度batch × headsbatch × heads × seq_len(Q 维度)batch × heads × seq_len
非 GEMM 操作优化未特别优化将 rescale 等操作移至寄存器WGMMA 异步执行 + softmax 重叠
Tensor Core 利用率~30–50% MMA~50–73% MMA~75%+ WGMMA
典型加速(vs 标准注意力)2–4x在 FA1 基础上再快 ~2x在 FA2 基础上再快 1.5–2x
FP8 支持是(非一致性量化)
因果掩码优化基础跳过精细的块级跳过块级跳过 + 异步掩码

FlashAttention-1 的核心贡献:

  • 首次将 IO-aware 的分块算法引入注意力计算,通过在线 Softmax 实现精确的分块注意力
  • 将内存占用从 O(N2) 降至 O(N)
  • 将注意力从 memory-bound 转变为更接近 compute-bound
  • 反向传播采用重计算策略,避免存储注意力矩阵

FlashAttention-2 的关键改进:

  1. 减少非矩阵乘 FLOPs。 FA1 中有大量的 rescale(伸缩校正)操作在共享内存中完成,无法利用 Tensor Core。FA2 重新组织了计算流程,将在线 Softmax 的统计量更新和 rescale 推迟到最后执行,并尽可能在寄存器中完成,减少了约 50% 的非 GEMM FLOPs。

  2. 循环结构反转与并行性提升。 FA1 的外层循环遍历 KV 块,内层循环遍历 Q 块。这意味着每处理一个新的 KV 块都需要更新所有 Q 块的输出,造成频繁的共享内存读写和同步。FA2 将循环反转为外层遍历 Q 块、内层遍历 KV 块。每个线程块独立负责一个 Q 块的完整计算,输出累积在寄存器中,无需线程块间通信。这一改变还使得可以在序列长度维度上增加并行度——不同的线程块处理不同的 Q 块,而非不同的 head。

  3. warp 级别分工。 FA2 在一个线程块内将 warp 分为两组:一组负责 GEMM 计算(QKPV),另一组负责 Softmax 等非 GEMM 操作。两组 warp 通过共享内存交换数据,实现了计算的流水线化。

FlashAttention-3 对 Hopper 架构的深度适配:

  1. WGMMA 指令与异步执行。 H100 引入了 Warp Group Matrix Multiply-Accumulate(WGMMA)指令,可以直接从共享内存发射矩阵乘法到 Tensor Core,无需先加载到寄存器。FA3 利用 WGMMA 实现了 GEMM 与 Softmax 的真正异步流水线:当 Tensor Core 执行矩阵乘法时,CUDA Core 同时执行 Softmax 的指数运算和归约操作。

  2. TMA(Tensor Memory Accelerator)硬件单元。 Hopper 架构提供了专用的 TMA 单元来异步搬运数据,FA3 利用 TMA 实现了 HBM→SRAM 数据加载与计算的完全重叠(prefetch 下一块数据的同时计算当前块)。

  3. FP8 低精度支持。 FA3 支持 FP8(E4M3/E5M2)精度的注意力计算,通过非一致性量化(incoherent processing)技术缓解 FP8 的精度损失——对 Q 和 K 分别使用随机正交变换,使得量化误差更加均匀分布。在 FP8 模式下,FA3 可达到接近 1.2 PFLOPS 的峰值吞吐。

  4. 块级稀疏与因果掩码。 对于因果掩码(causal mask),FA3 在块级别判断哪些 QK 块对完全被掩码覆盖,直接跳过计算。结合异步掩码应用,进一步减少了无效计算。

9.4.5 Triton 实现 FlashAttention

CUDA 实现的 FlashAttention 虽然性能极高,但代码复杂度也极高(FA2 的 CUDA 代码超过数千行)。OpenAI 的 Triton 语言提供了一种更高层次的 GPU 编程抽象,使得 FlashAttention 的核心逻辑可以用数十行 Python 风格的代码表达,同时保持接近手写 CUDA 的性能。

前向传播的 Triton 实现要点:

python
# 教学示例:展示核心逻辑,省略了部分 import 和辅助函数定义
@triton.jit
def flash_attn_fwd_kernel(
    Q, K, V, O,          # 指向 HBM 中矩阵的指针
    Lse,                  # log-sum-exp 统计量(供反向传播使用)
    stride_qb, stride_qh, stride_qm, stride_qk,  # Q 的各维度步长
    # ... K, V, O 的步长类似 ...
    N_CTX,                # 序列长度
    BLOCK_M: tl.constexpr,  # Q 块大小(编译期常量)
    BLOCK_N: tl.constexpr,  # KV 块大小
    BLOCK_D: tl.constexpr,  # head 维度
):
    # 确定当前程序实例负责的 Q 块索引
    start_m = tl.program_id(0) * BLOCK_M
    off_hz = tl.program_id(1)  # batch × head 的联合索引

    # 初始化累加器
    m_i = tl.full([BLOCK_M], float('-inf'), dtype=tl.float32)
    l_i = tl.full([BLOCK_M], 0.0, dtype=tl.float32)
    acc = tl.zeros([BLOCK_M, BLOCK_D], dtype=tl.float32)

    # 加载 Q 块(整个循环中常驻 SRAM/寄存器)
    q = tl.load(Q_block_ptr)  # [BLOCK_M, BLOCK_D]

    # 内层循环:遍历 KV 块
    for start_n in range(0, N_CTX, BLOCK_N):
        k = tl.load(K_block_ptr)  # [BLOCK_N, BLOCK_D]
        v = tl.load(V_block_ptr)  # [BLOCK_N, BLOCK_D]

        # 计算 S = Q · K^T
        s = tl.dot(q, tl.trans(k))  # [BLOCK_M, BLOCK_N]

        # 在线 Softmax
        m_ij = tl.max(s, axis=1)              # 当前块行最大值
        m_new = tl.maximum(m_i, m_ij)         # 全局最大值更新
        alpha = tl.exp(m_i - m_new)           # 旧输出的伸缩因子
        p = tl.exp(s - m_new[:, None])        # 当前块的 exp 值
        l_new = l_i * alpha + tl.sum(p, axis=1)

        # 更新累加输出
        acc = acc * alpha[:, None] + tl.dot(p.to(v.dtype), v)

        m_i = m_new
        l_i = l_new
        # 推进 KV 块指针...

    # 最终归一化
    acc = acc / l_i[:, None]
    # 存储 log-sum-exp 供反向传播
    lse = m_i + tl.log(l_i)
    tl.store(Lse_ptr, lse)
    tl.store(O_block_ptr, acc.to(O.dtype.element_ty))

Triton 的关键抽象包括:

  • tl.program_id:类似 CUDA 的 blockIdx,确定当前线程块处理的数据分片
  • tl.load / tl.store:显式的 HBM↔SRAM 数据搬运,支持块指针(block pointer)自动处理步长和边界
  • tl.dot:映射到 Tensor Core 的矩阵乘法指令
  • 编译期常量 tl.constexpr:块大小在编译时确定,编译器可据此优化寄存器分配和共享内存布局

反向传播的 Triton 实现:

反向传播需要计算 dQ,dK,dV 三个梯度。其核心挑战在于 dQ 的累加方向与 dK,dV 的累加方向不同,导致需要两个独立的 kernel(或使用原子操作)。

Triton 反向传播的关键步骤:

  1. 预计算 DiDi=rowsum(dOiOi),这是 Softmax 反向传播的一个辅助量
  2. Kernel 1(计算 dVdK:外层遍历 KV 块,内层遍历 Q 块。对每个 QK 块对,重计算 SijPij(使用前向传播保存的 LSE),然后累加 dVj+=PijdOidKj+=dSijQi
  3. Kernel 2(计算 dQ:外层遍历 Q 块,内层遍历 KV 块。同样重计算 Pij,然后累加 dQi+=dSijKj

其中 Softmax 梯度为:

dSij=Pij(dPijDi)

Triton 实现的 FlashAttention 在 A100 上可达到手写 CUDA 实现约 80%–90% 的性能,但代码量仅为后者的 1/10 左右,极大地降低了开发和维护成本。

9.4.6 FlashMLA:DeepSeek-V2 的 Kernel 级优化

FlashMLA 是 DeepSeek 团队为其 DeepSeek-V2/V3 模型中的Multi-head Latent Attention (MLA) 机制开发的高性能推理 kernel。MLA 是对标准 Multi-Head Attention 的一种变体,通过将 KV 缓存压缩到低秩**潜在空间(latent space)**来大幅减少 KV Cache 的内存占用。FlashMLA 则是这一架构在 Hopper GPU 上的高效实现。

MLA 的架构特点:

标准 MHA 中,每个注意力头独立维护各自的 Key 和 Value 缓存,内存占用为 O(nhNdh),其中 nh 为头数、dh 为头维度。MLA 的核心思想是将 KV 对压缩为低秩潜在向量:

ctKV=WDKVxtRdc

其中 dcnhdh(典型设置中 dc=512nhdh=16384)。在推理时,KV 缓存仅需存储 ctKV,然后在计算注意力时通过上投影矩阵恢复 Key 和 Value:

Kt=WUKctKV,Vt=WUVctKV

这使得 KV 缓存的内存占用降低了一个数量级以上。

FlashMLA 的 Kernel 设计挑战与方案:

MLA 的推理场景与标准注意力有本质不同,带来了独特的 kernel 设计挑战:

  1. 上投影融合。 朴素实现需要先将压缩的 KV 缓存通过矩阵乘法上投影为完整的 K 和 V,再执行注意力计算。这会引入大量额外的 HBM 读写。FlashMLA 将上投影操作融合到注意力 kernel 内部——在 SRAM 中加载压缩的 cKV 块后,先在片上完成上投影得到 K 和 V,然后直接进行注意力计算,中间不经过 HBM。

  2. 分页 KV 缓存(Paged KV Cache)。 在实际推理系统中,不同请求的序列长度各异,KV 缓存需要动态管理。FlashMLA 原生支持分页存储——KV 缓存被分为固定大小的页(page),通过页表(page table)索引。kernel 内部根据页表动态定位每个 KV 块的物理地址,实现了对碎片化内存的高效访问。

  3. 变长序列批处理。 一个推理 batch 中各请求的序列长度通常不同。FlashMLA 使用 tile_scheduler 在 kernel 启动前预计算每个线程块的工作分配,将不等长的序列映射为等大的 tile,确保 GPU 的占用率和负载均衡。

FlashMLA 对 Hopper 架构的利用:

FlashMLA 深度利用了 H100 的硬件特性:

  • TMA 异步数据搬运。 使用 Tensor Memory Accelerator 执行 HBM→共享内存的异步 copy,与计算完全重叠。FlashMLA 采用多级流水线(multi-stage pipeline),在计算当前 KV 块的同时预取下一个块。

  • WGMMA 矩阵乘法。 所有的矩阵乘法(上投影和注意力得分计算)均通过 WGMMA 指令发射到 Tensor Core,直接从共享内存读取操作数,避免了寄存器中转的开销。

  • 共享内存的精细管理。 FlashMLA 将共享内存划分为多个区域:KV 缓存页的缓冲区、Q 块的常驻区、以及中间结果的暂存区。通过 __syncthreads() 和异步栅栏(async barrier)精确控制生产者-消费者之间的同步。

  • Warp 特化(Warp Specialization)。 线程块内的 warp 被分为不同角色:部分 warp 负责 TMA 数据搬运(生产者),部分 warp 负责 WGMMA 计算(消费者)。两组 warp 通过异步栅栏协调,形成高效的流水线。

性能特征:

FlashMLA 针对推理场景(decode 阶段)进行了重点优化,此时每个请求的查询长度为 1,但 KV 缓存长度可达数万 token。这使得注意力计算成为典型的 memory-bound 操作(算术强度很低)。FlashMLA 的优化目标是最大化 HBM 带宽利用率:

  • 在 H800 GPU 上,FlashMLA 可达到接近 3000 GB/s 的有效 HBM 带宽利用率(理论峰值约 3.35 TB/s)
  • 相比朴素实现,端到端推理吞吐提升可达 2–4 倍
  • KV 缓存的内存压缩比约为 13:1(相比标准 MHA)

9.4.7 小结

FlashAttention 系列和 FlashMLA 展示了IO-aware 算法设计的核心思想:在算法正确性不变的前提下,通过重组计算顺序来适配硬件的内存层级结构,将瓶颈从内存带宽转移到计算能力上。

几个关键的设计原则贯穿始终:

  1. 分块(Tiling)是基石。 将大问题分解为能放入 SRAM 的小块,最大化数据局部性和重用。
  2. 在线算法解锁分块。 在线 Softmax 使得看似需要全局信息的操作也能以流式方式完成,这是分块注意力计算的数学基础。
  3. 算子融合消灭中间读写。 将多个操作合并为一个 kernel,所有中间结果在 SRAM/寄存器中流转,不经过 HBM。
  4. 以计算换内存。 重计算策略在计算资源充裕而内存带宽紧张的场景下,是一个正收益的权衡。
  5. 硬件感知的持续演进。 从 FA1 到 FA3 再到 FlashMLA,每一代都深入利用目标硬件的新特性(Tensor Core→WGMMA→TMA→分页内存),追求理论峰值性能。

这些原则不仅适用于注意力计算,也为其他内存密集型算子的优化提供了通用的方法论。