9.3 CUDA 与 Triton 编程
前两节讨论了 GPU 的硬件架构与执行模型,以及算子融合的核心思想。本节将从"纸上谈兵"走向"亲手实现"——用 CUDA C++ 和 Triton 两种语言编写 GeLU 和 Softmax 算子,在实践中体会算子融合带来的性能飞跃。最后,我们还将简要窥探编译产物 PTX,看看高级代码在 GPU 底层究竟被翻译成了什么。
9.3.1 CUDA 编程基础
CUDA(Compute Unified Device Architecture)是 NVIDIA 提供的并行计算平台,它在 C++ 语言基础上扩展了一套 GPU 编程原语,使开发者可以编写在数千线程上并行执行的 Kernel 函数。
Kernel 与执行配置。 一个 CUDA Kernel 用 __global__ 关键字声明,表示该函数由 CPU 发起调用、在 GPU 上执行。启动 Kernel 时需要通过 <<<num_blocks, block_size>>> 语法指定**网格(Grid)**中的线程块数量和每个线程块内的线程数量。每个线程通过内置变量 blockIdx.x、threadIdx.x 和 blockDim.x 计算出自己的全局索引:
int i = blockIdx.x * blockDim.x + threadIdx.x;这条公式将 Grid-Block-Thread 的层次结构映射到线性的内存地址上,是几乎所有一维 CUDA Kernel 的起手式。
CPU-GPU 异步执行。 理解 GPU 编程必须掌握的一个核心概念是:CPU 和 GPU 是异步工作的。当 CPU 发出一条 Kernel 启动指令后,它不会等待 GPU 执行完毕,而是立即返回继续执行后续代码。GPU 的调度器独立地从指令队列中取出任务并执行。这种异步解耦是高吞吐量的基础——CPU 可以在 GPU 忙于计算时持续准备下一批数据和 Kernel。
这带来一个重要的实践问题:如果你在 Python 中用 time.time() 计时一段 GPU 代码,测到的可能只是 CPU 提交任务的时间,而非 GPU 实际执行的时间。正确的做法是在计时前后调用 torch.cuda.synchronize(),强制 CPU 等待 GPU 完成所有已提交的任务。类似地,print(loss.item()) 这样的语句会触发隐式同步——CPU 必须等待 GPU 把标量值算出来才能打印——如果在训练循环中频繁出现,就会打断 CPU-GPU 的流水线,成为意外的性能杀手。
9.3.2 CUDA 实现 GeLU
GeLU 的 tanh 近似公式为:
先看一个反面教材——用 PyTorch 基础操作手动实现 GeLU:
# 教学示例:展示核心逻辑,省略了部分 import 和辅助函数定义
def manual_gelu(x: torch.Tensor):
return 0.5 * x * (1 + torch.tanh(0.79788456 * (x + 0.044715 * x * x * x)))这行代码看起来简洁优雅,但在 GPU 上执行时,每个算术操作(乘法、加法、tanh)都可能触发一次独立的 Kernel 调用,每次调用都要从全局内存读取中间结果、计算后再写回。对于一个有 manual_gelu 的执行剖面由一连串小型 Kernel(mul、add、tanh)组成,是典型的内存瓶颈。
现在用 CUDA C++ 将整个公式融合到一个 Kernel 中:
#include <cmath>
#include <torch/extension.h>
// 设备端函数:单个元素的 GeLU 计算
__device__ float gelu_forward(float x) {
return 0.5f * x * (1.0f + tanhf(0.79788456f * (x + 0.044715f * x * x * x)));
}
// GPU Kernel:每个线程处理一个元素
__global__ void gelu_kernel(const float* in, float* out, int num_elements) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < num_elements) {
out[i] = gelu_forward(in[i]);
}
}
// 主机端接口:分配输出、配置网格、启动 Kernel
torch::Tensor gelu(torch::Tensor x) {
TORCH_CHECK(x.is_cuda(), "Input must be a CUDA tensor");
TORCH_CHECK(x.is_contiguous(), "Input must be contiguous");
auto y = torch::empty_like(x);
int num_elements = x.numel();
int block_size = 1024;
int num_blocks = (num_elements + block_size - 1) / block_size;
gelu_kernel<<<num_blocks, block_size>>>(
x.data_ptr<float>(), y.data_ptr<float>(), num_elements);
return y;
}逐行拆解核心部分:
__device__函数gelu_forward在 GPU 上执行、仅可由 GPU 端代码调用,封装了 GeLU 的完整数学公式。__global__函数gelu_kernel是从 CPU 启动的 Kernel。每个线程通过全局索引i定位到自己负责的元素,执行边界检查后调用gelu_forward。- 主机端函数
gelu负责分配输出张量、计算网格配置(向上取整确保所有元素都被覆盖),然后启动 Kernel。 - 通过 PyTorch 的
torch.utils.cpp_extension.load_inline,可以在 Python 中动态编译这段 CUDA 代码并直接调用。
性能对比。 在典型的基准测试中(如
9.3.3 Triton 编程模型
CUDA 功能强大,但 C++ 语法和手动的内存管理使得开发门槛较高。Triton 是由 OpenAI 开发的领域特定语言,允许开发者用 Python 语法编写 GPU Kernel,同时由编译器自动处理内存合并(memory coalescing)、共享内存管理和指令调度等底层优化。
Triton 与 CUDA 最根本的区别在于编程抽象的层次:
| 维度 | CUDA | Triton |
|---|---|---|
| 编程单元 | 单个线程 | 线程块(程序实例) |
| 数据操作 | 标量(in[i]) | 向量/块(tl.load/tl.store) |
| 内存合并 | 手动设计访存模式 | 编译器自动优化 |
| 共享内存 | 显式分配与同步 | 编译器自动管理 |
| 语言 | C++ | Python |
在 Triton 中,开发者通过 tl.program_id(axis) 获取当前程序实例(线程块)的编号,通过 tl.arange 创建块内偏移量向量,然后用 tl.load 和 tl.store 以向量化方式一次性读写整块数据。所有计算都作用于这些向量,编译器负责将其映射到高效的底层指令。
9.3.4 Triton 实现 GeLU
以下是完整的 Triton GeLU 实现:
import torch
import triton
import triton.language as tl
@triton.jit
def triton_gelu_kernel(x_ptr, y_ptr, num_elements, BLOCK_SIZE: tl.constexpr):
# 1. 获取当前程序实例的 ID(对应一个线程块)
pid = tl.program_id(axis=0)
# 2. 计算当前块负责的元素偏移量
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
# 3. 边界掩码:防止读写超出张量范围
mask = offsets < num_elements
# 4. 向量化加载:一次性读入 BLOCK_SIZE 个元素
x = tl.load(x_ptr + offsets, mask=mask)
# 5. 在寄存器中完成全部 GeLU 计算
a = 0.79788456 * (x + 0.044715 * x * x * x)
exp = tl.exp(2 * a)
tanh = (exp - 1) / (exp + 1)
y = 0.5 * x * (1 + tanh)
# 6. 向量化写回
tl.store(y_ptr + offsets, y, mask=mask)
def triton_gelu(x: torch.Tensor):
assert x.is_cuda and x.is_contiguous()
y = torch.empty_like(x)
num_elements = x.numel()
block_size = 1024
num_blocks = triton.cdiv(num_elements, block_size)
triton_gelu_kernel[(num_blocks,)](x, y, num_elements, BLOCK_SIZE=block_size)
return y对比 CUDA 版本,几个关键差异值得注意:
- 块级思维。 CUDA Kernel 的逻辑是"我是第
i个线程,处理第i个元素";Triton Kernel 的逻辑是"我是第pid个块,处理从block_start开始的BLOCK_SIZE个元素"。后者更接近深度学习中"操作一个张量块"的思维习惯。 - 掩码替代 if。 CUDA 用
if (i < num_elements)做边界检查;Triton 用mask参数在tl.load和tl.store中统一处理,边界外的元素不会被读写。 - tanh 的实现。 这里用了恒等式
,因为 Triton 语言提供 tl.exp但不直接提供tl.tanh。这也展示了 Triton 编程有时需要手动展开数学公式。
性能表现。 Triton 版本与手写 CUDA 版本性能几乎持平,但代码量减少约一半,且完全在 Python 环境中完成编写、调试和调用。Triton 编译器在背后自动完成了内存合并和线程粗化(让每个物理线程实际处理多个元素以隐藏延迟),生成的底层代码质量与手工优化的 CUDA 相当。
9.3.5 Triton 实现 Softmax
GeLU 是逐元素操作,每个元素的计算独立于其他元素,融合起来相对简单。Softmax 则更具挑战性——它包含归约操作(求最大值、求和),一行内的元素之间存在数据依赖:
减去
如果用 PyTorch 基础操作手动实现,每一步(max、减法、exp、sum、除法)都是一次独立的 Kernel 调用,一行数据在全局内存和 SM 之间来回搬运多达 5 次。Triton 的解决方案是将一整行的所有计算融合到一个 Kernel 中:
import torch
import triton
import triton.language as tl
@triton.jit
def triton_softmax_kernel(
x_ptr, y_ptr, x_row_stride, y_row_stride,
num_cols, BLOCK_SIZE: tl.constexpr
):
# 1. 每个程序实例处理一行
row_idx = tl.program_id(0)
# 2. 计算当前行的列偏移量
col_offsets = tl.arange(0, BLOCK_SIZE)
x_ptrs = x_ptr + row_idx * x_row_stride + col_offsets
# 3. 加载一整行,边界外填充 -inf
mask = col_offsets < num_cols
x_row = tl.load(x_ptrs, mask=mask, other=float("-inf"))
# 4. 片上计算:所有中间结果都在寄存器中
row_max = tl.max(x_row, axis=0) # 归约:求最大值
x_row = x_row - row_max # 数值稳定
numerator = tl.exp(x_row) # 指数
denominator = tl.sum(numerator, axis=0) # 归约:求和
y_row = numerator / denominator # 归一化
# 5. 写回一整行结果
y_ptrs = y_ptr + row_idx * y_row_stride + col_offsets
tl.store(y_ptrs, y_row, mask=mask)
def triton_softmax(x: torch.Tensor):
y = torch.empty_like(x)
M, N = x.shape
block_size = triton.next_power_of_2(N) # 向上取整到 2 的幂
triton_softmax_kernel[(M,)](
x, y, x.stride(0), y.stride(0), N, BLOCK_SIZE=block_size
)
return y这段代码的核心设计是**"一行一块"(one row per block)**:网格大小等于行数
- 通过
tl.load将一整行个元素加载到 SM 的高速寄存器中(边界外用 填充,确保不影响 max和exp的结果)。 - 在寄存器中依次完成
max(归约)、减法、exp、sum(归约)、除法——所有中间结果从未写回全局内存。 - 通过
tl.store将最终结果一次性写回。
为什么 BLOCK_SIZE 要取 2 的幂? 这是一个硬件友好的选择。GPU 的内存子系统和归约操作在处理 2 的幂大小的数据块时效率最高。triton.next_power_of_2(N) 会将实际列数向上取整,多出的位置由掩码屏蔽。
性能对比。 非融合的手动 Softmax 实现对一行数据产生约 torch.compile 自动生成的版本相当。
9.3.6 PTX:编译产物检查与性能归因 [选读]
PTX(Parallel Thread Execution)是 NVIDIA 定义的 GPU 中间表示语言,可以理解为 GPU 的"汇编语言"。当我们编写 CUDA 或 Triton 代码后,编译器首先将其翻译为 PTX;GPU 驱动中的 JIT 编译器再将 PTX 进一步编译为特定 GPU 型号可执行的二进制码(SASS)。
为什么要看 PTX? 它是验证编译器优化是否生效的终极手段。当性能未达预期时,检查 PTX 可以回答以下问题:
- 内存合并是否生效? 如果看到
ld.global.v4.f32(一次加载 4 个 float32),说明编译器成功地将相邻线程的内存请求合并为一次宽事务;如果只看到ld.global.f32(逐个加载),则可能存在低效的访存模式。 - 线程粗化是否发生? 分析 Triton GeLU Kernel 生成的 PTX 时,可以观察到单个线程实际处理了 4 到 8 个元素(出现多组
ld.global.v4.f32+ 计算 +st.global.v4.f32序列),而非"一个线程一个元素"的朴素映射。线程粗化让每个线程承担更多工作,更好地分摊线程调度开销和隐藏内存延迟。 - 数学运算如何映射? GeLU 近似公式中的
tanh在 PTX 层面被分解为ex2.approx.f32(以 2 为底的指数近似)、mul.f32、add.f32等基础指令。通过阅读这些指令序列,可以确认编译器没有引入多余的计算步骤。
以下是 Triton GeLU Kernel 生成的 PTX 片段示例(已简化):
// 寄存器声明
.reg .f32 %f<32>;
.reg .b32 %r<8>;
// 获取线程块 ID 和线程 ID
mov.u32 %r1, %ctaid.x; // blockIdx.x
mov.u32 %r2, %tid.x; // threadIdx.x
// 向量化加载:一次读 4 个 float32
ld.global.v4.f32 {%f1, %f2, %f3, %f4}, [%rd5];
// GeLU 计算(以 %f1 为例)
mul.f32 %f5, %f1, %f1; // x * x
mul.f32 %f6, %f5, %f1; // x^3
mul.f32 %f7, %f6, 0f3D372713; // 0.044715 * x^3
add.f32 %f8, %f1, %f7; // x + 0.044715*x^3
mul.f32 %f9, %f8, 0f3F4C422A; // sqrt(2/pi) * (...)
// ... tanh 近似 + 最终乘法 ...
// 向量化写回
st.global.v4.f32 [%rd6], {%f25, %f26, %f27, %f28};从这段 PTX 中可以确认三个关键优化:(1) v4 后缀表明向量化加载/存储生效;(2) 所有中间值存储在 %f 寄存器中,从未溢出到全局内存;(3) 每个线程处理多个元素(线程粗化)。
PTX 的硬件抽象价值。 PTX 还扮演着"向前兼容"的角色。同一份 PTX 代码可以在未来的新 GPU 上被重新编译为针对新硬件优化的 SASS 码,无需修改上层代码。这使得 CUDA 和 Triton 程序具有跨代 GPU 的可移植性。
注意,日常开发中通常不需要阅读 PTX。它更多地作为调试和性能归因的最后手段——当 Nsight 等工具告诉你某个 Kernel 慢,但你看不出为什么时,PTX 是那扇通往真相的窗户。
9.3.7 本节小结
本节从三个层次展示了 GPU 算子开发的实践:
| 层次 | 工具 | 特点 | 适用场景 |
|---|---|---|---|
| 底层 | CUDA C++ | 精细控制,开发成本高 | 极致优化、非标准硬件操作 |
| 中层 | Triton | Python 语法,编译器自动优化 | 自定义算子的首选(如 FlashAttention) |
| 高层 | torch.compile | 零手动编写,自动融合 | 标准操作的加速 |
核心经验可以归纳为三条:
- 算子融合是性能的分水岭。 同样的数学公式,未融合时(
manual_gelu)每个操作独立触发 Kernel,数据在全局内存中反复搬运;融合后(CUDA/Triton 版本)数据只读写全局内存一次,性能差距可达 4~8 倍。 - Triton 在易用性和性能之间取得了出色平衡。 它的块级编程模型比 CUDA 的线程级模型更符合深度学习的思维习惯,编译器自动处理了内存合并、共享内存和线程粗化等底层优化,代码量减少约一半,性能却与手写 CUDA 持平。
- 理解底层原理始终有价值。 随着
torch.compile等自动编译器的成熟,大多数标准操作无需手写 Kernel。但当你需要实现 FlashAttention 这样具有非常规内存访问模式的新算法时,理解 CUDA 执行模型、CPU-GPU 异步机制和 PTX 编译产物,是写出正确且高效代码的基础。