当 Transformer 遇上状态空间模型:结构化状态空间对偶性揭秘

近年来,深度学习领域取得的巨大成功离不开 Transformer 架构的贡献,尤其是在语言建模方面。然而,随着模型规模的不断扩大,Transformer 的二次时间复杂度成为了其进一步发展的瓶颈。与此同时,状态空间模型(SSM),例如 Mamba,展现出与 Transformer 相媲美甚至更优的性能,并且在中小规模模型上具有线性时间复杂度优势。

本文将深入探讨 Transformer 与 SSM 之间的联系,并提出一个全新的理论框架——结构化状态空间对偶性(SSD)。该框架揭示了 SSM 与各种注意力变体之间的密切关系,并通过对结构化半可分矩阵的不同分解方式建立了联系。基于 SSD 框架,我们设计了一种全新的架构——Mamba-2,其核心层是对 Mamba 选择性 SSM 的改进,速度提升了 2-8 倍,同时在语言建模方面仍然可以与 Transformer 竞争。

Transformer 与 SSM 的前世今生

Transformer:深度学习的明星架构

Transformer,特别是仅解码器模型(例如 GPT、Llama),以其强大的序列建模能力成为了现代深度学习成功的关键驱动力之一。然而,其核心注意力层的二次时间复杂度问题一直是研究者们努力攻克的难题。

状态空间模型:线性复杂度的挑战者

结构化状态空间模型(SSM)作为一种新兴的序列模型,在长程任务(例如 S4)中表现出色,并且最近在中小型语言建模任务中与 Transformer 达到或超过了 Transformer 的性能(例如 Mamba)。SSM 具有线性时间复杂度,在训练和推理过程中效率更高。

结构化状态空间对偶性:架起沟通的桥梁

为了更好地理解和改进 SSM,我们提出了结构化状态空间对偶性(SSD)框架,该框架通过结构化矩阵的抽象概念将 SSM 与注意力变体联系起来。

结构化矩阵:SSD 框架的核心

结构化矩阵的定义和性质

结构化矩阵是指具有以下两个特性的矩阵:

  1. 可以通过压缩表示以亚二次(理想情况下是线性)参数表示。
  2. 具有快速算法(最重要的是矩阵乘法),可以直接对其压缩表示进行操作。

半可分矩阵:SSM 的矩阵表示

半可分矩阵是一种重要的结构化矩阵,其定义为:矩阵下三角部分中的每个子矩阵的秩最多为 N,其中 N 称为半可分矩阵的阶数或秩。

本文证明了 SSM 与半可分矩阵之间的等价性,并通过半可分矩阵的顺序半可分(SSS)表示形式建立了联系。

定理 3.5:状态大小为 N 的状态空间模型变换 y = SSM(A, B, C)(x) 与通过 N 阶半可分矩阵(以 SSS 表示形式)进行矩阵乘法 y = SSS(A, B, C) · x 相同。

基于结构化矩阵算法的 SSM 计算

通过将 SSM 视为半可分矩阵,我们可以利用结构化矩阵乘法算法来高效地计算 SSM。

线性(递归)模式:利用 SSM 的递归形式,可以以线性时间复杂度计算 SSM。

二次(朴素)模式:直接计算 SSM 的矩阵表示,时间复杂度为二次,但对于短序列长度,由于计算模式的硬件友好性,这种方法可能比线性算法更有效。

结构化掩码注意力:线性注意力的推广

注意力框架

注意力机制的核心是为序列中每对位置分配分数,从而使每个元素能够“关注”其他元素。最常见的注意力变体是 softmax 自注意力,其定义为:

Y = softmax(QKᵀ) · V

线性注意力

线性注意力通过将 softmax 折叠到核特征映射中,并利用矩阵乘法的结合律来重写注意力计算,从而避免了 softmax 的计算。

命题 4.1:自回归核注意力(即具有因果掩码的掩码核注意力)可以通过每次步骤花费恒定时间的递归以 O(T) 时间计算。

结构化掩码注意力

结构化掩码注意力(SMA)将线性注意力推广到使用任何结构化掩码 L 的情况,只要 L 具有亚二次矩阵乘法即可。

定义 4.2:结构化掩码注意力(SMA)(或简称结构化注意力)定义为对查询/键/值 Q, K, V 以及任何结构化矩阵 L(即具有亚二次矩阵乘法)的函数,通过四路张量收缩:

Y = contract(TN, SN, SP, TS → TP)(Q, K, V, L)

状态空间对偶性:SSM 与 SMA 的交汇

标量-单位结构化 SSM

当 SSM 中的 A 矩阵是标量时,其朴素二次计算可以看作是核注意力的一种实例。

1-半可分结构化掩码注意力

当 SMA 中的掩码 L 是 1-半可分矩阵时,其线性计算形式是状态空间模型的一种特例。

推论 5.1:1-SS SMA(具有 1-半可分结构化矩阵 L 的掩码注意力)(15)是对角 SSM(8)的一种特例,其中对角矩阵是单位矩阵的标量倍数。

定理 5.2:对于任何作为有界阶自回归过程的结构化掩码注意力(定义 4.2)实例,结构化掩码 L 必须是半可分矩阵。

结构化状态空间对偶性

SSD 框架揭示了 SSM 与 SMA 之间的对偶关系,其中线性 SSM 算法和二次核注意力算法是彼此的对偶形式。

SSD 模型的硬件高效算法

块分解

为了高效地计算 SSD 模型,我们采用了一种块分解方法。将矩阵 M 分解成大小为 Q × Q 的子矩阵的 T/Q × T/Q 网格,其中 Q 是块大小。对角块可以使用二次 SMA 模式高效计算,而离对角块可以利用半可分矩阵的秩结构分解为更小的递归。

计算成本

SSD 算法的计算成本与线性 SSM 相同,但其硬件友好性与注意力机制相当,主要使用矩阵乘法。

定理 6.1:考虑状态扩展因子为 N 且头部维度为 P = N 的 SSD 模型。存在一种算法,可以在任何输入 X ∈ R(T,P) 上计算模型,该算法仅需要 O(TN²) 训练 FLOP、O(TN) 推理 FLOP、O(N²) 推理内存,并且其工作量主要由矩阵乘法决定。

Mamba-2 架构

块设计

Mamba-2 架构对 Mamba-1 块进行了一些修改,这些修改部分是受注意力机制的启发,也是为了提高 Mamba-2 的可扩展性。

并行参数投影:Mamba-2 在块的开头使用单个投影并行生成 A, X, B, C,这与标准注意力架构类似,其中 X, B, C 对应于并行创建的 Q, K, V 投影。

额外的归一化:为了提高稳定性,在最终输出投影之前添加了一个额外的归一化层(例如 LayerNorm、GroupNorm 或 RMSNorm)。

多头模式

类似于多头注意力,Mamba-2 也采用了多头模式,其中状态大小 N 和头部维度 P 分别类似于注意力的 QK 头部维度和 V 头部维度。

多输入 SSM (MIS) / 多值注意力 (MVA) 模式:Mamba-2 使用 MVA 模式,其中 BC 矩阵(对应于注意力中的 KQ)在输入 X 的所有通道(对应于注意力中的 V)之间共享。

其他 SSD 扩展

SSD 还可以结合线性注意力文献中的其他思想,例如各种形式的核近似。

核注意力近似 softmax 注意力:Mamba-2 中包含一个灵活的核特征映射,并将其应用于 BC 分支(对应于注意力中的 KV 分支)。

合并归一化(分母)项:可以通过在 X 中添加一个额外的列 1 来找到分母项,从而得到形状为 (T, P + 1) 的张量。

实验验证

关联回忆

在多查询关联回忆(MQAR)任务中,Mamba-2 表现出色,并且随着状态大小的增加,性能持续提高。

语言建模

在标准语言建模任务中,Mamba-2 在困惑度和零样本评估方面与其他架构相比具有竞争力。

速度基准

SSD 算法比 Mamba 的扫描实现快 2-8 倍,并且在中等序列长度下与优化的注意力机制相当。

相关工作和讨论

状态空间模型

SSD 可以描述为具有 SISO 维度和标量-单位结构的选择性 SSM。

结构化矩阵

SSD 框架将 SSM 视为具有特定结构的矩阵混合器——半可分矩阵。

(线性)注意力

SSD 与标准(因果)注意力的主要区别在于:

  1. SSD 不使用 softmax 激活函数。
  2. SSD 将 logits 矩阵乘以一个输入相关的 1-半可分掩码。

相关模型

最近出现了一些与 Mamba 和 Mamba-2 非常相似的序列模型,例如 RetNet、TransNormerLLM、GateLoop、Gated Linear Attention (GLA)、HGRN、Griffin、xLSTM 和 RWKV(-4)。

结论

SSD 框架为理解和改进 SSM 提供了一个新的视角,并为设计更高效、更强大的序列模型开辟了新的方向。

参考文献

  • Dao, T., Gu, A., et al. (2019). M2: A high-performance monarch matrix multiplication library. In Proceedings of the 2019 IEEE/ACM International Symposium on Code Generation and Optimization (CGO) (pp. 174–185).
  • Gu, A., Goel, K., & Ré, C. (2022). Efficiently Modeling Long Sequences with Structured State Spaces. In International Conference on Learning Representations.
  • Gu, A., & Dao, T. (2023). Mamba: Linear-Complexity Attention with Selective State Spaces. arXiv preprint arXiv:2307.00855.
  • Katharopoulos, A., Vyas, A., Pappas, N., & Fleuret, F. (2020). Transformers are rnns: Fast autoregressive transformers with linear attention. In International Conference on Machine Learning (pp. 5156–5165). PMLR.
  • Pernet, C., & Storjohann, A. (2018). Time and space efficient generators for quasiseparable matrices. Journal of Symbolic Computation, 85, 224–246.
  • Sun, Y., Dehghani, M., et al. (2023). Retentive Network: A Successor to Transformer for Large Language Models. arXiv preprint arXiv:2307.08621.

Leave a Comment