Skip to content

9.3 CUDA 与 Triton 编程

前两节讨论了 GPU 的硬件架构与执行模型,以及算子融合的核心思想。本节将从"纸上谈兵"走向"亲手实现"——用 CUDA C++ 和 Triton 两种语言编写 GeLU 和 Softmax 算子,在实践中体会算子融合带来的性能飞跃。最后,我们还将简要窥探编译产物 PTX,看看高级代码在 GPU 底层究竟被翻译成了什么。


9.3.1 CUDA 编程基础

CUDA(Compute Unified Device Architecture)是 NVIDIA 提供的并行计算平台,它在 C++ 语言基础上扩展了一套 GPU 编程原语,使开发者可以编写在数千线程上并行执行的 Kernel 函数。

Kernel 与执行配置。 一个 CUDA Kernel 用 __global__ 关键字声明,表示该函数由 CPU 发起调用、在 GPU 上执行。启动 Kernel 时需要通过 <<<num_blocks, block_size>>> 语法指定**网格(Grid)**中的线程块数量和每个线程块内的线程数量。每个线程通过内置变量 blockIdx.xthreadIdx.xblockDim.x 计算出自己的全局索引:

int i = blockIdx.x * blockDim.x + threadIdx.x;

这条公式将 Grid-Block-Thread 的层次结构映射到线性的内存地址上,是几乎所有一维 CUDA Kernel 的起手式。

CPU-GPU 异步执行。 理解 GPU 编程必须掌握的一个核心概念是:CPU 和 GPU 是异步工作的。当 CPU 发出一条 Kernel 启动指令后,它不会等待 GPU 执行完毕,而是立即返回继续执行后续代码。GPU 的调度器独立地从指令队列中取出任务并执行。这种异步解耦是高吞吐量的基础——CPU 可以在 GPU 忙于计算时持续准备下一批数据和 Kernel。

这带来一个重要的实践问题:如果你在 Python 中用 time.time() 计时一段 GPU 代码,测到的可能只是 CPU 提交任务的时间,而非 GPU 实际执行的时间。正确的做法是在计时前后调用 torch.cuda.synchronize(),强制 CPU 等待 GPU 完成所有已提交的任务。类似地,print(loss.item()) 这样的语句会触发隐式同步——CPU 必须等待 GPU 把标量值算出来才能打印——如果在训练循环中频繁出现,就会打断 CPU-GPU 的流水线,成为意外的性能杀手。


9.3.2 CUDA 实现 GeLU

GeLU 的 tanh 近似公式为:

GeLU(x)0.5x(1+tanh[2/π(x+0.044715x3)])

先看一个反面教材——用 PyTorch 基础操作手动实现 GeLU:

python
# 教学示例:展示核心逻辑,省略了部分 import 和辅助函数定义
def manual_gelu(x: torch.Tensor):
    return 0.5 * x * (1 + torch.tanh(0.79788456 * (x + 0.044715 * x * x * x)))

这行代码看起来简洁优雅,但在 GPU 上执行时,每个算术操作(乘法、加法、tanh)都可能触发一次独立的 Kernel 调用,每次调用都要从全局内存读取中间结果、计算后再写回。对于一个有 N 个元素的张量,这意味着多达 8 次全局内存的读写往返。性能分析工具清楚地显示:manual_gelu 的执行剖面由一连串小型 Kernel(muladdtanh)组成,是典型的内存瓶颈。

现在用 CUDA C++ 将整个公式融合到一个 Kernel 中:

cpp
#include <cmath>
#include <torch/extension.h>

// 设备端函数:单个元素的 GeLU 计算
__device__ float gelu_forward(float x) {
    return 0.5f * x * (1.0f + tanhf(0.79788456f * (x + 0.044715f * x * x * x)));
}

// GPU Kernel:每个线程处理一个元素
__global__ void gelu_kernel(const float* in, float* out, int num_elements) {
    int i = blockIdx.x * blockDim.x + threadIdx.x;
    if (i < num_elements) {
        out[i] = gelu_forward(in[i]);
    }
}

// 主机端接口:分配输出、配置网格、启动 Kernel
torch::Tensor gelu(torch::Tensor x) {
    TORCH_CHECK(x.is_cuda(), "Input must be a CUDA tensor");
    TORCH_CHECK(x.is_contiguous(), "Input must be contiguous");
    auto y = torch::empty_like(x);
    int num_elements = x.numel();
    int block_size = 1024;
    int num_blocks = (num_elements + block_size - 1) / block_size;
    gelu_kernel<<<num_blocks, block_size>>>(
        x.data_ptr<float>(), y.data_ptr<float>(), num_elements);
    return y;
}

逐行拆解核心部分:

  • __device__ 函数 gelu_forward 在 GPU 上执行、仅可由 GPU 端代码调用,封装了 GeLU 的完整数学公式。
  • __global__ 函数 gelu_kernel 是从 CPU 启动的 Kernel。每个线程通过全局索引 i 定位到自己负责的元素,执行边界检查后调用 gelu_forward
  • 主机端函数 gelu 负责分配输出张量、计算网格配置(向上取整确保所有元素都被覆盖),然后启动 Kernel。
  • 通过 PyTorch 的 torch.utils.cpp_extension.load_inline,可以在 Python 中动态编译这段 CUDA 代码并直接调用。

性能对比。 在典型的基准测试中(如 N=222 个 float32 元素),手动 PyTorch 版本约需 8.1ms,而 CUDA 融合版本仅需约 1.8ms——快了约 4.5 倍。更重要的是,性能分析只看到一个 Kernel,数据仅经历一次从全局内存读取和一次写回,中间所有计算都在寄存器中完成。这正是算子融合的核心收益。


9.3.3 Triton 编程模型

CUDA 功能强大,但 C++ 语法和手动的内存管理使得开发门槛较高。Triton 是由 OpenAI 开发的领域特定语言,允许开发者用 Python 语法编写 GPU Kernel,同时由编译器自动处理内存合并(memory coalescing)、共享内存管理和指令调度等底层优化。

Triton 与 CUDA 最根本的区别在于编程抽象的层次

维度CUDATriton
编程单元单个线程线程块(程序实例)
数据操作标量(in[i]向量/块(tl.load/tl.store
内存合并手动设计访存模式编译器自动优化
共享内存显式分配与同步编译器自动管理
语言C++Python

在 Triton 中,开发者通过 tl.program_id(axis) 获取当前程序实例(线程块)的编号,通过 tl.arange 创建块内偏移量向量,然后用 tl.loadtl.store 以向量化方式一次性读写整块数据。所有计算都作用于这些向量,编译器负责将其映射到高效的底层指令。


9.3.4 Triton 实现 GeLU

以下是完整的 Triton GeLU 实现:

python
import torch
import triton
import triton.language as tl

@triton.jit
def triton_gelu_kernel(x_ptr, y_ptr, num_elements, BLOCK_SIZE: tl.constexpr):
    # 1. 获取当前程序实例的 ID(对应一个线程块)
    pid = tl.program_id(axis=0)

    # 2. 计算当前块负责的元素偏移量
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)

    # 3. 边界掩码:防止读写超出张量范围
    mask = offsets < num_elements

    # 4. 向量化加载:一次性读入 BLOCK_SIZE 个元素
    x = tl.load(x_ptr + offsets, mask=mask)

    # 5. 在寄存器中完成全部 GeLU 计算
    a = 0.79788456 * (x + 0.044715 * x * x * x)
    exp = tl.exp(2 * a)
    tanh = (exp - 1) / (exp + 1)
    y = 0.5 * x * (1 + tanh)

    # 6. 向量化写回
    tl.store(y_ptr + offsets, y, mask=mask)

def triton_gelu(x: torch.Tensor):
    assert x.is_cuda and x.is_contiguous()
    y = torch.empty_like(x)
    num_elements = x.numel()
    block_size = 1024
    num_blocks = triton.cdiv(num_elements, block_size)
    triton_gelu_kernel[(num_blocks,)](x, y, num_elements, BLOCK_SIZE=block_size)
    return y

对比 CUDA 版本,几个关键差异值得注意:

  1. 块级思维。 CUDA Kernel 的逻辑是"我是第 i 个线程,处理第 i 个元素";Triton Kernel 的逻辑是"我是第 pid 个块,处理从 block_start 开始的 BLOCK_SIZE 个元素"。后者更接近深度学习中"操作一个张量块"的思维习惯。
  2. 掩码替代 if。 CUDA 用 if (i < num_elements) 做边界检查;Triton 用 mask 参数在 tl.loadtl.store 中统一处理,边界外的元素不会被读写。
  3. tanh 的实现。 这里用了恒等式 tanh(a)=(e2a1)/(e2a+1),因为 Triton 语言提供 tl.exp 但不直接提供 tl.tanh。这也展示了 Triton 编程有时需要手动展开数学公式。

性能表现。 Triton 版本与手写 CUDA 版本性能几乎持平,但代码量减少约一半,且完全在 Python 环境中完成编写、调试和调用。Triton 编译器在背后自动完成了内存合并和线程粗化(让每个物理线程实际处理多个元素以隐藏延迟),生成的底层代码质量与手工优化的 CUDA 相当。


9.3.5 Triton 实现 Softmax

GeLU 是逐元素操作,每个元素的计算独立于其他元素,融合起来相对简单。Softmax 则更具挑战性——它包含归约操作(求最大值、求和),一行内的元素之间存在数据依赖:

Softmax(z)i=ezimax(z)jezjmax(z)

减去 max(z) 是为了数值稳定性,防止指数运算溢出。

如果用 PyTorch 基础操作手动实现,每一步(max、减法、expsum、除法)都是一次独立的 Kernel 调用,一行数据在全局内存和 SM 之间来回搬运多达 5 次。Triton 的解决方案是将一整行的所有计算融合到一个 Kernel 中:

python
import torch
import triton
import triton.language as tl

@triton.jit
def triton_softmax_kernel(
    x_ptr, y_ptr, x_row_stride, y_row_stride,
    num_cols, BLOCK_SIZE: tl.constexpr
):
    # 1. 每个程序实例处理一行
    row_idx = tl.program_id(0)

    # 2. 计算当前行的列偏移量
    col_offsets = tl.arange(0, BLOCK_SIZE)
    x_ptrs = x_ptr + row_idx * x_row_stride + col_offsets

    # 3. 加载一整行,边界外填充 -inf
    mask = col_offsets < num_cols
    x_row = tl.load(x_ptrs, mask=mask, other=float("-inf"))

    # 4. 片上计算:所有中间结果都在寄存器中
    row_max = tl.max(x_row, axis=0)       # 归约:求最大值
    x_row = x_row - row_max               # 数值稳定
    numerator = tl.exp(x_row)              # 指数
    denominator = tl.sum(numerator, axis=0) # 归约:求和
    y_row = numerator / denominator        # 归一化

    # 5. 写回一整行结果
    y_ptrs = y_ptr + row_idx * y_row_stride + col_offsets
    tl.store(y_ptrs, y_row, mask=mask)

def triton_softmax(x: torch.Tensor):
    y = torch.empty_like(x)
    M, N = x.shape
    block_size = triton.next_power_of_2(N)  # 向上取整到 2 的幂
    triton_softmax_kernel[(M,)](
        x, y, x.stride(0), y.stride(0), N, BLOCK_SIZE=block_size
    )
    return y

这段代码的核心设计是**"一行一块"(one row per block)**:网格大小等于行数 M,每个程序实例独立处理一行。具体流程:

  1. 通过 tl.load 将一整行 N 个元素加载到 SM 的高速寄存器中(边界外用 填充,确保不影响 maxexp 的结果)。
  2. 在寄存器中依次完成 max(归约)、减法、expsum(归约)、除法——所有中间结果从未写回全局内存。
  3. 通过 tl.store 将最终结果一次性写回。

为什么 BLOCK_SIZE 要取 2 的幂? 这是一个硬件友好的选择。GPU 的内存子系统和归约操作在处理 2 的幂大小的数据块时效率最高。triton.next_power_of_2(N) 会将实际列数向上取整,多出的位置由掩码屏蔽。

性能对比。 非融合的手动 Softmax 实现对一行数据产生约 8N 字节的内存流量(5 次读写),而 Triton 融合版本仅产生 2N 字节(一次读、一次写),内存效率提升约 4 倍。实测性能与 PyTorch 官方的高度优化实现和 torch.compile 自动生成的版本相当。


9.3.6 PTX:编译产物检查与性能归因 [选读]

PTX(Parallel Thread Execution)是 NVIDIA 定义的 GPU 中间表示语言,可以理解为 GPU 的"汇编语言"。当我们编写 CUDA 或 Triton 代码后,编译器首先将其翻译为 PTX;GPU 驱动中的 JIT 编译器再将 PTX 进一步编译为特定 GPU 型号可执行的二进制码(SASS)。

为什么要看 PTX? 它是验证编译器优化是否生效的终极手段。当性能未达预期时,检查 PTX 可以回答以下问题:

  • 内存合并是否生效? 如果看到 ld.global.v4.f32(一次加载 4 个 float32),说明编译器成功地将相邻线程的内存请求合并为一次宽事务;如果只看到 ld.global.f32(逐个加载),则可能存在低效的访存模式。
  • 线程粗化是否发生? 分析 Triton GeLU Kernel 生成的 PTX 时,可以观察到单个线程实际处理了 4 到 8 个元素(出现多组 ld.global.v4.f32 + 计算 + st.global.v4.f32 序列),而非"一个线程一个元素"的朴素映射。线程粗化让每个线程承担更多工作,更好地分摊线程调度开销和隐藏内存延迟。
  • 数学运算如何映射? GeLU 近似公式中的 tanh 在 PTX 层面被分解为 ex2.approx.f32(以 2 为底的指数近似)、mul.f32add.f32 等基础指令。通过阅读这些指令序列,可以确认编译器没有引入多余的计算步骤。

以下是 Triton GeLU Kernel 生成的 PTX 片段示例(已简化):

// 寄存器声明
.reg .f32 %f<32>;
.reg .b32 %r<8>;

// 获取线程块 ID 和线程 ID
mov.u32 %r1, %ctaid.x;    // blockIdx.x
mov.u32 %r2, %tid.x;      // threadIdx.x

// 向量化加载:一次读 4 个 float32
ld.global.v4.f32 {%f1, %f2, %f3, %f4}, [%rd5];

// GeLU 计算(以 %f1 为例)
mul.f32  %f5, %f1, %f1;           // x * x
mul.f32  %f6, %f5, %f1;           // x^3
mul.f32  %f7, %f6, 0f3D372713;    // 0.044715 * x^3
add.f32  %f8, %f1, %f7;           // x + 0.044715*x^3
mul.f32  %f9, %f8, 0f3F4C422A;    // sqrt(2/pi) * (...)
// ... tanh 近似 + 最终乘法 ...

// 向量化写回
st.global.v4.f32 [%rd6], {%f25, %f26, %f27, %f28};

从这段 PTX 中可以确认三个关键优化:(1) v4 后缀表明向量化加载/存储生效;(2) 所有中间值存储在 %f 寄存器中,从未溢出到全局内存;(3) 每个线程处理多个元素(线程粗化)。

PTX 的硬件抽象价值。 PTX 还扮演着"向前兼容"的角色。同一份 PTX 代码可以在未来的新 GPU 上被重新编译为针对新硬件优化的 SASS 码,无需修改上层代码。这使得 CUDA 和 Triton 程序具有跨代 GPU 的可移植性。

注意,日常开发中通常不需要阅读 PTX。它更多地作为调试和性能归因的最后手段——当 Nsight 等工具告诉你某个 Kernel 慢,但你看不出为什么时,PTX 是那扇通往真相的窗户。


9.3.7 本节小结

本节从三个层次展示了 GPU 算子开发的实践:

层次工具特点适用场景
底层CUDA C++精细控制,开发成本高极致优化、非标准硬件操作
中层TritonPython 语法,编译器自动优化自定义算子的首选(如 FlashAttention)
高层torch.compile零手动编写,自动融合标准操作的加速

核心经验可以归纳为三条:

  1. 算子融合是性能的分水岭。 同样的数学公式,未融合时(manual_gelu)每个操作独立触发 Kernel,数据在全局内存中反复搬运;融合后(CUDA/Triton 版本)数据只读写全局内存一次,性能差距可达 4~8 倍。
  2. Triton 在易用性和性能之间取得了出色平衡。 它的块级编程模型比 CUDA 的线程级模型更符合深度学习的思维习惯,编译器自动处理了内存合并、共享内存和线程粗化等底层优化,代码量减少约一半,性能却与手写 CUDA 持平。
  3. 理解底层原理始终有价值。 随着 torch.compile 等自动编译器的成熟,大多数标准操作无需手写 Kernel。但当你需要实现 FlashAttention 这样具有非常规内存访问模式的新算法时,理解 CUDA 执行模型、CPU-GPU 异步机制和 PTX 编译产物,是写出正确且高效代码的基础。