缓存与效果的极限拉扯:从 MHA 、 MQA 、 GQA 到 MLA

引言

最近,幻方发布的 DeepSeek-V2 引起了广泛关注。其 1 块钱 100 万 token 的价格令人惊叹,而背后的关键技术之一——MLA(Multi-head Latent Attention) 更是备受瞩目。本文将带大家梳理从 MHA 、 MQA 、 GQA 到 MLA 的演变历程,并深入介绍 MLA 的设计思路。

MHA:多头注意力

MHA(Multi-Head Attention) 是 《Attention is all you need 》提出的注意力机制的基础。它通过多个独立的单头注意力拼接而成,广泛应用于当前的主流 LLM(大语言模型) 。

MHA 的设计使得每个注意力头 (Head) 都有独立的键 (Key) 、值 (Value) 和查询 (Query) 向量,这些向量通过线性变换得到。 MHA 的计算量和存储开销较大,特别是在长上下文 (Context) 情况下,KV Cache(键值缓存) 会占用大量显存。

瓶颈:为何降低 KV Cache 大小如此重要?

LLM 的推理主要在 GPU 上进行,而 GPU 显存有限。一部分显存用于存放模型参数和激活值,另一部分用于存放 KV Cache 。随着上下文长度增加,KV Cache 的大小会逐渐占据主导地位,可能超出单张卡甚至单台机器的显存容量。

减少 KV Cache 的目的是在更少的设备上推理更长的上下文,或在相同上下文长度下提高批处理大小,从而实现更快的推理速度或更大的吞吐量,最终降低推理成本。

MQA:多查询注意力

MQA(Multi-Query Attention) 是减少 KV Cache 的一次尝试,首次提出于 《Fast Transformer Decoding: One Write-Head is All You Need 》。 MQA 的思路是让所有注意力头共享同一个 Key 和 Value,从而将 KV Cache 减少到原来的 1/h(h 是头的数量) 。

尽管 MQA 在显存占用上有显著优势,但其效果在某些任务上可能有所下降。为了弥补这一损失,研究人员提出了 GQA 。

GQA:分组查询注意力

GQA(Grouped-Query Attention) 是 MHA 与 MQA 之间的过渡版本,出自 《GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints 》。 GQA 将所有注意力头分为 g 个组,每组共享同一对 Key 、 Value 。这样既可以减少 KV Cache,又能在一定程度上保留多样性。

GQA 提供了从 MHA 到 MQA 的自然过渡,当 g=h 时是 MHA,g=1 时是 MQA,1 < g < h 时则是 GQA 。 GQA 在 KV Cache 压缩率和效果之间提供了平衡。

MLA:多头潜在注意力

Part 1:增强模型能力

MLA(Multi-head Latent Attention) 对 GQA 进行了改进,采用低秩投影的方式替代 GQA 的分割、重复。 MLA 通过不同的投影矩阵增强模型能力,并在推理阶段通过恒等变换技巧减少 KV Cache 。

Part 2:兼容 RoPE

MLA 的一个难题是如何兼容 RoPE(旋转位置编码) 。 RoPE 是一个位置相关的矩阵,MLA 通过引入一种混合方法,每个注意力头的 Query 和 Key 新增部分维度用于添加 RoPE,从而保持 KV Cache 的减少效果。

Part 3:减少训练参数量

MLA 的最终版本将 Query 输入改为低秩投影形式,减少训练期间的参数量和相应的显存占用。推理阶段,通过恒等变换技巧减少 KV Cache,同时保持高效的计算。

小结

本文概述了多头注意力的演变历程,特别是从 MHA 到 MQA 、 GQA,最终到 MLA 的变化理念。 MLA 通过低秩投影和恒等变换技巧实现了 KV Cache 的进一步压缩,同时兼容 RoPE,称得上是一种非常实用的注意力变体。

转载本文请包括本文地址:https://kexue.fm/archives/10091

如果您还有什么疑惑或建议,欢迎在下方评论区继续讨论。如果您觉得本文不错,欢迎分享或打赏本文。打赏并非为了获得收益,而是希望知道科学空间获得了多少读者的真心关注。再次表示欢迎和感谢!


苏剑林. (May. 13, 2024). 《缓存与效果的极限拉扯:从 MHA 、 MQA 、 GQA 到 MLA 》 [Blog post]. Retrieved from 科学空间


MLA(Multi-head Latent Attention) 是一种对 GQA(Generalized Query Attention) 进行改进的注意力机制。它采用低秩投影的方式替代了 GQA 中的分割和重复操作,同时通过恒等变换技巧减少了 KV Cache 的使用。 MLA 的核心思想是通过不同的投影矩阵增强模型的能力,并在推理阶段通过恒等变换技巧减少 KV Cache 的存储和计算开销。

MLA 的改进主要解决了推理过程中的 KV Cache 问题,从而实现在更少的设备上推理更长的上下文,或者在相同的上下文长度下增大批处理大小,以实现更快的推理速度或更大的吞吐量,从而降低推理成本。

与经典的 MHA(Multi-head Attention) 和 GQA 、 MQA(Multi-query Attention) 相比,MLA 在优化 KV Cache 和保证模型效果方面具有显著的优势。 MLA 通过低秩投影的方式替代了 GQA 中的分割和重复操作,从而大大减小了 KV Cache 的大小。与 MQA 相比,MLA 的性能和效果显著优于 MQA,甚至强于 MHA 和 GQA,真正实现了降低推理成本并保证模型性能的目标。

MLA 的核心是权重矩阵的合并。在传统的 MHA 中,注意力计算涉及到多个投影矩阵的乘法运算,而 MLA 通过合并这些投影矩阵,减少了存储和计算的开销。具体来说,MLA 将 (Q. ^(T)K 的计算结果合并为一个矩阵,并将合并后的权重应用到输入上,从而减少了存储和计算的开销。

然而,尽管 MHA 也可以进行合并,但由于其特定的计算方式,无法像 MLA 那样进行合并。 MLA 通过恒等变换技巧,将合并后的权重矩阵应用到输入上,从而实现了 KV Cache 的减少。

综上所述,MLA 通过低秩投影替代分割和重复操作,采用恒等变换技巧减少 KV Cache 的使用,从而在推理过程中降低了存储和计算的开销,实现了更高效的推理和更低的成本。


Learn more:

  1. 还在用 MHA?MLA 来了 DeepSeek-v2 的 MLA 的总结和思考 - 知乎
  2. [2405.04434] DeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model
  3. DeepSeek-V2 中的 MLA 详解 - 知差 (chai)

发表评论