DeepSeek-V2 中的 MLA 详解

DeepSeek-V2 是 DeepSeek 团队最新发布的 MoE(Mixture of Experts) 架构的 LLM(大型语言模型) 底座。该模型拥有 236B 的总参数量和 21B 的每个 token 激活参数量,支持 128K tokens 的上下文长度。 DeepSeek-V2 的一个核心创新点就是 Multi-head Latent Attention(MLA) 。

Multi-head Latent Attention(MLA) 简介

MLA 对传统 Transformer 中的多头注意力机制 (MHA) 进行了改进,主要目标是:

  1. 降低推理时 KV Cache 的存储开销;
  2. 缓解 GQA(Grouped-Query Attention) 和 MQA(Multi-Query Attention) 等方法导致的模型性能损耗。

标准的 MHA 结构

在标准的 MHA 结构中,每个 token 的 query 、 key 和 value 通过参数矩阵映射得到,并分割成多个注意力头。每个头独立计算注意力权重并得到输出,这个过程虽然能捕捉丰富的上下文信息,但在推理时需要缓存大量的 KV Cache 。

MLA 如何改进?

MLA 通过对 keys 和 values 进行低秩联合压缩来降低 KV Cache:

  1. 低秩 Key-Value 联合压缩
    [
    \mathbf{c}_t^{KV} = W^{DKV} \mathbf{h}_t
    ]
    [
    \mathbf{k}_t^C = W^{UK} \mathbf{c}_t^{KV}
    ]
    [
    \mathbf{v}_t^C = W^{UV} \mathbf{c}_t^{KV}
    ]
    其中,(\mathbf{c}_t^{KV}) 表示压缩后的隐向量,(W^{DKV}) 是降维映射矩阵,(W^{UK}) 和 (W^{UV}) 是升维映射矩阵。在推理时,只需要缓存隐向量 (\mathbf{c}_t^{KV}),显著减少了 KV Cache 的容量。
  2. Queries 的低秩压缩
    [
    \mathbf{c}_t^Q = W^{DQ} \mathbf{h}_t
    ]
    [
    \mathbf{q}_t^C = W^{UQ} \mathbf{c}_t^Q
    ]
    这样即便不能减少 KV Cache,但可以降低训练过程中的激活内存。

代码实现

以下是 MLA 在 DeepSeek-V2 中的 Python 代码实现片段:


class DeepSeekV2Attention(nn.Module):
def init(self, config: DeepSeekV2Config, layer_idx: Optional[int] = None):

self.w_dq = nn.Linear(self.hidden_size, config.q_lora_rank, bias=config.attention_bias)
self.w_uq = nn.Linear(config.q_lora_rank, self.num_heads * self.q_head_dim, bias=False)
self.w_dkv = nn.Linear(self.hidden_size, self.dc, bias=config.attention_bias)
self.w_uk = nn.Linear(self.dc, self.num_heads * self.q_head_dim, bias=False)
self.w_uv = nn.Linear(self.dc, self.num_heads * self.q_head_dim, bias=False)

def forward(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None,
            position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None,
            output_attentions: bool = False, use_cache: bool = False, **kwargs):
    bsz, q_len, _ = hidden_states.size()

    q = self.w_uq(self.q_a_layernorm(self.w_dq(hidden_states))).view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
    kv_seq_len = q.size(-2)
    if past_key_value is not None:
        kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)

    compressed_kv = self.w_dkv(hidden_states)
    if past_key_value is not None:
        compressed_kv = past_key_value.update(compressed_kv)

    k = self.w_uk(compressed_kv).view(bsz, -1, self.num_heads, self.q_head_dim).transpose(1, 2)
        v = self.w_uv(compressed_kv).view(bsz, -1, self.num_heads, self.q_head_dim).transpose(1, 2)

        attn_weights = torch.matmul(q, k.transpose(2, 3)) * self.softmax_scale
        attn_weights = nn.functional.softmax(attn_weights, dim=-1)

        if output_attentions:
            outputs = (attn_weights,)
        else:
            outputs = ()

        attn_output = torch.matmul(attn_weights, v)
        attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, q_len, self.hidden_size)

        attn_output = self.out_proj(attn_output)
        outputs = (attn_output,) + outputs

        if use_cache:
            outputs = outputs + (past_key_value,)

        return outputs
```


结论
DeepSeek-V2 通过引入 Multi-head Latent Attention(MLA) 结构,成功优化了传统的多头注意力机制 (MHA),在保证模型性能的同时,显著降低了推理时 KV Cache 的存储开销。这不仅提高了模型的效率,也为未来的大模型架构设计提供了新的思路。

MLA 的实现通过对 queries 、 keys 和 values 进行低秩压缩,减少了存储需求,缓解了因 GQA 和 MQA 方法导致的性能损耗。这种创新在深度学习模型的设计中具有重要的参考价值。

如果你对于 DeepSeek-V2 的 MLA 结构有更多的兴趣,建议查看其开源代码和详细文档,以便深入理解其工作机制和实现细节。

发表评论