DeepSeek MLA 如何通过“矩阵吸收”实现 MHA 到 MQA 的无缝切换?
在当前的大模型架构之争中,显存效率(KV Cache)与模型性能(表达能力)往往是鱼与熊掌不可兼得。MHA(多头注意力)性能好但显存爆炸,MQA(多查询注意力)显存极低但可能损耗性能。
DeepSeek-V2/V3 提出的MLA (Multi-Head Latent Attention)架构,巧妙地通过低秩压缩(Low-Rank Compression)和矩阵吸收(Matrix Absorption)技术,实现了“训练时是 MHA,推理时是 MQA”的神奇效果。
本文结合 DeepSeek-V3.2 论文插图与科学空间的解析,带你拆解这一过程。
1. 核心思想:KV 的低秩压缩
MLA 的出发点是:不直接存储巨大的K KK和V VV矩阵,而是存储一个压缩后的低维潜在向量c K V c_{KV}cKV。
在标准的 MHA 中,每个头(Head)都有自己独立的K KK和V VV。而在 MLA 中,生成逻辑如下:
- 输入向量h t h_tht经过投影生成压缩向量c K V c_{KV}cKV。
- 训练时(图 a - MHA Mode):c K V c_{KV}cKV通过两个上投影矩阵W U K W^{UK}WUK和W U V W^{UV}WUV,“还原”出每个头所需的k t , i C k_{t,i}^Ckt,iC和v t , i C v_{t,i}^Cvt,iC。
k t , i C = c K V ⋅ W i U K k_{t,i}^C = c_{KV} \cdot W_{i}^{UK}kt,iC=cKV⋅WiUK
v t , i C = c K V ⋅ W i U V v_{t,i}^C = c_{KV} \cdot W_{i}^{UV}vt,iC=cKV⋅WiUV
这看起来依然是 MHA,因为每个头确实获得了解耦的 Key 和 Value。
2. 推理时的魔法:矩阵吸收 (Matrix Absorption)
MLA 最大的创新在于:在推理(Decoding)阶段,我们不需要真的把k kk和v vv还原出来存入 KV Cache。
利用矩阵乘法的结合律,我们可以将用于还原K KK和V VV的投影矩阵(W U K , W U V W^{UK}, W^{UV}WUK,WUV),分别“吸收”到Query 端和Output 端。
A. Key 的吸收(变为 MQA 形式)
计算 Attention 分数的核心公式是Q ⋅ K T Q \cdot K^TQ⋅KT。在 MLA 中,代入K KK的生成公式:
Score = q t , i C ⋅ ( k t , i C ) T = q t , i C ⋅ ( c K V ⋅ W i U K ) T \text{Score} = q_{t,i}^C \cdot (k_{t,i}^C)^T = q_{t,i}^C \cdot (c_{KV} \cdot W_{i}^{UK})^TScore=qt,iC⋅(kt,iC)T=qt,iC⋅(cKV⋅WiUK)T
利用转置性质( A B ) T = B T A T (AB)^T = B^T A^T(AB)T=BTAT:
Score = q t , i C ⋅ ( W i U K ) T ⋅ c K V T \text{Score} = q_{t,i}^C \cdot (W_{i}^{UK})^T \cdot c_{KV}^TScore=qt,iC⋅(WiUK)T⋅cKVT
这里发生了一个关键变换:我们可以结合q t , i C ⋅ ( W i U K ) T q_{t,i}^C \cdot (W_{i}^{UK})^Tqt,iC⋅(WiUK)T作为一个新的 Query。
- 对应图中 (b) 的蓝色箭头:W i U K W_i^{UK}WiUK不再用于生成K KK,而是直接作用于 Query。
- 结果:KV Cache 中只需要存储压缩后的c K V c_{KV}cKV。对于所有头来说,c K V c_{KV}cKV是共享的。这不就是MQA (Multi-Query Attention)吗?(即所有头共享一个 Key)。
B. Value 的吸收
同理,对于 Attention 的输出计算:
o t , i = AttnWeight ⋅ v t , i C = AttnWeight ⋅ ( c K V ⋅ W i U V ) o_{t,i} = \text{AttnWeight} \cdot v_{t,i}^C = \text{AttnWeight} \cdot (c_{KV} \cdot W_{i}^{UV})ot,i=AttnWeight⋅vt,iC=AttnWeight⋅(cKV⋅WiUV)
利用结合律,我们可以先计算AttnWeight ⋅ c K V \text{AttnWeight} \cdot c_{KV}AttnWeight⋅cKV,最后再乘以W i U V W_{i}^{UV}WiUV。
- 对应图中 (b) 的橙色箭头:W i U V W_i^{UV}WiUV被移到了 Attention 计算之后,甚至可以进一步融合到最终的 Output Projection (W O W_OWO) 中。
- 结果:KV Cache 中不需要存展开的V VV,只需要存c K V c_{KV}cKV。
3. RoPE 的处理(Decoupled RoPE)
细心的读者会发现图中还有一个apply RoPE的分支。
为了避免旋转位置编码(RoPE)破坏上述的矩阵吸收特性(RoPE 是位置敏感的,不能简单被线性矩阵吸收),MLA 采用了Decoupled RoPE(解耦 RoPE)策略:
q = [ q c o n t e n t , q r o p e ] ; k = [ k c o n t e n t , k r o p e ] q = [q_{content}, q_{rope}]; \quad k = [k_{content}, k_{rope}]q=[qcontent,qrope];k=[kcontent,krope]
- Content 部分 (c K V c_{KV}cKV):完全压缩,执行矩阵吸收,变成 MQA 模式。
- RoPE 部分 (k R k^RkR):单独保留,携带位置信息,随c K V c_{KV}cKV一起缓存。
4. 总结:图 (a) 到 图 (b) 的变换
回到 DeepSeek-V3.2 的 Figure 7:
- 图 (a) MHA Mode:展示了逻辑上的计算过程。c K V c_{KV}cKV分裂并通过W U K W^{UK}WUK、W U V W^{UV}WUV变成多头的k kk和v vv。这是模型训练时的视角,保证了模型拥有多头的表达能力。
- 图 (b) MQA Mode:展示了物理上的计算过程(推理时)。
- W U K W^{UK}WUK被吸收到 Query 侧(蓝色箭头)。
- W U V W^{UV}WUV被吸收到 Output 侧(橙色箭头)。
- KV Cache:只剩下灰色的c K V c_{KV}cKV和小部分的k R k^RkR。
结论:MLA 通过数学上的等价变换,在不损失 MHA 性能(因为数学上完全等价)的前提下,将推理时的显存占用降低到了MQA 的水平。这就是 DeepSeek-V3 能够支持超长上下文且推理高效的核心秘密。