DeepSeek-V2 是 DeepSeek 团队最新发布的 MoE(Mixture of Experts) 架构的 LLM(大型语言模型) 底座。该模型拥有 236B 的总参数量和 21B 的每个 token 激活参数量,支持 128K tokens 的上下文长度。 DeepSeek-V2 的一个核心创新点就是 Multi-head Latent Attention(MLA) 。
Multi-head Latent Attention(MLA) 简介
MLA 对传统 Transformer 中的多头注意力机制 (MHA) 进行了改进,主要目标是:
- 降低推理时 KV Cache 的存储开销;
- 缓解 GQA(Grouped-Query Attention) 和 MQA(Multi-Query Attention) 等方法导致的模型性能损耗。
标准的 MHA 结构
在标准的 MHA 结构中,每个 token 的 query 、 key 和 value 通过参数矩阵映射得到,并分割成多个注意力头。每个头独立计算注意力权重并得到输出,这个过程虽然能捕捉丰富的上下文信息,但在推理时需要缓存大量的 KV Cache 。
MLA 如何改进?
MLA 通过对 keys 和 values 进行低秩联合压缩来降低 KV Cache:
- 低秩 Key-Value 联合压缩:
[
\mathbf{c}_t^{KV} = W^{DKV} \mathbf{h}_t
]
[
\mathbf{k}_t^C = W^{UK} \mathbf{c}_t^{KV}
]
[
\mathbf{v}_t^C = W^{UV} \mathbf{c}_t^{KV}
]
其中,(\mathbf{c}_t^{KV}) 表示压缩后的隐向量,(W^{DKV}) 是降维映射矩阵,(W^{UK}) 和 (W^{UV}) 是升维映射矩阵。在推理时,只需要缓存隐向量 (\mathbf{c}_t^{KV}),显著减少了 KV Cache 的容量。 - Queries 的低秩压缩:
[
\mathbf{c}_t^Q = W^{DQ} \mathbf{h}_t
]
[
\mathbf{q}_t^C = W^{UQ} \mathbf{c}_t^Q
]
这样即便不能减少 KV Cache,但可以降低训练过程中的激活内存。
代码实现
以下是 MLA 在 DeepSeek-V2 中的 Python 代码实现片段:
class DeepSeekV2Attention(nn.Module):
def init(self, config: DeepSeekV2Config, layer_idx: Optional[int] = None):
…
self.w_dq = nn.Linear(self.hidden_size, config.q_lora_rank, bias=config.attention_bias)
self.w_uq = nn.Linear(config.q_lora_rank, self.num_heads * self.q_head_dim, bias=False)
self.w_dkv = nn.Linear(self.hidden_size, self.dc, bias=config.attention_bias)
self.w_uk = nn.Linear(self.dc, self.num_heads * self.q_head_dim, bias=False)
self.w_uv = nn.Linear(self.dc, self.num_heads * self.q_head_dim, bias=False)
def forward(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None,
output_attentions: bool = False, use_cache: bool = False, **kwargs):
bsz, q_len, _ = hidden_states.size()
q = self.w_uq(self.q_a_layernorm(self.w_dq(hidden_states))).view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
kv_seq_len = q.size(-2)
if past_key_value is not None:
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
compressed_kv = self.w_dkv(hidden_states)
if past_key_value is not None:
compressed_kv = past_key_value.update(compressed_kv)
k = self.w_uk(compressed_kv).view(bsz, -1, self.num_heads, self.q_head_dim).transpose(1, 2)
v = self.w_uv(compressed_kv).view(bsz, -1, self.num_heads, self.q_head_dim).transpose(1, 2)
attn_weights = torch.matmul(q, k.transpose(2, 3)) * self.softmax_scale
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
if output_attentions:
outputs = (attn_weights,)
else:
outputs = ()
attn_output = torch.matmul(attn_weights, v)
attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, q_len, self.hidden_size)
attn_output = self.out_proj(attn_output)
outputs = (attn_output,) + outputs
if use_cache:
outputs = outputs + (past_key_value,)
return outputs
```
结论
DeepSeek-V2 通过引入 Multi-head Latent Attention(MLA) 结构,成功优化了传统的多头注意力机制 (MHA),在保证模型性能的同时,显著降低了推理时 KV Cache 的存储开销。这不仅提高了模型的效率,也为未来的大模型架构设计提供了新的思路。
MLA 的实现通过对 queries 、 keys 和 values 进行低秩压缩,减少了存储需求,缓解了因 GQA 和 MQA 方法导致的性能损耗。这种创新在深度学习模型的设计中具有重要的参考价值。
如果你对于 DeepSeek-V2 的 MLA 结构有更多的兴趣,建议查看其开源代码和详细文档,以便深入理解其工作机制和实现细节。