Skip to content

6.6 Gated DeltaNet(线性注意力)

标准 Softmax 注意力的核心操作是计算 n×n 的注意力矩阵 softmax(QKT/dk),其时间和空间复杂度均为 O(n2),其中 n 为序列长度。当上下文窗口从 4K 扩展到 128K 乃至更长时,这一二次方开销成为推理延迟和显存消耗的主要瓶颈。Gated DeltaNet 是一种受循环神经网络启发的线性注意力变体,通过将 QKT 矩阵替换为递推状态更新,将单步推理复杂度降至 O(1)(整个序列为 O(n))。该机制源自 Yang 等人 2024 年的论文 Gated Delta Networks: Improving Mamba2 with Delta Rule,随后被 Qwen3-Next 和 Kimi Linear 采纳为其混合架构中的线性注意力层。

本节将从标准注意力的递推视角出发,推导 Gated DeltaNet 的状态更新公式,给出完整的 PyTorch 实现,并分析其与标准注意力在复杂度、KV 缓存和建模能力上的权衡。

线性注意力机制示意图

图 6-20:线性注意力与标准 Softmax 注意力的对比。线性注意力通过递推状态更新替代显式注意力矩阵,将单步推理复杂度从 O(n) 降至 O(1)。

6.6.1 从标准注意力到递推状态机制

标准注意力的瓶颈。 回顾缩放点积注意力的计算:

Attention(Q,K,V)=softmax(QKTdk)V

其中 Q,K,VRn×dQKT 产生一个 n×n 的注意力分数矩阵。在自回归生成的第 t 步,新 token 的 Query 向量 qt 需要与所有历史 Key 向量 k1,,kt 做内积,计算量为 O(td)。整个序列的总计算量为 t=1ntd=O(n2d)

线性注意力的核心思路。 如果我们去掉 Softmax,将注意力写成 QKTV 的形式,利用矩阵乘法的结合律,可以先计算 KTVRd×d(代价为 O(nd2)),再用 Q 左乘(代价为 O(nd2)),总复杂度变为 O(nd2)——对序列长度 n 是线性的。当 dn 时,这比 O(n2d) 快得多。

更进一步,定义一个状态矩阵 StRd×d,通过递推方式逐 token 更新:

St=St1+ktvtT,yt=Stqt

其中 kt,vt,qtRd 分别是第 t 步的 Key、Value、Query 向量,yt 是输出。每一步的计算量仅为 O(d2)(外积更新 + 矩阵-向量乘),与序列位置 t 无关。这就是线性注意力的递推状态机制。

问题:信息只增不减。 上述朴素递推中,St 只能累积信息,无法遗忘过时的内容。当序列很长时,早期 token 的信息会持续干扰后续输出,而固定大小的 d×d 矩阵也无法无限容纳所有历史信息。Gated DeltaNet 正是通过引入"衰减门"和"Delta 规则"来解决这一问题。

6.6.2 Gated DeltaNet 的递推公式

Gated DeltaNet 在朴素线性注意力的基础上引入三个门控信号,完整的递推公式如下:

St=αtSt1+kt[βt(vtSt1Tkt)]Tyt=Stqtot=RMSNorm(yt)SiLU(gt)

各符号含义:

符号维度含义
Std×d状态矩阵(每个头各自维护一个)
αt标量(per head)衰减门(decay gate):控制旧记忆的保留比例,αt(0,1)
βtd 维向量更新门(update gate):控制新信息写入状态的强度
gtd 维向量输出门(output gate):对注意力输出进行逐元素缩放
逐元素乘(Hadamard 积)

逐步解读递推过程:

第一步:衰减旧状态。 αtSt1 将整个状态矩阵按 αt 缩放。当 αt 接近 1 时,模型保留几乎全部历史信息;当 αt 接近 0 时,模型大幅遗忘过去——这类似于 LSTM 中遗忘门的作用。

第二步:计算 Delta(差值)。 δt=βt(vtSt1Tkt) 是 Gated DeltaNet 的核心创新。其中 St1Tkt 是用当前 Key 从旧状态中"检索"出的预测值——如果状态已经能很好地表示 kt 对应的 Value,那么 vtSt1Tkt 趋近于零,状态几乎不更新。反之,差距越大,更新越强。这一"误差驱动"的更新规则与经典的 Delta 学习规则(Widrow-Hoff 规则)同源,本质上是在做联想记忆的在线修正。更新门 βt 进一步调节每个维度的更新幅度。

第三步:写入新信息。 St=αtSt1+ktδtT 将修正量以外积形式写入状态矩阵。kt 充当"地址"向量,δt 充当"内容"向量,外积 ktδtT 是一个秩 1 更新,将新信息沿 kt 方向写入状态。

第四步:读取输出。 yt=Stqt 用 Query 向量从更新后的状态中检索信息。

第五步:门控输出。 ot=RMSNorm(yt)SiLU(gt) 先对输出做归一化稳定数值,再通过 SiLU 门控(而非传统的 Sigmoid 门控)进行缩放。SiLU(xσ(x))允许负值通过(虽然幅度很小),梯度流动性优于标准 Sigmoid。

6.6.3 三个门控的计算方式

三个门控信号均由输入 xt 经线性投影后激活得到:

衰减门 αt(标量,per head):

αt=exp(Asoftplus(Wαxt+bα))

其中 A 是可学习的 per-head 标量参数(以 logA 形式参数化),softplus()=log(1+e()) 保证括号内为正,外层 exp() 保证 αt(0,1)。这一设计来自 Mamba 架构中对状态衰减率的参数化方式,在数值上等价于在对数空间中计算衰减率再取指数。

更新门 βtd 维向量):

βt=σ(Wβxt)

即标准 Sigmoid 门控,逐维度控制更新幅度,βt(0,1)d

输出门 gtd 维向量):

gt=Wgxt

输出门不需要显式激活函数——它在最终输出时与 SiLU 复合使用,SiLU(gt)=gtσ(gt),SiLU 本身已包含 Sigmoid 激活。

6.6.4 PyTorch 实现

线性注意力与 Softmax 注意力的计算对比

图 6-21:线性注意力与标准 Softmax 注意力的计算流程对比。线性注意力通过递推状态更新替代显式注意力矩阵,将计算复杂度从 O(n2) 降至 O(n)

以下是 Gated DeltaNet 的完整 PyTorch 实现。为突出递推机制的核心逻辑,省略了实际部署中常见的卷积混合(short convolution)模块:

python
import torch
import torch.nn as nn
import torch.nn.functional as F


def l2norm(x, dim=-1, eps=1e-6):
    """L2 归一化,用于 Query/Key 的归一化(类似 QKNorm)"""
    return x * torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps)


class GatedDeltaNet(nn.Module):
    def __init__(self, d_in, d_out, num_heads, dropout=0.0, qkv_bias=False):
        super().__init__()
        assert d_out % num_heads == 0, "d_out must be divisible by num_heads"

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads

        # Q / K / V 投影(与标准注意力相同)
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

        # 三个门控的投影
        self.W_gate = nn.Linear(d_in, d_out, bias=False)   # 输出门
        self.W_beta = nn.Linear(d_in, d_out, bias=False)   # 更新门

        # 衰减门 alpha = exp(-A * softplus(W_alpha(x) + dt_bias))
        self.W_alpha = nn.Linear(d_in, num_heads, bias=False)
        self.dt_bias = nn.Parameter(torch.ones(num_heads))
        A_init = torch.empty(num_heads).uniform_(0, 16)
        self.A_log = nn.Parameter(torch.log(A_init))

        # 输出归一化
        self.norm = nn.RMSNorm(self.head_dim, eps=1e-6)

        self.out_proj = nn.Linear(d_out, d_out)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        b, n, _ = x.shape

        # 线性投影
        queries = self.W_query(x)
        keys = self.W_key(x)
        values = self.W_value(x)

        # 计算三个门控信号
        beta = torch.sigmoid(self.W_beta(x))          # (b, n, d_out)
        alpha_log = -self.A_log.exp().view(1, 1, -1) * F.softplus(
            self.W_alpha(x) + self.dt_bias
        )                                              # (b, n, num_heads)
        alpha = alpha_log.exp()                        # (b, n, num_heads)
        gate = self.W_gate(x)                          # (b, n, d_out)

        # reshape 为多头形式: (b, num_heads, n, head_dim)
        queries = queries.view(b, n, self.num_heads, self.head_dim).transpose(1, 2)
        keys = keys.view(b, n, self.num_heads, self.head_dim).transpose(1, 2)
        values = values.view(b, n, self.num_heads, self.head_dim).transpose(1, 2)
        beta = beta.view(b, n, self.num_heads, self.head_dim).transpose(1, 2)
        gate = gate.view(b, n, self.num_heads, self.head_dim).transpose(1, 2)

        # QKNorm:L2 归一化 + 缩放,稳定递推数值
        queries = l2norm(queries, dim=-1) / (self.head_dim ** 0.5)
        keys = l2norm(keys, dim=-1)

        # 初始化状态矩阵 S: (b, num_heads, head_dim, head_dim)
        S = x.new_zeros(b, self.num_heads, self.head_dim, self.head_dim)

        outs = []
        for t in range(n):
            k_t = keys[:, :, t]       # (b, num_heads, head_dim)
            q_t = queries[:, :, t]     # (b, num_heads, head_dim)
            v_t = values[:, :, t]      # (b, num_heads, head_dim)
            b_t = beta[:, :, t]        # (b, num_heads, head_dim)
            a_t = alpha[:, t]          # (b, num_heads)
            a_t = a_t.unsqueeze(-1).unsqueeze(-1)  # (b, num_heads, 1, 1)

            # Step 1: 衰减旧状态
            S = S * a_t

            # Step 2: 从状态中检索预测值,计算 delta
            kv_mem = (S * k_t.unsqueeze(-1)).sum(dim=-2)  # S^T @ k_t
            delta = (v_t - kv_mem) * b_t

            # Step 3: 秩-1 更新写入状态
            S = S + k_t.unsqueeze(-1) * delta.unsqueeze(-2)

            # Step 4: 读取输出
            y_t = (S * q_t.unsqueeze(-1)).sum(dim=-2)     # S @ q_t

            outs.append(y_t)

        # 合并时间步: (b, num_heads, n, head_dim)
        context = torch.stack(outs, dim=2)
        # 转置回 (b, n, num_heads, head_dim) 以便做 RMSNorm
        context = context.transpose(1, 2).contiguous()
        context = context.view(b, n, self.num_heads, self.head_dim)

        # Step 5: RMSNorm + SiLU 输出门控
        context = self.norm(context)
        gate = gate.transpose(1, 2).contiguous()
        gate = gate.view(b, n, self.num_heads, self.head_dim)
        context = context * F.silu(gate)

        # 合并多头并输出投影
        context = context.view(b, n, self.d_out)
        context = self.dropout(context)
        return self.out_proj(context)

代码要点解读:

  1. 状态矩阵 S 的形状为 (b, num_heads, head_dim, head_dim)——这是一个 d×d 的方阵(per head),而非标准注意力中随序列长度增长的 n×n 矩阵。对于典型的 head_dim=128,每个头的状态仅占 128×128×2=32 KB(bf16),远小于长序列的 KV 缓存。
  2. l2norm 对 Q 和 K 做 L2 归一化(类似 QKNorm),确保递推过程中数值不会爆炸或消失。Softmax 注意力天然具有归一化效果(注意力权重和为 1),但线性注意力失去了这一性质,因此需要显式归一化。
  3. kv_mem = (S * k_t.unsqueeze(-1)).sum(dim=-2) 等价于矩阵-向量乘 STkt,利用逐元素乘 + 求和替代显式矩阵乘法。类似地,y_t = (S * q_t.unsqueeze(-1)).sum(dim=-2) 等价于 Sqt
  4. 衰减门 α 的参数化A_log(可学习)→ exp 得到 A → 与 softplus(W_alpha(x) + dt_bias) 相乘 → 取负 → exp 得到 α(0,1)。这种多层变换保证了数值稳定性,并允许模型灵活学习衰减速率。

6.6.5 与标准注意力的复杂度对比

下表对比了标准 Softmax 注意力与 Gated DeltaNet 在推理阶段的关键指标(n 为序列长度,dhead_dimH 为头数,L 为层数):

指标Softmax 注意力Gated DeltaNet
单步推理计算量O(nd)(需与所有历史 Key 做内积)O(d2)(状态矩阵更新 + 查询)
全序列计算量O(n2d)O(nd2)
KV 缓存大小O(LHnd)(随 n 线性增长)O(LHd2)(固定,与 n 无关)
KV 缓存增长方式每生成一个 token,缓存增加 2×L×H×d 个元素无增长,仅维护固定大小的状态矩阵 S
全局上下文建模完整——每个 token 直接关注所有历史 token受限——历史信息被压缩到 d×d 状态矩阵中

关键权衡:nd 时(长上下文场景),O(nd2) 远优于 O(n2d),且 KV 缓存不再随上下文长度膨胀。但 d×d 的状态矩阵构成信息瓶颈——所有历史信息必须被压缩到这一固定容量中,类似 RNN 的隐状态瓶颈。这正是混合架构存在的原因。

6.6.6 KV 缓存对比:数值示例

以一个典型的中大规模模型配置进行量化对比:emb_dim=2048num_heads=16head_dim=128n_layers=48dtype=bf16(2 字节/元素)、batch_size=1

标准注意力的 KV 缓存(随序列长度增长):

KVstd=B×L×n×H×d×2×2=1×48×n×16×128×2×2
序列长度 nKV 缓存大小
4,0961.50 GB
32,76812.00 GB
131,07248.00 GB

Gated DeltaNet 的状态内存(固定):

StateGDN=B×L×H×d×d×2=1×48×16×128×128×225.17 MB

这一数值与序列长度完全无关。即使在 128K 上下文长度下,标准注意力需要 48 GB KV 缓存,而 Gated DeltaNet 仅需约 25 MB 的固定状态——相差近 2000 倍。

当然,这里的比较仅针对线性注意力层。在混合架构(如 3:1 配置)中,全局注意力层仍然需要 KV 缓存,但总量减少为纯全局注意力架构的 1/4

6.6.7 混合架构:3:1 策略

状态空间模型与递推架构

图 6-22:状态空间模型(SSM)的递推架构。DeltaNet 与 Mamba 等模型共享递推状态更新的核心思想,通过固定大小的状态矩阵压缩全部历史信息。

纯线性注意力虽然效率极高,但其状态矩阵的有限容量使其在需要精确长距离检索的任务上(如从长文档中提取特定事实)表现不如全局注意力。因此 Qwen3-Next 和 Kimi Linear 均采用混合策略:每 4 个 Transformer 块中,3 个使用 Gated DeltaNet 线性注意力,1 个使用全局注意力(Qwen3-Next 使用带输出门控的标准多头注意力,Kimi Linear 使用 MLA),比例为 3:1。

这一设计的考量包括:

  1. 效率:全模型 75% 的注意力层为线性复杂度,整体计算量和 KV 缓存大幅减少。
  2. 上下文能力:每隔 3 层插入一个全局注意力层,为模型提供直接访问完整历史的通道,弥补线性注意力的信息压缩损失。
  3. 训练稳定性:Qwen3-Next 在全局注意力层中使用 Sigmoid 输出门控,消除了 Attention Sink(注意力权重集中在首 token)和 Massive Activation(激活值异常放大)等数值稳定性问题。

Kimi Linear 的改进:通道级衰减门。 Qwen3-Next 的衰减门 αt 是 per-head 的标量,即同一头内所有维度共享相同的衰减率。Kimi Linear 将其替换为 per-channel 的向量(即 αtRd),允许不同特征维度以不同速率遗忘。根据 Kimi Linear 论文的实验,这一改进在长上下文推理任务上带来了可测量的性能提升。

6.6.8 推理流程对比

将标准注意力和 Gated DeltaNet 的自回归推理流程并排对比,可以清晰看出两者在缓存机制上的根本差异:

标准注意力的推理(以第 t 步为例):

  1. 计算新 token 的 qt,kt,vt
  2. kt,vt 追加到 KV 缓存:cacheK[k1,,kt1,kt]
  3. 计算 qt 与缓存中所有 Key 的点积:O(td)
  4. Softmax + 加权求和得到输出。

Gated DeltaNet 的推理(以第 t 步为例):

  1. 计算新 token 的 qt,kt,vt 以及三个门控 αt,βt,gt
  2. 更新状态矩阵:St=αtSt1+ktδtT,代价 O(d2)
  3. 从状态中查询输出:yt=Stqt,代价 O(d2)
  4. 门控输出:ot=RMSNorm(yt)SiLU(gt)

整个过程无需存储任何历史 K/V 向量,仅需维护一个固定大小的状态矩阵 S。这意味着生成第 1000 个 token 和第 100,000 个 token 的单步计算量和内存消耗完全相同。

DeltaNet 递推状态更新机制

图 6-23:DeltaNet 的递推状态更新。状态矩阵通过门控遗忘和 Delta 规则增量更新,实现对序列信息的高效压缩存储。

本节小结

本节介绍了 Gated DeltaNet 线性注意力的原理与实现:

  • 核心机制:用递推状态矩阵 StRd×d 替代 n×n 注意力矩阵。每步通过"衰减-检索-修正-写入-查询"五步递推更新状态,单步计算量为 O(d2),全序列为 O(nd2),对序列长度线性。
  • 三个门控:衰减门 α(控制遗忘速率)、更新门 β(控制写入强度)、输出门 g(SiLU 激活,控制输出缩放)。Delta 规则(误差驱动更新)使状态更新具有选择性——已经记住的信息不会被重复写入。
  • KV 缓存优势:Gated DeltaNet 的"缓存"是固定大小的状态矩阵(如 head_dim=128 时每头仅 32 KB),不随序列长度增长。在 128K 上下文、48 层模型的配置下,标准注意力需要约 48 GB KV 缓存,而 Gated DeltaNet 仅需约 25 MB。
  • 建模能力代价:固定大小的状态矩阵构成信息瓶颈,无法像全局注意力那样直接关注任意历史位置。因此实际部署中采用 3:1 混合策略(3 层线性注意力 + 1 层全局注意力),在效率和上下文建模能力之间取得平衡。
  • 工业实践:Qwen3-Next 和 Kimi Linear 均采用此架构。Kimi Linear 进一步将 per-head 的标量衰减门改为 per-channel 的向量衰减门,提升了长上下文推理性能。