Skip to content

3.2 核心组件逐一拆解

上一节从宏观视角勾勒了 Transformer 的整体架构和数据流。然而,仅仅知道"注意力子层后面跟着前馈网络"是不够的——魔鬼藏在细节中。从原始 Transformer 到 LLaMA、Mistral、Qwen 等现代大语言模型,架构的核心骨架几乎没有变化,真正改变的是三个看似不起眼却至关重要的组件:层归一化前馈网络残差连接。这三者的设计选择直接决定了模型能否在数千层深度和万亿参数规模下稳定训练。

本节将逐一拆解这三个组件:先讲清楚它们"做什么"、"为什么要这么做",再给出数学公式和 PyTorch 实现,最后揭示它们在现代 LLM 中的演进方向。

3.2.1 层归一化

深层网络训练面临的首要挑战是内部协变量偏移(Internal Covariate Shift):随着梯度更新,每一层的输入分布不断漂移,导致后续层需要反复适应新的输入统计量。归一化技术通过在每一层强制将激活值映射到稳定的分布,从根本上缓解了这一问题,使得优化过程更加平滑,允许使用更高的学习率。

LayerNorm

层归一化(Layer Normalization)对单个样本的隐藏维度计算均值和方差,然后进行标准化。设输入向量 xRd,LayerNorm 的计算过程如下:

μ=1di=1dxi,σ2=1di=1d(xiμ)2LayerNorm(x)=γxμσ2+ε+β

其中 γ,βRd 是可学习的缩放和偏移参数,ε(通常取 105)防止除零。减均值操作称为重新居中(re-centering),除以标准差称为重新缩放(re-scaling),两者共同保证了输出的零均值和单位方差。

与 BatchNorm 的区别。 BatchNorm 沿 batch 维度计算统计量,隐含假设同一特征在不同样本间服从相似分布,这在 CNN 的空间特征中是合理的。但在序列建模中,不同样本的序列长度和内容差异巨大,batch 维度的统计量噪声过高。LayerNorm 沿隐藏维度计算,每个样本独立归一化,不依赖 batch 内的其他样本,因此天然适合 Transformer 架构。

python
import torch
import torch.nn as nn

class LayerNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-5):
        super().__init__()
        self.eps = eps
        self.gamma = nn.Parameter(torch.ones(dim))
        self.beta = nn.Parameter(torch.zeros(dim))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        mean = x.mean(dim=-1, keepdim=True)
        var = x.var(dim=-1, keepdim=True, unbiased=False)
        x_norm = (x - mean) / torch.sqrt(var + self.eps)
        return self.gamma * x_norm + self.beta

RMSNorm

Zhang 和 Sennrich(2019)在论文"Root Mean Square Layer Normalization"中提出了一个关键假设:LayerNorm 的成功主要归功于重新缩放的不变性,而非均值居中操作。基于这一洞察,RMSNorm 直接移除了均值减法步骤和偏置参数 β,仅使用均方根(Root Mean Square)进行归一化:

RMS(x)=1di=1dxi2RMSNorm(x)=γxRMS(x)+ε

移除均值居中的直觉在于:在深层网络中,经过多层变换后的激活值均值本就趋近于零,显式减去均值带来的收益有限,反而增加了一次全局归约操作。从计算角度看,RMSNorm 省去了均值计算和方差计算中的减法步骤,只需要一个可学习参数 γ(而非 γβ 两个)。虽然 FLOPs 的减少量微不足道,但在以内存带宽为瓶颈的现代 GPU 上,更少的内存读写意味着更高的实际吞吐量。实验表明,这种简化可以带来 7%–64% 的速度提升,且对模型性能几乎没有影响。

RMSNorm 已成为 LLaMA、PaLM、Qwen 等现代大语言模型的标准归一化方法。

python
class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-5):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x: torch.Tensor) -> torch.Tensor:
        # rsqrt = 1 / sqrt(x),避免先 sqrt 再除法的两步操作
        return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # 先转 float32 计算以避免半精度下溢,再转回原始 dtype
        return self.weight * self._norm(x.float()).type_as(x)

代码解读。 _norm 方法使用 torch.rsqrt(倒数平方根)一步完成归一化,避免了先 sqrt 再除法的两步操作。x.float() 将输入临时提升到 float32 精度进行归一化计算,再通过 type_as(x) 转回原始精度(如 bfloat16)。这是因为半精度浮点数的最小正数约为 6×108(float16)或精度有效位较少(bfloat16),直接在半精度下计算平方均值容易发生下溢或精度损失。

Pre-Norm vs Post-Norm

归一化"用什么"只是问题的一半,"放在哪里"同样关键。

Post-Norm(原始 Transformer)。 原始 Transformer 将归一化放在残差连接之后:

x=LayerNorm(x+SubLayer(x))

这意味着 LayerNorm 直接位于残差主路径上。从梯度流的角度看,反向传播时梯度必须穿过 LayerNorm 层才能回传到更早的层,而 LayerNorm 的 Jacobian 矩阵会对梯度进行非平凡的缩放和旋转,这在深层网络中逐层累积,可能导致梯度爆炸或消失。因此 Post-Norm 的训练往往需要精细的学习率预热(Warmup)策略,否则容易发散。

Pre-Norm(现代标准)。 Pre-Norm 将归一化移到子层的输入端:

x=x+SubLayer(LayerNorm(x))

这一看似微小的调整带来了本质性的改善。注意此时 LayerNorm 位于子层的分支路径上,而残差主路径 xx 是一条"干净"的恒等连接——梯度可以无损地从网络顶层直接回传到底层,不受归一化层的干扰。这使得训练稳定性大幅提升,模型可以容忍更高的学习率,不再严重依赖 Warmup。

Pre-Norm 与 Post-Norm 结构对比

图 3-3:Pre-Norm 与 Post-Norm 结构对比。Post-Norm 中归一化位于残差加法之后的主路径上;Pre-Norm 中归一化位于子层输入端的分支路径上,残差主路径保持干净的恒等连接。几乎所有现代 LLM 均采用 Pre-Norm 结构。

几乎所有现代大语言模型都采用 Pre-Norm 结构。 在 LLaMA 系列中,具体配置为 Pre-RMSNorm——在每个注意力子层和前馈网络子层之前分别施加 RMSNorm。值得一提的是,Grok、Gemma 2 等最新模型甚至引入了 Double Norm:在 Pre-Norm 的基础上,于子层输出端再增加一次归一化,提供额外的数值稳定性保障。

3.2.2 前馈网络

Transformer 中的每个注意力子层后面都紧跟一个位置逐点前馈网络(Position-wise Feed-Forward Network, FFN)。注意力子层负责在序列维度上聚合上下文信息,而 FFN 则在每个位置独立地进行特征变换——可以将其理解为对每个 token 的表示施加一次非线性特征提取。

标准 FFN

原始 Transformer 中的 FFN 由两个线性变换和一个 ReLU 激活函数组成:

FFN(x)=ReLU(xW1+b1)W2+b2

其中 W1Rd×dff 是"升维"矩阵,W2Rdff×d 是"降维"矩阵,中间维度 dff 通常取 4d(如 d=768dff=3072)。FFN 本质上就是一个两层 MLP:先升维到高维空间以获得更强的特征表达能力,再降维回原始维度。

激活函数的演进:ReLU -> GELU -> Swish

ReLU ReLU(x)=max(0,x) 是深度学习时代最经典的激活函数,计算极其简单高效。但它存在Dying ReLU 问题:如果一个神经元的输入在训练过程中持续为负,其梯度恒为零,对应的权重将永远无法更新——这个神经元"死掉"了。

GELU(Gaussian Error Linear Unit)是 ReLU 的平滑近似,被 GPT 和 BERT 系列广泛采用:

GELU(x)=xΦ(x)=x12(1+erf(x2))

其中 Φ(x) 是标准正态分布的累积分布函数。与 ReLU 的硬截断不同,GELU 在零点附近是平滑的,允许负值输入产生非零但极小的输出。这种平滑性一方面避免了 Dying ReLU 问题,另一方面可以被解释为一种随机正则化——以概率 Φ(x) 保留输入,以概率 1Φ(x) 将其置零。

Swish 函数是另一种平滑激活:Swish(x)=xσ(x),其中 σ 为 sigmoid 函数。当 β=1 时,Swish 等价于 SiLU(Sigmoid Linear Unit),也是 PyTorch 中 nn.SiLU() 的默认实现。Swish 与 GELU 的形状非常接近,但计算上仅依赖 sigmoid,无需误差函数。

SwiGLU:门控机制的引入

标准 FFN 中的激活函数是"静态"的——对每个维度施加相同的非线性变换,不考虑输入的内容。Dauphin 等人(2017)提出的门控线性单元(Gated Linear Unit, GLU)引入了一种全新的思路:让网络根据输入内容动态决定每个维度的信息通过量

GLU 的一般形式为:

GLU(x)=(xWup)σ(xWgate)

其中 表示逐元素乘法,σ 是某种激活函数。xWgate 经过激活函数后生成一个"门控"向量,其值控制着 xWup 中每个维度的信息流——接近零时"关闭"该维度,接近正值时"放行"。这一额外的门控路径使得 FFN 能够进行内容自适应的特征选择,表达能力远超静态激活函数。

SwiGLU 是 GLU 家族中性能最优的变体(Shazeer, 2020),它使用 Swish 作为门控激活函数。完整的 SwiGLU FFN 公式为:

FFNSwiGLU(x)=(Swish(xWgate)xWup)Wdown

其中 Wgate,WupRd×dffWdownRdff×d。注意,与标准 FFN 的两个权重矩阵不同,SwiGLU 引入了三个权重矩阵——WgateWup 分别生成门控信号和信息主干,Wdown 将结果降维回原始维度。

图 3-4:SwiGLU FFN 的结构示意图。输入 x 被同时送入两个并行的线性层:Wgate 路径经 Swish 激活后生成门控信号,Wup 路径生成信息主干。两者逐元素相乘后,经 Wdown 降维输出。相比标准 FFN 的"升维->激活->降维",SwiGLU 增加了一条门控路径,实现了内容自适应的特征选择。

参数量的平衡。 三个矩阵意味着比标准 FFN 多出 50% 的参数量。为了在总参数量不变的前提下使用 SwiGLU,通常将中间维度 dff4d 缩减为 8d3,使总参数量 3×d×8d3=8d2 与标准 FFN 的 2×d×4d=8d2 保持一致。在工程实现中,dff 还会进一步对齐到 256 的整数倍,以充分利用 GPU 的内存对齐和并行计算优势。

以下代码同时包含从 T5 开始普遍采用的去偏置设计(bias=False):

python
import torch.nn.functional as F

class SwiGLUFFN(nn.Module):
    def __init__(self, dim: int, multiple_of: int = 256):
        super().__init__()
        # 中间维度: 8d/3,对齐到 multiple_of 的整数倍
        mid_dim = int(8 * dim / 3)
        mid_dim = multiple_of * ((mid_dim + multiple_of - 1) // multiple_of)

        self.w_gate = nn.Linear(dim, mid_dim, bias=False)
        self.w_up = nn.Linear(dim, mid_dim, bias=False)
        self.w_down = nn.Linear(mid_dim, dim, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Swish(x @ W_gate) ⊙ (x @ W_up),再降维
        return self.w_down(F.silu(self.w_gate(x)) * self.w_up(x))

代码解读。 F.silu 即 Swish 函数(xσ(x)β=1 时的 SiLU)。w_gatew_up 并行地对输入进行线性变换,前者经 Swish 激活后作为门控信号,与后者逐元素相乘,最终由 w_down 降维输出。整个前向过程只有三次矩阵乘法和一次逐元素乘法,计算模式对 GPU 非常友好。

自 LLaMA 以来,SwiGLU 已取代 GELU 成为大语言模型 FFN 层的事实标准。GeGLU(将 Swish 替换为 GELU)是另一个有竞争力的变体,被 T5-v1.1、Phi-3 等模型采用。论文作者坦诚地表示"我们无法为这一改进提供单一的解释"——但实验数据一致表明,门控机制在各种基准上均带来了可观的困惑度下降。

3.2.3 残差连接

残差连接与 Pre-Activation 设计

图 3-5:残差连接的两种设计及其训练效果对比。(a) 原始残差块(Post-Activation):weight → BN → ReLU 顺序,归一化在残差加法之前;(b) Pre-Activation 残差块:BN → ReLU → weight 顺序,归一化放在变换之前,残差主路径更干净。右图展示了在 1001 层 ResNet 上的训练曲线——Pre-Activation(蓝色)显著优于原始设计(黄色),这一原理直接延伸为 Transformer 中 Pre-Norm 的设计动机。

残差连接可能是深度学习中最简单却影响最深远的设计。它的形式只有一行公式:

x=x+F(x)

其中 F 表示子层的变换(注意力或前馈网络)。输入 x 通过一条"捷径"(shortcut)直接跳过子层与输出相加——因此也称为跳连(skip connection)。

退化问题与残差学习

1.2 节曾从 CNN 的视角介绍了残差连接的起源。这里我们从 Transformer 的语境出发,给出一个更深入的理论分析。

深层网络面临的核心困境是退化问题(Degradation Problem):随着网络层数增加,训练误差反而上升——不是过拟合,而是优化本身失败了。这个现象的一个思想实验可以揭示问题的本质:假设一个 56 层网络的最优解恰好等价于某个 20 层网络——多出来的 36 层只需要学会恒等映射即可。但实验表明,SGD 很难将一个带有非线性激活和归一化的层训练成恒等映射。

残差连接将问题巧妙地转换了:子层不再需要学习完整的目标映射 H(x),而只需要学习残差 F(x)=H(x)x。如果某一层应该是恒等映射,那么 F(x)=0——将一组权重推向零远比学习恒等映射容易得多。

残差块的结构可参见图 1-7(§1.2)。在 Transformer 的 Pre-Norm 配置下,子层路径为 SubLayer(RMSNorm(x)),捷径路径为恒等映射,两者相加后传入下一层。

函数族论证:更大模型必覆盖更小模型

残差连接赋予深层网络一个优雅的理论性质,可以用函数族嵌套的视角来理解。

FL 表示 L 层残差网络能够表达的所有函数的集合。对于一个 L 层残差网络,每一层的映射为 xl+1=xl+Fl(xl)。考虑一个更深的 (L+1) 层网络,它可以表达 FL+1 中的所有函数。关键观察是:

如果第 (L+1) 层令 FL+1(x)=0(即所有参数为零),则该层退化为恒等映射,整个 (L+1) 层网络等价于原来的 L 层网络。

这意味着 FLFL+1——更深的残差网络的函数族严格包含更浅网络的函数族。用集合论的语言表述:

F1F2FLFL+1

这一嵌套性质保证了:增加网络深度只会扩大(或至少不缩小)模型的表达能力。更大的模型永远能够覆盖更小模型所能表达的全部函数,因为它可以简单地将多余的层"关闭"。在没有残差连接的网络中,这一性质不成立——一个更深但没有跳连的网络,其函数族可能与浅层网络不具有包含关系,因为学习恒等映射本身就是困难的。

这也是为什么我们可以放心地将 Transformer 从 12 层扩展到 96 层乃至数百层:残差连接在理论上保证了深层网络至少不会比浅层网络差——即使某些层学到的变换接近于零,模型的整体表达能力也不会退化。

梯度高速公路

残差连接在优化层面的贡献同样深远。考虑一个 L 层残差网络的前向过程:

xL=x0+l=0L1Fl(xl)

对损失 L 关于第 l 层输入 xl 求梯度:

Lxl=LxLxLxl=LxL(I+xlk=lL1Fk(xk))

关键在于那个 I——无论中间的子层变换 Fk 多么复杂,梯度总有一条"高速公路"可以直接从损失函数传回任意一层。即使某些 Fkxl 趋近于零(梯度消失)或非常大(梯度爆炸),恒等项 I 始终提供稳定的梯度信号。这正是残差网络能够训练到数百层深度的根本原因。

以下代码展示了 Transformer 块中残差连接与 Pre-RMSNorm 的标准组合:

python
class TransformerBlock(nn.Module):
    def __init__(self, dim: int, num_heads: int, multiple_of: int = 256):
        super().__init__()
        self.attn_norm = RMSNorm(dim)
        self.attn = MultiHeadAttention(dim, num_heads)  # 参见 §2.3
        self.ffn_norm = RMSNorm(dim)
        self.ffn = SwiGLUFFN(dim, multiple_of)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # 残差 + Pre-RMSNorm + 注意力
        x = x + self.attn(self.attn_norm(x))
        # 残差 + Pre-RMSNorm + SwiGLU FFN
        x = x + self.ffn(self.ffn_norm(x))
        return x

代码解读。 这个 TransformerBlock 浓缩了本节介绍的三个核心组件:RMSNorm 实现层归一化,SwiGLUFFN 实现门控前馈网络,而 x = x + ... 就是残差连接。每个子层遵循相同的模式——先归一化、再变换、最后加回残差——简洁而统一。整个现代 Transformer 就是将这样的块堆叠数十至数百层。

本节小结

本节拆解了 Transformer 架构中三个看似简单却至关重要的核心组件:

  • 层归一化 通过稳定每一层的输入分布来保障训练稳定性。从 LayerNorm 到 RMSNorm,去掉均值居中在几乎不损失性能的前提下显著提升了计算效率。从 Post-Norm 到 Pre-Norm,归一化位置的调整保证了残差主路径的梯度畅通无阻。
  • 前馈网络 在每个位置独立地进行特征变换。从 ReLU 到 GELU 再到 SwiGLU,激活函数的演进方向是引入门控机制实现内容自适应的特征选择,使 FFN 从静态非线性变换升级为动态信息过滤器。
  • 残差连接x=x+F(x) 这一极简形式,同时解决了函数族嵌套(更深的模型表达能力不退化)和梯度传播(提供从损失到任意层的直达通道)两大根本性问题。

这三个组件的协同构成了现代 LLM 的"基本粒子"。理解它们各自的设计动机和演进逻辑,是深入理解下一节将介绍的旋转位置编码(RoPE)以及后续章节中各类大语言模型架构变体的必要前提。