9.4 FlashAttention 与 FlashMLA
自注意力机制是 Transformer 的核心,但其标准实现存在严峻的性能瓶颈——计算和内存复杂度均为
FlashAttention 系列算法正是针对这一瓶颈,从 IO-aware 的角度重新设计注意力计算流程,将其从内存受限转变为计算受限,实现了精确注意力的大幅加速与内存节省。
9.4.1 GPU 内存层级与 IO 复杂度分析
要理解 FlashAttention 的设计动机,首先需要了解 GPU 的内存层级结构。
GPU 内存层级是一个金字塔模型:
| 层级 | 位置 | 容量(典型值) | 延迟(时钟周期) | 带宽 |
|---|---|---|---|---|
| 寄存器(Registers) | SM 内部 | ~256 KB/SM | ~1 | 最高 |
| 共享内存/L1(SRAM) | SM 内部 | 48–228 KB/SM | ~20–30 | ~19 TB/s |
| L2 缓存 | 芯片上,SM 外部 | 数 MB | ~200–300 | ~12 TB/s |
| 全局内存(HBM) | 芯片外部 | 40–80 GB | ~400+ | 1.5–3.35 TB/s |
关键矛盾在于:GPU 的计算能力增长速度远超内存带宽增长速度。以 A100 为例,其 Tensor Core 峰值算力约为 312 TFLOPS(FP16),而 HBM 带宽仅为 2.0 TB/s。这意味着每从 HBM 读取 1 字节数据,GPU 可以完成约 156 次浮点运算。如果一个操作的算术强度(FLOPs/字节)低于这个比值,计算单元就会空转等待数据,形成内存墙。
标准注意力的 IO 复杂度分析:
标准注意力的计算分为以下步骤(每步对应一个独立的 CUDA kernel):
- 计算
:从 HBM 读取 (各 ),写入 ( )到 HBM - 计算
:从 HBM 读取 ,写入 到 HBM - 计算
:从 HBM 读取 ,写入 到 HBM
总的 HBM 访问量为:
其中
FlashAttention 的 IO 复杂度:
FlashAttention 通过分块和算子融合,完全避免了
其中
9.4.2 在线 Softmax:分块计算的数学基础
FlashAttention 的核心挑战在于:Softmax 是一个全局依赖操作。标准的数值稳定 Softmax 公式为:
计算任何一个
在线 Softmax 算法通过维护和增量更新两个统计量来解决这一问题。
设输入向量
:前 个块的全局最大值 :以 为基准的全局指数和
初始化:
递推更新(处理第
- 计算当前块的局部统计量:
- 更新全局最大值:
- 伸缩更新全局指数和:
正确性证明:
该递推的正确性基于指数函数的性质
当引入新块并更新
类似地,
从 Softmax 到注意力输出的在线更新:
在 FlashAttention 中,不仅需要在线计算 Softmax 的分母,还需要同步更新注意力输出
最终输出为
9.4.3 FlashAttention 前向与反向传播算法
前向传播算法:
FlashAttention 的前向传播将
输入:
输出:
// 每个线程块负责计算 O 的一个行块 O_i
初始化: O_i = 0, m_i = -∞, ℓ_i = 0
for j = 1 to ⌈N/B_c⌉: // 遍历 KV 块
从 HBM 加载 K_j, V_j 到 SRAM // 每块大小 B_c × d
__syncthreads()
// 步骤 1: 计算局部注意力得分
S_ij = Q_i · K_j^T // B_r × B_c,在 SRAM/寄存器中完成
// 步骤 2: 在线 Softmax——局部统计量
m_ij = rowmax(S_ij)
P_ij = exp(S_ij - m_ij)
ℓ_ij = rowsum(P_ij)
// 步骤 3: 在线 Softmax——全局更新
m_new = max(m_i, m_ij)
ℓ_new = ℓ_i · exp(m_i - m_new) + ℓ_ij · exp(m_ij - m_new)
// 步骤 4: 更新输出(含伸缩校正)
O_i = (ℓ_i / ℓ_new) · exp(m_i - m_new) · O_i
+ (1 / ℓ_new) · exp(m_ij - m_new) · P_ij · V_j
m_i = m_new, ℓ_i = ℓ_new
__syncthreads()
将 O_i 写回 HBM
存储 m_i, ℓ_i 到 HBM(供反向传播使用)整个过程中,
反向传播与重计算策略:
标准反向传播需要存储注意力矩阵
- 前向传播时:仅存储最终输出
和 Softmax 统计量 ,丢弃所有中间注意力矩阵块。 - 反向传播时:从 HBM 重新加载
的对应块到 SRAM,即时重新计算所需的注意力矩阵块 和 ,然后在 SRAM 中完成梯度计算。
这是一个经典的以计算换内存的权衡。由于注意力计算在 FlashAttention 中已经从内存受限转变为计算受限,重计算引入的额外 FLOPs 几乎不会影响墙钟时间——GPU 的计算单元本来就没有被充分利用。实测表明,反向传播的重计算仅增加约 25%–33% 的总 FLOPs,但避免了
9.4.4 FlashAttention 1/2/3 演进
FlashAttention 经历了三代迭代,每一代都在前代基础上挖掘更深层的硬件潜力。
| 特性 | FlashAttention-1 (2022) | FlashAttention-2 (2023) | FlashAttention-3 (2024) |
|---|---|---|---|
| 目标硬件 | A100 (Ampere) | A100 (Ampere) | H100 (Hopper) |
| 核心创新 | IO-aware 分块 + 在线 Softmax | 减少非矩阵乘 FLOPs、优化并行 | 异步流水线、FP8 支持、低精度 |
| 循环结构 | 外层遍历 KV,内层遍历 Q | 外层遍历 Q,内层遍历 KV | 继承 FA2 + warp 特化 |
| 并行维度 | batch × heads | batch × heads × seq_len(Q 维度) | batch × heads × seq_len |
| 非 GEMM 操作优化 | 未特别优化 | 将 rescale 等操作移至寄存器 | WGMMA 异步执行 + softmax 重叠 |
| Tensor Core 利用率 | ~30–50% MMA | ~50–73% MMA | ~75%+ WGMMA |
| 典型加速(vs 标准注意力) | 2–4x | 在 FA1 基础上再快 ~2x | 在 FA2 基础上再快 1.5–2x |
| FP8 支持 | 否 | 否 | 是(非一致性量化) |
| 因果掩码优化 | 基础跳过 | 精细的块级跳过 | 块级跳过 + 异步掩码 |
FlashAttention-1 的核心贡献:
- 首次将 IO-aware 的分块算法引入注意力计算,通过在线 Softmax 实现精确的分块注意力
- 将内存占用从
降至 - 将注意力从 memory-bound 转变为更接近 compute-bound
- 反向传播采用重计算策略,避免存储注意力矩阵
FlashAttention-2 的关键改进:
减少非矩阵乘 FLOPs。 FA1 中有大量的 rescale(伸缩校正)操作在共享内存中完成,无法利用 Tensor Core。FA2 重新组织了计算流程,将在线 Softmax 的统计量更新和 rescale 推迟到最后执行,并尽可能在寄存器中完成,减少了约 50% 的非 GEMM FLOPs。
循环结构反转与并行性提升。 FA1 的外层循环遍历 KV 块,内层循环遍历 Q 块。这意味着每处理一个新的 KV 块都需要更新所有 Q 块的输出,造成频繁的共享内存读写和同步。FA2 将循环反转为外层遍历 Q 块、内层遍历 KV 块。每个线程块独立负责一个 Q 块的完整计算,输出累积在寄存器中,无需线程块间通信。这一改变还使得可以在序列长度维度上增加并行度——不同的线程块处理不同的 Q 块,而非不同的 head。
warp 级别分工。 FA2 在一个线程块内将 warp 分为两组:一组负责 GEMM 计算(
和 ),另一组负责 Softmax 等非 GEMM 操作。两组 warp 通过共享内存交换数据,实现了计算的流水线化。
FlashAttention-3 对 Hopper 架构的深度适配:
WGMMA 指令与异步执行。 H100 引入了 Warp Group Matrix Multiply-Accumulate(WGMMA)指令,可以直接从共享内存发射矩阵乘法到 Tensor Core,无需先加载到寄存器。FA3 利用 WGMMA 实现了 GEMM 与 Softmax 的真正异步流水线:当 Tensor Core 执行矩阵乘法时,CUDA Core 同时执行 Softmax 的指数运算和归约操作。
TMA(Tensor Memory Accelerator)硬件单元。 Hopper 架构提供了专用的 TMA 单元来异步搬运数据,FA3 利用 TMA 实现了 HBM→SRAM 数据加载与计算的完全重叠(prefetch 下一块数据的同时计算当前块)。
FP8 低精度支持。 FA3 支持 FP8(E4M3/E5M2)精度的注意力计算,通过非一致性量化(incoherent processing)技术缓解 FP8 的精度损失——对 Q 和 K 分别使用随机正交变换,使得量化误差更加均匀分布。在 FP8 模式下,FA3 可达到接近 1.2 PFLOPS 的峰值吞吐。
块级稀疏与因果掩码。 对于因果掩码(causal mask),FA3 在块级别判断哪些 QK 块对完全被掩码覆盖,直接跳过计算。结合异步掩码应用,进一步减少了无效计算。
9.4.5 Triton 实现 FlashAttention
CUDA 实现的 FlashAttention 虽然性能极高,但代码复杂度也极高(FA2 的 CUDA 代码超过数千行)。OpenAI 的 Triton 语言提供了一种更高层次的 GPU 编程抽象,使得 FlashAttention 的核心逻辑可以用数十行 Python 风格的代码表达,同时保持接近手写 CUDA 的性能。
前向传播的 Triton 实现要点:
# 教学示例:展示核心逻辑,省略了部分 import 和辅助函数定义
@triton.jit
def flash_attn_fwd_kernel(
Q, K, V, O, # 指向 HBM 中矩阵的指针
Lse, # log-sum-exp 统计量(供反向传播使用)
stride_qb, stride_qh, stride_qm, stride_qk, # Q 的各维度步长
# ... K, V, O 的步长类似 ...
N_CTX, # 序列长度
BLOCK_M: tl.constexpr, # Q 块大小(编译期常量)
BLOCK_N: tl.constexpr, # KV 块大小
BLOCK_D: tl.constexpr, # head 维度
):
# 确定当前程序实例负责的 Q 块索引
start_m = tl.program_id(0) * BLOCK_M
off_hz = tl.program_id(1) # batch × head 的联合索引
# 初始化累加器
m_i = tl.full([BLOCK_M], float('-inf'), dtype=tl.float32)
l_i = tl.full([BLOCK_M], 0.0, dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_D], dtype=tl.float32)
# 加载 Q 块(整个循环中常驻 SRAM/寄存器)
q = tl.load(Q_block_ptr) # [BLOCK_M, BLOCK_D]
# 内层循环:遍历 KV 块
for start_n in range(0, N_CTX, BLOCK_N):
k = tl.load(K_block_ptr) # [BLOCK_N, BLOCK_D]
v = tl.load(V_block_ptr) # [BLOCK_N, BLOCK_D]
# 计算 S = Q · K^T
s = tl.dot(q, tl.trans(k)) # [BLOCK_M, BLOCK_N]
# 在线 Softmax
m_ij = tl.max(s, axis=1) # 当前块行最大值
m_new = tl.maximum(m_i, m_ij) # 全局最大值更新
alpha = tl.exp(m_i - m_new) # 旧输出的伸缩因子
p = tl.exp(s - m_new[:, None]) # 当前块的 exp 值
l_new = l_i * alpha + tl.sum(p, axis=1)
# 更新累加输出
acc = acc * alpha[:, None] + tl.dot(p.to(v.dtype), v)
m_i = m_new
l_i = l_new
# 推进 KV 块指针...
# 最终归一化
acc = acc / l_i[:, None]
# 存储 log-sum-exp 供反向传播
lse = m_i + tl.log(l_i)
tl.store(Lse_ptr, lse)
tl.store(O_block_ptr, acc.to(O.dtype.element_ty))Triton 的关键抽象包括:
tl.program_id:类似 CUDA 的blockIdx,确定当前线程块处理的数据分片tl.load / tl.store:显式的 HBM↔SRAM 数据搬运,支持块指针(block pointer)自动处理步长和边界tl.dot:映射到 Tensor Core 的矩阵乘法指令- 编译期常量
tl.constexpr:块大小在编译时确定,编译器可据此优化寄存器分配和共享内存布局
反向传播的 Triton 实现:
反向传播需要计算
Triton 反向传播的关键步骤:
- 预计算
: ,这是 Softmax 反向传播的一个辅助量 - Kernel 1(计算
和 ):外层遍历 KV 块,内层遍历 Q 块。对每个 QK 块对,重计算 和 (使用前向传播保存的 LSE),然后累加 和 - Kernel 2(计算
):外层遍历 Q 块,内层遍历 KV 块。同样重计算 ,然后累加
其中 Softmax 梯度为:
Triton 实现的 FlashAttention 在 A100 上可达到手写 CUDA 实现约 80%–90% 的性能,但代码量仅为后者的 1/10 左右,极大地降低了开发和维护成本。
9.4.6 FlashMLA:DeepSeek-V2 的 Kernel 级优化
FlashMLA 是 DeepSeek 团队为其 DeepSeek-V2/V3 模型中的Multi-head Latent Attention (MLA) 机制开发的高性能推理 kernel。MLA 是对标准 Multi-Head Attention 的一种变体,通过将 KV 缓存压缩到低秩**潜在空间(latent space)**来大幅减少 KV Cache 的内存占用。FlashMLA 则是这一架构在 Hopper GPU 上的高效实现。
MLA 的架构特点:
标准 MHA 中,每个注意力头独立维护各自的 Key 和 Value 缓存,内存占用为
其中
这使得 KV 缓存的内存占用降低了一个数量级以上。
FlashMLA 的 Kernel 设计挑战与方案:
MLA 的推理场景与标准注意力有本质不同,带来了独特的 kernel 设计挑战:
上投影融合。 朴素实现需要先将压缩的 KV 缓存通过矩阵乘法上投影为完整的 K 和 V,再执行注意力计算。这会引入大量额外的 HBM 读写。FlashMLA 将上投影操作融合到注意力 kernel 内部——在 SRAM 中加载压缩的
块后,先在片上完成上投影得到 K 和 V,然后直接进行注意力计算,中间不经过 HBM。 分页 KV 缓存(Paged KV Cache)。 在实际推理系统中,不同请求的序列长度各异,KV 缓存需要动态管理。FlashMLA 原生支持分页存储——KV 缓存被分为固定大小的页(page),通过页表(page table)索引。kernel 内部根据页表动态定位每个 KV 块的物理地址,实现了对碎片化内存的高效访问。
变长序列批处理。 一个推理 batch 中各请求的序列长度通常不同。FlashMLA 使用
tile_scheduler在 kernel 启动前预计算每个线程块的工作分配,将不等长的序列映射为等大的 tile,确保 GPU 的占用率和负载均衡。
FlashMLA 对 Hopper 架构的利用:
FlashMLA 深度利用了 H100 的硬件特性:
TMA 异步数据搬运。 使用 Tensor Memory Accelerator 执行 HBM→共享内存的异步 copy,与计算完全重叠。FlashMLA 采用多级流水线(multi-stage pipeline),在计算当前 KV 块的同时预取下一个块。
WGMMA 矩阵乘法。 所有的矩阵乘法(上投影和注意力得分计算)均通过 WGMMA 指令发射到 Tensor Core,直接从共享内存读取操作数,避免了寄存器中转的开销。
共享内存的精细管理。 FlashMLA 将共享内存划分为多个区域:KV 缓存页的缓冲区、Q 块的常驻区、以及中间结果的暂存区。通过
__syncthreads()和异步栅栏(async barrier)精确控制生产者-消费者之间的同步。Warp 特化(Warp Specialization)。 线程块内的 warp 被分为不同角色:部分 warp 负责 TMA 数据搬运(生产者),部分 warp 负责 WGMMA 计算(消费者)。两组 warp 通过异步栅栏协调,形成高效的流水线。
性能特征:
FlashMLA 针对推理场景(decode 阶段)进行了重点优化,此时每个请求的查询长度为 1,但 KV 缓存长度可达数万 token。这使得注意力计算成为典型的 memory-bound 操作(算术强度很低)。FlashMLA 的优化目标是最大化 HBM 带宽利用率:
- 在 H800 GPU 上,FlashMLA 可达到接近 3000 GB/s 的有效 HBM 带宽利用率(理论峰值约 3.35 TB/s)
- 相比朴素实现,端到端推理吞吐提升可达 2–4 倍
- KV 缓存的内存压缩比约为 13:1(相比标准 MHA)
9.4.7 小结
FlashAttention 系列和 FlashMLA 展示了IO-aware 算法设计的核心思想:在算法正确性不变的前提下,通过重组计算顺序来适配硬件的内存层级结构,将瓶颈从内存带宽转移到计算能力上。
几个关键的设计原则贯穿始终:
- 分块(Tiling)是基石。 将大问题分解为能放入 SRAM 的小块,最大化数据局部性和重用。
- 在线算法解锁分块。 在线 Softmax 使得看似需要全局信息的操作也能以流式方式完成,这是分块注意力计算的数学基础。
- 算子融合消灭中间读写。 将多个操作合并为一个 kernel,所有中间结果在 SRAM/寄存器中流转,不经过 HBM。
- 以计算换内存。 重计算策略在计算资源充裕而内存带宽紧张的场景下,是一个正收益的权衡。
- 硬件感知的持续演进。 从 FA1 到 FA3 再到 FlashMLA,每一代都深入利用目标硬件的新特性(Tensor Core→WGMMA→TMA→分页内存),追求理论峰值性能。
这些原则不仅适用于注意力计算,也为其他内存密集型算子的优化提供了通用的方法论。