LLM 架构计算方法论

从 config.json 到参数量、FLOPs、KV Cache、推理显存的完整计算推导。基于 8 个开源模型的实战拆解经验。


目录

  • CH 1 预备知识:从 config.json 到矩阵乘法
  • CH 2 参数分解:这个模型有多大
  • CH 3 FLOPs 估算:推理一次花多少计算
  • CH 4 KV Cache 显存:长上下文为什么吃显存
  • CH 5 推理显存:部署需要多少卡
  • CH 6 实战:MiniMax M3 完整推演
  • 附录 A config.json 字段速查表
  • 附录 B 符号与缩写表
  • 附录 C 8 个模型计算结果速览

阅读导航

你的目标推荐阅读路径预计时间
快速了解全貌CH 1.2(FLOPs基础)→ CH 2.3(Attention参数)→ CH 4.2(KV cache公式)→ 附录 C(8模型速览)30 min
学会算参数量CH 1.1(config字段)→ CH 1.4(符号表)→ CH 2(全章,4个案例代入)→ CH 2.10(Nemotron完整推演)60 min
学会算 FLOPsCH 1.2(FLOPs公式)→ CH 3.2(Full Attn)→ 按需选读 3.3(MSA)/ 3.4(MLA)/ 3.5(Mamba-2)/ 3.6(SWA)/ 3.7(DeltaNet)→ CH 3.10(跨架构对比)45 min
学会算 KV cacheCH 1.4(符号表)→ CH 4.2(标准GQA)→ CH 4.3(MLA重点)→ CH 4.5(SWA)/ 4.6(DeltaNet)/ 4.7(Mamba-2)→ CH 4.9(对比表)40 min
独立推演一个模型CH 1(预备知识,15 min)→ CH 6(M3完整推演,对照 config.json 自己算一遍)90 min
查漏补缺附录 A(config字段→计算项映射)→ 附录 B(符号表)→ 定位到对应章节5 min

各章之间的依赖关系:CH 2 → CH 3(参数是 FLOPs 的输入,但 FLOPs 的核心公式独立)→ CH 4(FLOPs 和 KV cache 无依赖,可并行阅读)→ CH 5(依赖 CH 2 + CH 4)→ CH 6(依赖全部)。

新读者建议:从 CH 1.2(5 分钟搞懂 FLOPs 怎么数)和 CH 4.2(10 分钟搞懂 KV cache 怎么算)开始——这两节能让你最快建立「能算」的感觉。


CH 1 预备知识 & CH 2 参数分解

读者定位:有 Transformer 基础知识的工程师,目标是从 config.json 独立推导任意模型的参数量。


CH 1 | 预备知识

1.1 读 config.json

参数量不是猜出来的——config.json 是唯一真相来源。

下表列出影响计算的核心字段,每种架构类型给一个真实案例:

字段含义Nemotron 3 UltraMiniMax M3Kimi K2.5DeepSeek V4 Flash
hidden_size残差流维度 $d$8192614471684096
num_attention_headsQ 头数 $H_q$64646464
num_key_value_headsKV 头数 $H_{kv}$24641
head_dim每 head 维度 $D_h$128128(见 MLA)512
intermediate_sizeDense FFN 中间维 $d_{ff}$5120307218432
moe_intermediate_sizeMoE 专家中间维5120307220482048
n_routed_experts路由专家数 $E$512128384256
num_experts_per_tok每 token 激活专家数 $k$22486
vocab_size词表大小 $V$131072200064163840129280

MLA(Multi-head Latent Attention)特有字段

字段Kimi K2.5含义
kv_lora_rank512K 和 V 的压缩维度 $d_{kv}$
q_lora_rank1536Q 的压缩维度 $d_q$
qk_nope_head_dim128每头无位置编码维度 $D_{nope}$
qk_rope_head_dim64每头 RoPE 维度 $D_{rope}$
v_head_dim128每头 V 维度 $D_v$

注:MLA 中 $D_h = D_{nope} + D_{rope}$,对 K2.5 而言 $D_h = 128 + 64 = 192$。GQA/MHA 模型通常直接给 head_dim,不需要这几个字段。

Mamba-2 特有字段(Nemotron):

字段Nemotron含义
ssm_state_size128SSM 隐状态维度 $N$
mamba_num_heads256Mamba 头数 $H_{mamba}$
mamba_head_dim64Mamba 每头维度 $D_{mamba}$
n_groups8A 矩阵分组数(Mamba-2 的多头扩展)
conv_kernel41D 深度卷积核大小
expand2内部扩展因子($d_{inner} = 2 \times d$)

Vision 相关字段(M3):

字段M3 值含义
vision_config.hidden_size1280ViT 隐藏维度
vision_config.num_attention_heads16ViT 注意力头数
vision_config.num_hidden_layers32ViT 层数
vision_config.intermediate_size5120ViT MLP 中间维
vision_config.patch_size14Patch 大小
vision_config.image_size2016输入图像尺寸

MoE 相关补充字段

字段含义例子
moe_latent_sizeNemotron 低秩投影维度2048
moe_shared_expert_intermediate_size共享专家中间维Nemotron: 10240
dense_intermediate_sizeDense 层 FFN 中间维(M3 前 3 层)M3: 12288
shared_intermediate_size共享专家中间维(M3)M3: 3072
n_shared_experts共享专家数量通常为 1
scoring_func路由评分函数sigmoid / softmax
tie_word_embeddings输入/输出 Embedding 是否共享权重false(多数大模型不共享)

实战提示:打开 config.json 后,先把上述字段圈出来列成一个小表。后续所有计算都不需要看源码——只看这个表就能推出 95% 以上的参数量。


1.2 矩阵乘法 FLOPs 是怎么算的

建立“矩阵乘法的计算量直觉”。参数量是“存了多少数”,FLOPs 是“每次前向要算多少步”——两者是同一个硬币的两面。

1.2.1 基本定义

现代深度学习框架中,一次 Multiply-Accumulate(MAC,乘加)计为 2 FLOPs(1 次乘法 + 1 次加法)。

矩阵乘法 $C = A \cdot B$,其中 $A \in \mathbb{R}^{m \times k}$,$B \in \mathbb{R}^{k \times n}$:

$$\text{FLOPs} = 2 \cdot m \cdot n \cdot k$$

1.2.2 完整代入案例

假设我们在计算 Attention 层的 Q 投影:

$$\text{hidden\_states} \in \mathbb{R}^{1 \times 4096 \times 6144}$$$$W_Q \in \mathbb{R}^{6144 \times 6144}$$$$\text{FLOPs}_{Q} = 2 \times 4096 \times 6144 \times 6144 = 2 \times 4096 \times 37,748,736$$$$= 309{,}237{,}645{,}312 \approx 309 \text{ GFLOPs}$$

$m \times n$ 是输出矩阵的大小($4096 \times 6144$),每个输出元素需要做 $k=6144$ 次乘加。把这 2500 万个输出元素每个都做 6144 次运算,再 ×2,就是总的浮点运算次数。

1.2.3 分解技巧

一个大矩阵乘法的 FLOPs 可以按“输出形状 × 2 × 公共维度”来记:

  • hidden [B, S, d] × W [d, d_out] $\to$ FLOPs = $2 \cdot B \cdot S \cdot d \cdot d_{out}$
  • Q [B, H, S, D] × K^T [B, H, D, S] $\to$ FLOPs = $2 \cdot B \cdot H \cdot S \cdot S \cdot D$
  • attn [B, H, S, S] × V [B, H, S, D] $\to$ FLOPs = $2 \cdot B \cdot H \cdot S \cdot S \cdot D$

其中 $B$ 是 batch size,$S$ 是序列长度,$H$ 是头数,$D$ 是每头维度。


1.3 einsum 是什么

读懂 PyTorch/Flax 代码中 einsum 的维度缩并记法。绝大多数模型源码用 einsum 写注意力计算,看不懂 einsum 就看不懂代码。

1.3.1 基本语法

1
torch.einsum("bhqk,bhkv->bhqv", Q, K_T, V)

规则:

  • -> 左边是输入的维度标签,逗号分隔多个输入
  • -> 右边是输出的维度标签
  • 出现在左边但不出现在右边的标签 = 被求和缩并掉的维度
  • 字母顺序任意,但同一字母在同一个输入中只能出现一次

1.3.2 具体案例

1
2
3
4
5
6
7
8
# Q: [Batch=2, Heads=16, Seq_q=4096, Dim=64]  -> 标签 bhqk
# K: [Batch=2, Heads=16, Seq_k=4096, Dim=64]  -> 标签 bhkv  
# (注意 K 的最后一维用了与 Q 不同的标签 v,Q 最后一维是 k)

# einsum("bhqk,bhkv->bhqv", Q, K)
# 缩并维度:k(Q 的第 4 维 和 K 的第 4 维做点积)
# 保留维度:b, h, q, v
# 输出形状: [2, 16, 4096, 4096]  -> 即注意力分数矩阵

einsum 就是“对着字母做操作”——同名字母在左边多个输入中出现就做点积(乘法+求和),只在一个输入中出现的字母保留到输出。你不需要想象循环嵌套,只需要追踪每个字母的维度大小。

1.3.3 常见 Attention 计算模式

einsum 模式含义输入形状输出形状
bhqk,bhkv->bhqvQK 点积求 attention score[B,H,S_q,D] × [B,H,S_k,D][B,H,S_q,S_k]
bhqv,bhvd->bhqdAttention × V 加权[B,H,S_q,S_k] × [B,H,S_k,D][B,H,S_q,D]
bnhd,hdo->bnoOutput 投影(合并 heads)[B,H,S,D] × [H,D,d][B,S,d]
bsi,io->bso标准 Linear 层[B,S,in] × [in,out][B,S,out]

实战提示:在 PyTorch 代码中遇到 einsum 时,第一步不是理解计算逻辑,而是写出每个字母代表的维度大小——然后你就可以心算 FLOPs 了。


1.4 符号约定

统一本文后续所有公式的符号,避免混淆。RoPE 的 head_dim 和 V 的 head_dim 在 MLA 中是两个不同的值,不约定清楚会算错。

符号含义NemotronM3K2.5
$B$Batch size
$S_q$ / $S_k$ / $S_v$Query / Key / Value 序列长度(prefill 时 $S_q = S_k$)
$L$层数(不含 MTP/Vision)606061
$d$hidden_size,残差流维度819261447168
$V$vocab_size,词表大小131072200064163840
$H_q$ / $H_{kv}$Q 头数 / KV 头数64 / 264 / 464 / 64
$D_h$head_dim,每头维度(GQA/MHA 中)128128
$D_{nope}$ / $D_{rope}$MLA 中无位置/有位置编码维度128/64
$D_v$MLA 中每头 V 维度128
$d_{kv}$ / $d_q$MLA 压缩维度(kv_lora_rank / q_lora_rank512/1536
$d_{ff}$FFN 中间维度51203072 (MoE) / 12288 (Dense)2048 (MoE) / 18432 (Dense)
$d_{moe_ff}$moe_intermediate_size,MoE 专家中间维512030722048
$E$n_routed_experts,路由专家总数512128384
$k$num_experts_per_tok,每 token 激活专家数2248
$H_{mamba}$ / $D_{mamba}$Mamba 头数/每头维度256/64
$N$ssm_state_size,SSM 状态维度128
$T$序列长度(tokens)
bytes每个参数的字节数(BF16=2, FP8=1, FP4=0.5)

关键区分

  • $D_h$(GQA/MHA):一个值,Q、K、V 的 head_dim 相同
  • $(D_{nope}, D_{rope}, D_v)$(MLA):三个独立值,Q/K 的维度是 $D_{nope}+D_{rope}$,V 的维度是 $D_v$

1.5 Bytes 换算

参数个数 $\to$ 显存占用的转换。算完参数量不乘 bytes 等于白算——内存是按字节分配的,不是按“个数”。

精度字节/参数典型应用场景
FP324训练主权重(full precision)
BF162推理权重、训练前向(主流默认)
FP162部分训练框架
FP8 (E4M3)1推理量化、部分训练(如 DeepSeek V4 的 quantization_config
INT81推理量化
INT4 / NVFP40.5极端推理量化(如 Nemotron 的 NVFP4 训练)
FP4 (E2M1)0.5权重级极限压缩

换算公式

$$\text{Memory (GiB)} = \frac{\text{Params} \times \text{bytes}}{1024^3} = \frac{\text{Params} \times \text{bytes}}{1{,}073{,}741{,}824}$$

本文使用 GiB(2³⁰ bytes)而非 GiB(10⁹ bytes)用于显存计算,因为 GPU 显存以 2 的幂次分配。1 GiB ≈ 1.074 GB。

实战案例:Nemotron 3 Ultra 的 550B 参数以 BF16 存储:

$$550 \times 10^9 \times 2 \text{ bytes} = 1.1 \times 10^{12} \text{ bytes} \approx 1024 \text{ GiB} \approx 1 \text{ TiB}$$

如果换成 NVFP4 推理(仅权重部分):

$$550 \times 10^9 \times 0.5 \text{ bytes} = 275 \text{ GiB}$$

这一节内容少,但每次计算都要用到——建议手写贴在显示器旁边。


CH 2 | 参数分解

2.1 通用原理

建立“参数就是矩阵元素数”的底层逻辑。所有花哨的架构(GQA、MLA、MoE、Mamba)最终都可以归结为“有多少个权重矩阵,每个矩阵的形状是什么”。

核心公式

$$\text{Params} = \sum_{W \in \text{所有权重矩阵}} \text{size}(W)$$

其中 $\text{size}(W) = \text{in_features} \times \text{out_features}$(不含 bias,大模型中 bias 通常为 False 或可忽略)。

一级近似

一个 Decoder-only Transformer 的总参数主要由以下模块构成:

$$\text{Params}_{\text{total}} = \text{Params}_{\text{embed}} + L \times \text{Params}_{\text{attn}} + L_{\text{dense}} \times \text{Params}_{\text{ffn\_dense}} + L_{\text{moe}} \times \text{Params}_{\text{moe}} + \text{Params}_{\text{norm}} + \text{Params}_{\text{head}} + \text{Params}_{\text{other}}$$

参数量就是“把所有权重矩阵的元素数加起来”。Embedding 是一个大矩阵,每层有一个 Attention(QKV+O)和一个 FFN(gate+up+down 或 up+down),MoE 把 FFN 复制了 $E$ 份,Mamba 把 Attention 换成了自己的 in/out/ssm 参数。


2.2 Embedding 层

计算输入/输出 Embedding 的参数量。对 100K+ 词表的模型,Embedding 就占了 ~1B 参数——不是小数目。

公式

$$\text{Params}_{\text{embed\_in}} = V \times d$$$$\text{Params}_{\text{embed\_out}} = \begin{cases} V \times d & \text{若 } \texttt{tie\_word\_embeddings} = \texttt{false} \\ 0 & \text{若 } \texttt{tie\_word\_embeddings} = \texttt{true} \text{(共享输入权重)} \end{cases}$$

案例

Nemotron 3 Ultra:$V = 131072$(vocab_size),$d = 8192$(hidden_size),tie_word_embeddings = false

$$\text{Params}_{\text{embed\_in}} = 131{,}072 \times 8192 = 1{,}073{,}741{,}824 \approx 1.07\text{B}$$$$\text{Params}_{\text{embed\_out}} = 131{,}072 \times 8192 = 1.07\text{B}$$$$\text{Params}_{\text{embed\_total}} \approx 2.15\text{B}$$

MiniMax M3:$V = 200{,}064$,$d = 6144$,tie_word_embeddings = false

$$\text{Params}_{\text{embed\_in}} = 200{,}064 \times 6144 = 1{,}229{,}193{,}216 \approx 1.23\text{B}$$

Embedding 就是一个查表操作——$V$ 行,每行是一个 $d$ 维向量。输入和输出通常各有一个独立的表(因为 tie_word_embeddings=false 在大模型中很常见),所以 Embedding 部分大致是 $2Vd$ 个参数。131K 词表 × 8K 维度 ≈ 1B 参数一块,两块就是 2B。


2.3 Attention 参数

计算 Q、K、V、O 四个投影矩阵的参数。Attention 的计算量由序列长度主导,但参数量只由维度决定——理解这一点才能区分“参数”和“FLOPs”两个概念。

2.3.1 标准 MHA(Multi-Head Attention)

每个头独立的 Q、K、V,无 GQA 压缩。

$$\text{Params}_{Q} = d \times H_q \times D_h$$$$\text{Params}_{K} = d \times H_{kv} \times D_h$$$$\text{Params}_{V} = d \times H_{kv} \times D_h$$$$\text{Params}_{O} = H_q \times D_h \times d$$$$\text{Params}_{\text{MHA}} = d \times H_q \times D_h + 2 \times d \times H_{kv} \times D_h + H_q \times D_h \times d$$

当 $H_{kv} = H_q$ 时(纯 MHA,无 GQA):

$$\text{Params}_{\text{MHA}} = 4 \times d \times H_q \times D_h = 4 \times d^2 \quad (\text{若 } H_q \times D_h = d)$$

Kimi K2.5 的 Layer 0(全 MHA,$d=7168$,$H_q=H_{kv}=64$,$D_h = D_{nope} + D_{rope} = 192$):

$$\text{Params}_{Q} = 7168 \times 64 \times 192 = 88{,}080{,}384 \approx 88.1\text{M}$$$$\text{Params}_{K} = 7168 \times 64 \times 192 = 88.1\text{M}$$$$\text{Params}_{V} = 7168 \times 64 \times 128 = 58{,}720{,}256 \approx 58.7\text{M}$$$$\text{Params}_{O} = 64 \times 128 \times 7168 = 58.7\text{M}$$$$\text{Params}_{\text{MHA, per layer}} \approx 293.6\text{M}$$

注意 V 的头维度是 128(v_head_dim),不是 192——这是 MLA 的设计,即使在不压缩的 Layer 0 也遵循同样的维度约定。

MHA 就是四个大矩阵——Q 把 d 维投影到 d 维($H \times D = d$),K 也一样,V 也一样,O 把 d 维映射回 d 维。4 × d² 就是每层 Attention 的“起步价”。

2.3.2 GQA(Grouped Query Attention)

GQA 的核心:Q 头数不变,K 和 V 头数减少。K、V 矩阵“变窄”了

$$\text{Params}_{\text{GQA}} = d \times H_q \times D_h + 2 \times d \times H_{kv} \times D_h + H_q \times D_h \times d$$

Nemotron 3 Ultra(GQA 32:1,$d=8192$,$H_q=64$,$H_{kv}=2$,$D_h=128$):

$$\text{Params}_{Q} = 8192 \times 64 \times 128 = 67{,}108{,}864 \approx 67.1\text{M}$$$$\text{Params}_{K} = 8192 \times 2 \times 128 = 2{,}097{,}152 \approx 2.1\text{M}$$$$\text{Params}_{V} = 8192 \times 2 \times 128 = 2.1\text{M}$$$$\text{Params}_{O} = 64 \times 128 \times 8192 = 67.1\text{M}$$$$\text{Params}_{\text{GQA, per layer}} \approx 138.4\text{M}$$

对比全 MHA($H_{kv}=64$)的 $4 \times 8192^2 = 268.4\text{M}$,GQA 32:1 将 Attention 参数量压到了 48%

MiniMax M3(GQA 16:1,$d=6144$,$H_q=64$,$H_{kv}=4$,$D_h=128$):

$$\text{Params}_{Q} = 6144 \times 64 \times 128 = 50{,}331{,}648 \approx 50.3\text{M}$$$$\text{Params}_{K} = 6144 \times 4 \times 128 = 3{,}145{,}728 \approx 3.1\text{M}$$$$\text{Params}_{V} = 6144 \times 4 \times 128 = 3.1\text{M}$$$$\text{Params}_{O} = 64 \times 128 \times 6144 = 50.3\text{M}$$$$\text{Params}_{\text{GQA, per layer}} \approx 107.0\text{M}$$

GQA 公式速记:Q 矩阵 $d \times d$(因为 $H_q \times D_h = d$),K 矩阵 $d \times (H_{kv} \times D_h)$——“变窄”的矩阵,V 同理,O 矩阵 $d \times d$。GQA 比 MHA 省的就是 $K$、$V$ 投影省出来的 $2 \times d \times (H_q - H_{kv}) \times D_h$ 个参数。

⚠️ 最容易犯的错误:GQA 中 $H_{kv} \times D_h \neq d$(因为 $H_{kv} < H_q$,而 $H_q \times D_h = d$)。K/V 投影的输出维度不是 hidden_size,而是 num_kv_heads × head_dim。很多人直接写 d × d 给 K/V——这是 MHA 才对的。你可以在心里验证:Nemotron GQA 32:1 → K 投影 = $8192 \times (2 \times 128) = 8192 \times 256 = 2.1\text{M}$,远小于 Q 投影的 $8192^2 = 67.1\text{M}$。

MHA 里 K 和 V 也是 $d \times d$ 的方阵,GQA 把它们变窄了——因为 KV 头数只有 Q 头数的几十分之一,所以 KV 投影矩阵的列数就是 Q 投影的几分之一。O 投影不受影响,因为输出维度($H_q \times D_h = d$)不变。

2.3.3 MLA(Multi-head Latent Attention)

这节是最复杂的部分。MLA 的核心思想:不在高维空间存 K 和 V,而是先压缩到一个低维“潜空间”,再从潜空间解压恢复。这就像一个 zip 压缩——存储/传输时用压缩格式,使用时解压。

MLA 将 K 分解为两部分:

  • nope 分量(No Position Encoding):128 维,可压缩(通过潜空间)
  • rope 分量(RoPE Position Encoding):64 维,不可压缩(RoPE 必须在全维度上旋转,不能压缩后旋转)
矩阵清单

以 Kimi K2.5 为例:$d=7168$,$d_{kv}=512$,$d_q=1536$,$H=64$,$D_{nope}=128$,$D_{rope}=64$,$D_v=128$。

矩阵形状含义
$W_{kv_a}$$d \times (d_{kv} + D_{rope})$K 压缩:hidden $\to$ 压缩 KV + RoPE 分量
$W_{kv_b}$$d_{kv} \times H \times (D_{nope} + D_v)$K/V 解压:压缩空间 $\to$ 所有 head 的 nope K + V
$W_{q_a}$$d \times d_q$Q 压缩:hidden $\to$ 压缩 Q
$W_{q_b}$$d_q \times H \times D_{nope}$Q 解压:压缩 Q $\to$ 所有 head 的 nope Q
$W_{q_rope}$$d \times H \times D_{rope}$Q RoPE 分量:hidden $\to$ 所有 head 的 rope Q(不压缩)
$W_o$$H \times D_v \times d$输出投影
逐项代入计算

(1) KV 压缩投影 $W_{kv_a}$

$$\text{Params}_{kv\_a} = d \times (d_{kv} + D_{rope}) = 7168 \times (512 + 64) = 7168 \times 576 = 4{,}128{,}768 \approx 4.13\text{M}$$

这个投影将 hidden 映射为 512 维的压缩潜向量 + 64 维的 RoPE 分量。后 64 维不参与压缩,直接作为 K 的 rope 部分使用。

(2) K/V 解压投影 $W_{kv_b}$

$$\text{Params}_{kv\_b} = d_{kv} \times H \times (D_{nope} + D_v) = 512 \times 64 \times (128 + 128) = 512 \times 64 \times 256 = 8{,}388{,}608 \approx 8.39\text{M}$$

从 512 维潜空间“解压”出 64 个 head、每个 head 的 128 维 nope K 和 128 维 V。

(3) Q 压缩投影 $W_{q_a}$

$$\text{Params}_{q\_a} = d \times d_q = 7168 \times 1536 = 11{,}010{,}048 \approx 11.01\text{M}$$

(4) Q nope 解压投影 $W_{q_b}$

$$\text{Params}_{q\_b} = d_q \times H \times D_{nope} = 1536 \times 64 \times 128 = 12{,}582{,}912 \approx 12.58\text{M}$$

(5) Q RoPE 直投投影 $W_{q_rope}$

$$\text{Params}_{q\_rope} = d \times H \times D_{rope} = 7168 \times 64 \times 64 = 29{,}360{,}128 \approx 29.36\text{M}$$

为什么 Q 的 rope 部分不压缩?因为 RoPE 是按维度旋转的——如果先压缩再解压,旋转操作会被破坏。所以 rope 部分直接从 hidden 维度投影,不经过压缩/解压。

(6) 输出投影 $W_o$

$$\text{Params}_{o} = H \times D_v \times d = 64 \times 128 \times 7168 = 58{,}720{,}256 \approx 58.72\text{M}$$

(7) 单层 MLA 总计

$$\text{Params}_{\text{MLA, per layer}} \approx 4.13 + 8.39 + 11.01 + 12.58 + 29.36 + 58.72 = 124.19\text{M}$$
MLA vs MHA vs GQA 对比

假设同样 $d=7168$,$H_q=64$:

架构KV 头数Attention 参数/层相对 MHA
MHA64~293.6M100%
GQA 8:18~165.2M56%
GQA 16:14~143.7M49%
MLA (K2.5)64 (压缩后)~124.2M42%

MLA 将 Attention 参数压缩到了全 MHA 的 42%,同时保持了 64 个独立 KV 头的能力(因为解压发生在注意力计算前)。

MLA 就像“快递打包”——把 64 个 KV 头的内容先折成一个小包裹(512 维潜向量)运输(存储/KV cache),到了目的地(注意力计算前)再拆开还原。包裹小所以运费低(KV cache 小),但内容还原后跟原来差不多(注意力质量高)。额外代价是打包($W_{kv_a}$)和拆包($W_{kv_b}$)的少量参数。


2.4 FFN 参数:SwiGLU vs ReLU$^2$

区分两种主流 FFN 的门控机制,正确计算参数量。SwiGLU 比 ReLU$^2$ 多 50% 参数——不知道这个区别会把 MoE 参数量算错三分之一。

2.4.1 ReLU$^2$(Nemotron 风格的“无门控 FFN”)

只有 up 和 down 两个矩阵:

$$\text{FFN}(\mathbf{x}) = W_{down} \cdot \text{ReLU}(\mathbf{x} \cdot W_{up})^2$$$$\text{Params}_{\text{ReLU}^2} = 2 \times d \times d_{ff}$$

Nemotron 共享专家($d=8192$,$d_{ff}=10240$):

$$\text{Params} = 2 \times 8192 \times 10240 = 167{,}772{,}160 \approx 167.8\text{M}$$

2.4.2 SwiGLU(标准门控 FFN)

有三个矩阵:gate、up、down。

$$\text{FFN}(\mathbf{x}) = W_{down} \cdot (\text{SiLU}(\mathbf{x} \cdot W_{gate}) \odot \mathbf{x} \cdot W_{up})$$$$\text{Params}_{\text{SwiGLU}} = 3 \times d \times d_{ff}$$

K2.5 路由专家($d=7168$,$d_{ff}=2048$):

$$\text{Params} = 3 \times 7168 \times 2048 = 44{,}040{,}192 \approx 44.0\text{M}$$

同维度下 ReLU$^2$ 只需 $2 \times 7168 \times 2048 \approx 29.4\text{M}$。

2.4.3 对比表

激活函数矩阵数公式相对参数
ReLU$^2$2$2 \times d \times d_{ff}$100%
SwiGLU3$3 \times d \times d_{ff}$150%
Non-gated SwiGLU (M3)2 (合并)$d \times 2d_{ff} + d_{ff} \times d$100% (等价于 $3 \times d \times d_{ff}$)

注意:M3 的 “non-gated SwiGLU” 将 gate 和 up 合并为 gate_up_proj(d → 2×d_ff) 一个矩阵,参数量 $d \times 2d_{ff} = 2 \times d \times d_{ff}$,与分离的 gate + up 的总参数相同($d \times d_{ff} + d \times d_{ff} = 2 \times d \times d_{ff}$)。区别是计算路径而非参数数量。

ReLU$^2$ 是"一个门+一条路",SwiGLU 是"两个独立门汇合到一条路"。多一个门就多 $d \times d_{ff}$ 个参数。检查 config.jsonhidden_act 字段——如果是 siluswigluoai,大概率是 SwiGLU(3 个矩阵);如果是 relu2,就是 ReLU$^2$(2 个矩阵)。


2.5 MoE 参数

计算 MoE 层的完整参数量——路由器 + 所有专家 + 共享专家。MoE 占模型总参数的 90%+,算错一个专家的维度会导致总参估算差出几十 B。

2.5.1 路由器

对于最简单的 sigmoid 路由(M3、Nemotron):

$$\text{Params}_{\text{router}} = d \times E$$

M3($d=6144$,$E=128$):$\text{Params}_{\text{router}} = 6144 \times 128 = 786{,}432 \approx 0.79\text{M}$

Nemotron($d=8192$,$E=512$):$\text{Params}_{\text{router}} = 8192 \times 512 = 4{,}194{,}304 \approx 4.2\text{M}$

路由器还可以包含 e_score_correction_bias($E$ 个标量,可忽略)。更复杂的路由(如 DeepSeek V4 的 hash routing)参数更大,但原理相同——最终是一个 $d \times \text{num_experts}$ 的矩阵。

2.5.2 总 MoE 参数量

$$\text{Params}_{\text{MoE, per layer}} = \underbrace{d \times E}_{\text{router}} + \underbrace{E \times \text{Params}_{\text{expert}}}_{\text{所有路由专家}} + \underbrace{\text{Params}_{\text{shared}}}_{\text{共享专家}}$$

2.5.3 完整案例:MiniMax M3

M3 有 57 个 MoE 层(第 3-59 层),每层 128 个路由专家 + 1 个共享专家。

路由专家(SwiGLU,$d_{ff}=3072$):

$$\text{Params}_{\text{expert}} = 3 \times 6144 \times 3072 = 56{,}623{,}104 \approx 56.62\text{M}$$

实际上 M3 用 non-gated SwiGLU(gate/up 合并):$6144 \times (2 \times 3072) + 3072 \times 6144 = 56.62\text{M}$,结果相同。

每层所有路由专家

$$\text{Params}_{\text{all\_experts\_per\_layer}} = 128 \times 56.62\text{M} = 7{,}247{,}757{,}312 \approx 7.25\text{B}$$

共享专家(每层 1 个):

$$\text{Params}_{\text{shared}} = 56.62\text{M}$$

每层 MoE 总计

$$\text{Params}_{\text{MoE, per layer}} = 0.79\text{M} + 7.25\text{B} + 56.62\text{M} = 7.31\text{B}$$

57 层 MoE 总计

$$\text{Params}_{\text{MoE, 57 layers}} = 57 \times 7.31\text{B} \approx 416.4\text{B}$$

这占了 M3 总参数(~428B)的 97%

MoE 的本质是“把 FFN 复制 E 份”。每份是一个完整的 FFN,参数量 = SwiGLU 的 $3 \times d \times d_{ff}$(或 ReLU$^2$ 的 $2 \times d \times d_{ff}$)。128 份 × 56M/份 × 57 层 ≈ 400B。路由器本身才 0.79M/层——跟专家的参数量比相当于“一根羽毛跟一头大象”。

2.5.4 Nemotron 的 LatentMoE(低秩专家)

Nemotron 的 MoE 有个特殊设计:专家在低秩空间 $d_{latent}=2048$ 中计算,而非全维度 8192。

每层 MoE 结构

  • 路由器:$8192 \times 512 = 4.2\text{M}$
  • 低秩投影入:$8192 \times 2048 = 16.8\text{M}$
  • 低秩投影出:$2048 \times 8192 = 16.8\text{M}$
  • 路由专家(ReLU$^2$,在 latent 空间):$2 \times 2048 \times 5120 = 20.97\text{M}$/专家
  • 512 专家:$512 \times 20.97\text{M} = 10.74\text{B}$
  • 共享专家(ReLU$^2$,在 full 空间):$2 \times 8192 \times 10240 = 167.8\text{M}$

每层 MoE 总计:$\approx 10.94\text{B}$

48 层合计:$\approx 525.4\text{B}$

Nemotron 的 MoE 是“先降维再升维”的低秩设计——hidden 从 8192 压到 2048,在 2048 维空间里做 512 个专家计算,再升回 8192。这比直接在 8192 维做专家(每个专家 $2 \times 8192 \times 5120 = 83.9\text{M}$)节省了 75% 的参数量——代价是低秩压缩的信息损失。


2.6 Mamba-2 参数(Nemotron)

计算 Mamba-2 SSD 层的参数量。Nemotron 有 48 个 Mamba 层——它不是 Attention,不能套用 QKV 公式。

2.6.1 维度推导

config.json 直接读到:

  • $d = 8192$
  • $H_{mamba} = 256$(mamba_num_heads
  • $D_{mamba} = 64$(mamba_head_dim
  • $N = 128$(ssm_state_size
  • $n_{groups} = 8$(n_groups
  • kernel = 4(conv_kernel
  • expand = 2(expand

推导内部维度:

  • $d_{inner} = \text{expand} \times d = 2 \times 8192 = 16384$
  • 验证:$H_{mamba} \times D_{mamba} = 256 \times 64 = 16384$ ← 自洽
  • $d_{conv} = d_{inner} + 2 \times n_{groups} \times N = 16384 + 2 \times 8 \times 128 = 16384 + 2048 = 18432$

2.6.2 逐项参数

(1) in_proj(输入投影,一投多产)

Mamba 的 in_proj 一次性投影出所有需要的分量:$x$、$z$、$B$、$C$、$\Delta$ 的参数。

$$\text{Params}_{\text{in\_proj}} = d \times (d_{inner} + d_{conv} + H_{mamba}) = 8192 \times (16384 + 18432 + 256)$$$$= 8192 \times 35072 = 287{,}309{,}824 \approx 287.3\text{M}$$

分解:35072 = 16384($x$ 和 $z$ 各 $d_{inner}$,共 $2 \times d_{inner}$)+ 18432($B$ 和 $C$ 的 $d_{conv}$)+ 256($\Delta$ 的 $H_{mamba}$)。

等等,让我重新梳理。$2 \times d_{inner} = 32768$,$2 \times n_{groups} \times N = 2048$(B 和 C),$H_{mamba} = 256$(Δ)。合计 $32768 + 2048 + 256 = 35072$。自洽。

(2) conv1d(深度卷积)

$$\text{Params}_{\text{conv1d}} = d_{conv} \times \text{kernel} + d_{conv} = 18432 \times 4 + 18432 = 92{,}160 \approx 0.09\text{M}$$

深度卷积(每个通道独立卷积核),参数极少。

(3) A_logDdt_bias(SSM 内部标量)

$$\text{Params}_{A\_log} = H_{mamba} = 256$$$$\text{Params}_{D} = H_{mamba} = 256$$$$\text{Params}_{dt\_bias} = H_{mamba} = 256$$

三个加起来不到 1000 个参数——完全可以忽略。

(4) out_proj(输出投影)

$$\text{Params}_{\text{out\_proj}} = d_{inner} \times d = 16384 \times 8192 = 134{,}217{,}728 \approx 134.2\text{M}$$

2.6.3 单层 Mamba-2 汇总

组件参数量
in_proj287.3M
conv1d0.09M
A_log + D + dt_bias~0.001M
out_proj134.2M
单层合计~421.6M

48 层 Mamba-2 合计:$48 \times 421.6\text{M} \approx 20.2\text{B}$

Mamba 的 in_proj 是“一拖多”——一个矩阵输出 5 件事(x, z, B, C, Δ),所以它特别胖(8192 × 35072 = 287M)。out_proj 再把它收回来。其余部分(卷积、状态标量)几乎不占参数。对比一下:Nemotron 的 Attention 层(138M)比 Mamba 层(422M)便宜 3 倍


2.7 Vision Encoder 参数(M3 / K2.5)

计算 ViT 编码器和投影器的参数量。VL 模型的视觉编码器通常有 0.6-2B 参数,在算总参和激活参时都要考虑。

2.7.1 MiniMax M3 视觉编码器

ViT 32 层(vision_config),$d_{vit}=1280$,$H_{vit}=16$,$D_{vit}=1280/16=80$,$d_{ff}^{vit}=5120$。

每层 Attention(标准 MHA,无 GQA):

$$\text{Params}_{\text{ViT attn}} = 4 \times (d_{vit} \times H_{vit} \times D_{vit}) = 4 \times (1280 \times 16 \times 80)$$$$= 4 \times 1{,}638{,}400 = 6{,}553{,}600 \approx 6.55\text{M}$$

每层 MLP(GELU,2 个矩阵):

$$\text{Params}_{\text{ViT mlp}} = 2 \times d_{vit} \times d_{ff}^{vit} = 2 \times 1280 \times 5120 = 13{,}107{,}200 \approx 13.11\text{M}$$

每层合计:19.66M。32 层:$\approx 629\text{M} \approx 0.63\text{B}$。
加上 patch embedding(Conv3d)+ Pre-LN + 3D RoPE:$\approx 0.65\text{B}$。

投影器(双阶段 MLP,$d_{vit} \to d \to d$ + spatial merge):

$$\text{Stage 1}: 1280 \times 6144 + 6144 \times 6144 \approx 7.86\text{M} + 37.75\text{M} = 45.6\text{M}$$$$\text{Stage 2}: (4 \times 6144) \times 6144 + 6144 \times 6144 \approx 150.99\text{M} + 37.75\text{M} = 188.7\text{M}$$$$\text{Params}_{\text{projector}} \approx 0.23\text{B}$$

视觉总计:$\approx 0.88\text{B}$。

2.7.2 Kimi K2.5 视觉编码器

ViT 27 层(vision_config.vt_num_hidden_layers),$d_{vit}=1152$,$H_{vit}=16$,$d_{ff}^{vit}=4304$。

每层 Attention

$$\text{Params}_{\text{ViT attn}} = 4 \times (1152 \times 16 \times 72) = 4 \times 1{,}327{,}104 \approx 5.31\text{M}$$

($D_{vit} = 1152/16 = 72$,但 config 中 mm_hidden_size=1152vt_hidden_size=1152,需验证 head_dim = 1152/16 = 72)

每层 MLP

$$\text{Params}_{\text{ViT mlp}} = 2 \times 1152 \times 4304 \approx 9.92\text{M}$$

每层合计:~15.23M。27 层:$\approx 0.41\text{B}$。加 PatchMerger 和投影器共约 2B。

ViT 就是一个小号 Transformer。算它跟算文本骨干的方法完全一样——QKV+O + MLP up/down,只是维度小得多(1152/1280 vs 6144/7168)。但 27-32 层加起来也有 ~0.6-2B 参数,不容忽略。


2.8 MTP Predictor 参数

计算 Multi-Token Prediction 模块的参数。MTP 模块不算在激活参数里(推理时是独立的投机解码模块),但算总参时不能漏。

MTP 模块的结构与主干的单个 layer 相同:1 个 Attention + 1 个 MoE(或 Mamba)。

Nemotron 3 Ultra(1 个 MTP 层,类型 ["attention", "moe"]):

$$\text{Params}_{\text{MTP}} = \text{Params}_{\text{attn}} + \text{Params}_{\text{MoE, 1 layer}} \approx 138.4\text{M} + 10.94\text{B} \approx 11.1\text{B}$$

MiniMax M3(7 个 MTP 模块,num_mtp_modules=7,每个含 1 layer):

M3 的 MTP 模块共享 Embedding 和 LM Head,每个模块的结构和主干层类似但维度可能不同。精确参数量需从源码确认,当前一级近似:

$$\text{Params}_{\text{MTP, per module}} \approx \text{Params}_{\text{attn}} + \text{Params}_{\text{MoE, 1 layer}} \approx 111\text{M} + 7.31\text{B} \approx 7.42\text{B}$$

7 个模块:$\approx 52\text{B}$。但官方标称 MTP 不显著增加推理显存(因为 MTP 权重可能与主干有共享或使用更小的维度),实际数值以官方技术报告为准。

设计意图待确认:M3 的 _keys_to_ignore_on_load_unexpected: [r"mtp.*"] 表明 MTP 权重在独立命名空间下。参数可能比主干层小(使用更小的 intermediate 维度),或通过参数共享减少总量。

MTP 就是“多长了几层”——如果是 1 个 MTP 模块,等于多 1 个 Attention + 1 个 MoE 层。如果 7 个 MTP 模块就是多 7 层。区别在于 MTP 只用于预测 future tokens,不是 backbone 的一部分。


2.9 激活参 vs 总参

区分“模型存了多少参数”和“每次前向要用多少参数”。推理显存 = 激活参数 × bytes/param + KV cache + 其他。不懂激活参就算不了推理成本。

核心概念

  • 总参数量(Total Params):所有权重矩阵的元素总数。模型文件的大小。
  • 激活参数量(Active Params):单次前向传播实际参与计算的参数。MoE 中只激活 top-k 专家。
$$\text{Params}_{\text{active}} = \text{Params}_{\text{non-MoE}} + \text{Params}_{\text{router}} + k \times \text{Params}_{\text{expert}} + \text{Params}_{\text{shared}}$$$$\text{激活率} = \frac{\text{Params}_{\text{active}}}{\text{Params}_{\text{total}}} \times 100\%$$

案例 1:Nemotron 3 Ultra

组件总参 (B)激活参 (B)说明
Embedding + LM Head2.152.15全激活
48 Mamba-2 层20.2420.24全激活(无 MoE)
12 Attention 层1.661.66全激活
48 MoE 层 (512E, top-22)525.432.04只激活 22/512
MTP Predictor11.1不计入独立模块,推理时按需使用
Norm 等~0.001~0.001
总计~560B~56B
官方标称550B55B偏差 ~2%
$$\text{激活率} = \frac{55}{550} = 10\%$$

Nemotron 虽然存了 550B 参数,但每次只用其中 55B——因为 48 个 MoE 层每层只在 512 个专家中激活 22 个(4.3%)。剩下 95.7% 的专家参数“休眠”。这就是 MoE 的核心价值:总容量大,推理成本低。

案例 2:MiniMax M3

$$\text{Params}_{\text{active}} \approx 1.23\text{B} + 6.64\text{B} + 0.68\text{B} + 12.91\text{B} + 3.23\text{B} + 1.23\text{B} \approx 25.9\text{B}$$

(Embedding + Attention + Dense FFN + 4/128 专家激活 + 共享专家 + LM Head)

$$\text{激活率} = \frac{25.9}{428} \approx 6.0\%$$

加上 Vision 编码器(0.88B,图像输入时激活)约为 26.8B。官方标称 ~23B。

各模型激活率对比

模型总参激活参激活率每 token 专家激活比例
Nemotron 3 Ultra550B~55B10.0%22/512 = 4.3%
MiniMax M3~428B~23-26B5.4-6.0%4/128 = 3.1%
Kimi K2.5~1T~32B3.2%8/385 ≈ 2.1%

激活率越高,同等总参下推理越贵。Nemotron 的 10% 激活率看起来高,但因为它有 48 个 Mamba 层(无稀疏化),这些层每个 token 都要全部跑一遍。M3 和 K2.5 的激活率更低是因为它们几乎所有层的 FFN 都是 MoE。


2.10 完整案例:Nemotron 3 Ultra 参数分解

config.json 出发,一步步列出每类模块的参数,求和验证 ≈ 550B。这是本章所有知识的综合运用——读完你应该能对任何模型做同样的事。

Step 0:读取 config.json

关键字段值(见 1.1 节表)。

Step 1:Embedding

$$131072 \times 8192 = 1.07\text{B (输入)} + 1.07\text{B (输出)} = \mathbf{2.15\text{B}}$$

Step 2:48 个 Mamba-2 层

$$48 \times (287.3\text{M} + 0.09\text{M} + 134.2\text{M}) = 48 \times 421.6\text{M} = \mathbf{20.24\text{B}}$$

Step 3:12 个 Attention 层(GQA 32:1)

$$12 \times (67.1\text{M} + 2.1\text{M} + 2.1\text{M} + 67.1\text{M}) = 12 \times 138.4\text{M} = \mathbf{1.66\text{B}}$$

Step 4:48 个 MoE 层

每层:

  • Router: $8192 \times 512 = 4.2\text{M}$
  • 低秩投影入+出: $8192 \times 2048 + 2048 \times 8192 = 33.6\text{M}$
  • 512 专家 (ReLU$^2$, latent 空间): $512 \times 2 \times 2048 \times 5120 = 10,737.4\text{M} = 10.74\text{B}$
  • 共享专家 (ReLU$^2$, full 空间): $2 \times 8192 \times 10240 = 167.8\text{M}$

单层:$4.2\text{M} + 33.6\text{M} + 10,737.4\text{M} + 167.8\text{M} = 10,943\text{M} \approx 10.94\text{B}$

48 层:$48 \times 10.94\text{B} = \mathbf{525.4\text{B}}$

Step 5:MTP Predictor(1 attention + 1 moe)

$$\mathbf{11.1\text{B}}$$

Step 6:求和

模块参数 (B)占比
Embedding + LM Head2.150.4%
48 Mamba-2 层20.243.6%
12 Attention 层1.660.3%
48 LatentMoE 层525.494.1%
MTP Predictor11.12.0%
Norm 等~0.001~0%
直接求和~560.5
官方标称550

偏差 ~1.9%,可能来源:MTP 权重有部分与主干共享;部分维度在实现中与 config 有细微差异;NVFP4 训练下的有效参数量口径不同。

Step 7:激活参验证

$$\text{Active} = 2.15 + 20.24 + 1.66 + 48 \times (4.2\text{M} + 33.6\text{M} + 22 \times 21\text{M} + 167.8\text{M}) \div 10^3$$$$= 2.15 + 20.24 + 1.66 + 48 \times 0.6675\text{B}$$$$= 2.15 + 20.24 + 1.66 + 32.04 = \mathbf{56.1\text{B}}$$

与官方 55B 偏差 ~2%。扣除 MTP(11.1B)后 backbone 激活 ≈ 56B,与标称一致。

自查清单(算完参数量后对照):

  • Embedding = vocab_size × hidden_size?weight tying 只乘一次?
  • GQA 的 K/V 矩阵是 d × (H_kv × D_h) 不是 d × d
  • SwiGLU 是 3 个矩阵(gate/up/down),ReLU² 是 2 个?
  • MoE = Router + N_experts × expert + shared_expert(别忘了 Router)?
  • 各项之和 ≈ 官方标称值(允许 1-2% 偏差)?
  • 激活参 ≠ 总参?激活率通常在 3-10%?

2.11 速查表:从 config.json 到参数量

给一张“抄作业”级别的公式汇总表。以后算任何模型,打开这张表逐行代入即可。

模块公式适用条件
Embedding (in)$V \times d$总是
Embedding (out)$V \times d$tie_word_embeddings=false
MHA Attention$4 \times d^2$$H_{kv}=H_q$ 且 $H_q \times D_h = d$
GQA Attention$d \times (H_q \times D_h) + 2 \times d \times (H_{kv} \times D_h) + (H_q \times D_h) \times d$通用
MLA (Q 侧)$d \times d_q + d_q \times H \times D_{nope} + d \times H \times D_{rope}$kv_lora_rankq_lora_rank 存在时
MLA (KV 侧)$d \times (d_{kv} + D_{rope}) + d_{kv} \times H \times (D_{nope} + D_v)$同上
MLA (output)$H \times D_v \times d$同上
SwiGLU FFN$3 \times d \times d_{ff}$hidden_act=silu 且 gate/up/down 分离
ReLU$^2$ FFN$2 \times d \times d_{ff}$hidden_act=relu2
MoE Router$d \times E$总是
MoE 总/层$d \times E + E \times \text{Params}{expert} + \text{Params}{shared}$总是
Mamba-2 in_proj$d \times (2d_{inner} + 2n_{groups}N + H_{mamba})$model_type 含 mamba
Mamba-2 out_proj$d_{inner} \times d$同上
Dense FFN同 SwiGLU/ReLU$^2$,见 intermediate_size / dense_intermediate_sizemoe_layer_freq[i]=0 的层
RMSNorm$d$每层 2 个,可忽略
激活参$\text{非MoE} + \text{Router} + k \times \text{Params}{expert} + \text{Params}{shared}$MoE 模型
总参上述所有求和

实战口诀

  1. 打开 config.json,圈出 $d, V, H_q, H_{kv}, D_h, d_{ff}, d_{moe_ff}, E, k$
  2. Embedding: $2Vd$(如果 tie_word_embeddings=false
  3. Attention/层: 查 GQA/MHA/MLA 公式
  4. FFN/层: 查 hidden_act 决定 ×2 还是 ×3
  5. MoE: $E \times$ FFN/层 + router
  6. 乘层数,加 MTP,加 Vision
  7. 总参 = 以上求和;激活参 = 非 MoE + $k \times$ 单专家
  8. 显存 = 激活参 × bytes/param(见 1.5 节)

术语中英对照

中文英文config 字段
隐藏维度hidden size / model dimensionhidden_size
注意力头数number of attention headsnum_attention_heads
KV 头数number of key-value headsnum_key_value_heads
每头维度head dimensionhead_dim
中间维度intermediate sizeintermediate_size
词表大小vocabulary sizevocab_size
路由专家routed expertsn_routed_experts
共享专家shared expertsn_shared_experts
每 token 专家数experts per tokennum_experts_per_tok
激活参数active / activated parameters
总参数total parameters
权重绑定weight tyingtie_word_embeddings
分组查询注意力Grouped Query Attention (GQA)$H_{kv} < H_q$
多头潜注意力Multi-head Latent Attention (MLA)kv_lora_rank 存在
状态空间模型State Space Model (SSM)ssm_state_size 存在
多 token 预测Multi-Token Prediction (MTP)num_nextn_predict_layers

CH1-2 常见计算错误

#常见错误正确做法
1用 $d$ 代替 $H_q \times D_h$ 算 K 投影GQA 中 K 的维度是 $H_{kv} \times D_h$,不是 $d$
2忘记 MLA 的 rope 投影不能压缩rope 部分用 $d \times H \times D_{rope}$,不经过潜空间
3混淆 intermediate_sizemoe_intermediate_sizeDense 层和 MoE 专家层可能用不同维度
4忘记乘 bytes参数量是“个数”,显存是“字节数”,中间差 2×(BF16)
5Mamba 的 $d_{inner}$ 没验证自洽$d_{inner} = \text{expand} \times d = H_{mamba} \times D_{mamba}$
6漏掉了 LM Headtie_word_embeddings=false 时 LM Head 是独立矩阵
7Router 参数当成 0虽然小(几 M),但要把所有层加起来
8激活参计算时忘记共享专家共享专家对每个 token 都激活,不算在 top-k 里

下一章预告:CH 3 FLOPs 估算——从参数量到计算量,推导 prefill/decode 的单 token FLOPs 公式,并给出 Nemotron/M3/K2.5 的完整 FLOPs 分解表。


CH 3 FLOPs 估算

读者定位:已完成 CH1-2 的参数计算,目标是推导 prefilling / decoding 的单 token 计算量,并理解不同架构(Full Attn / MSA / MLA / Mamba-2)的 FLOPs 差异根源。


3.1 通用原理

建立“前向 FLOPs = 所有权重矩阵乘法之和”的底层逻辑。参数量是“模型存了多少数”,FLOPs 是“每次前向要算多少下”——两者直接决定推理延迟和硬件成本。

核心公式

单层 FLOPs = 该层内所有矩阵乘法的 $2 \times m \times n \times k$ 之和(见 1.2 节)。

$$\text{FLOPs}_{\text{total}} = \sum_{l=1}^{L} \text{FLOPs}_{\text{attn}}^{(l)} + \text{FLOPs}_{\text{ffn}}^{(l)} + \text{FLOPs}_{\text{norm}}^{(l)}$$

其中 norm(RMSNorm / LayerNorm)的 FLOPs 为 $4 \times d$(乘 $\gamma$ + 加 $\beta$),在大模型中可忽略($d=8192$ 时 $\approx 32\text{K FLOPs}$,而 Q 投影是 $\approx 134\text{M FLOPs}$)。

Prefill vs Decode

  • Prefill:输入 $T_{in}$ 个 token,所有层对所有 token 完整计算一次。总 FLOPs 正比于 $T_{in}$(线性部分)或 $T_{in}^2$(注意力部分)。
  • Decode:每次只产生 1 个新 token,但需要 attend 到所有历史 token($T_{total}$)。只有新 token 的 QKV 需要投影,但 QK 点积和 V 加权要覆盖全部历史。
$$\text{FLOPs}_{\text{decode\_per\_token}} = \sum_{l=1}^{L} \text{FLOPs}_{\text{new\_token}}^{(l)}$$

Prefill 是“一口气读完整本书再回答问题”,Decode 是“每次多读一个字就要把所有笔记翻一遍”。前者吞吐高但延迟长,后者每步轻量但被历史长度拖累。Attention 的 O(T²) 项只在 Prefill 是全量爆炸,Decode 时变成 O(T)(因为只有 1 个 query)。

单 Token FLOPs 计算范式

对每个矩阵乘法,固定范式为:

$$\text{FLOPs} = 2 \times (\text{输出第一维}) \times (\text{输出第二维}) \times (\text{被缩并的公共维度})$$

案例:Attention 层 Q 投影,输入 hidden $[1, d]$,权重 $W_Q [d, H_q \times D_h]$:

$$\text{FLOPs}_Q = 2 \times 1 \times (H_q \times D_h) \times d$$

Nemotron 12 个 Attention 层之一($d=8192$,$H_q=64$,$D_h=128$):

$$\text{FLOPs}_Q = 2 \times 1 \times (64 \times 128) \times 8192 = 2 \times 8192 \times 8192 = 134{,}217{,}728 \approx 134.2\text{M FLOPs/token}$$

每产生一个 token,Q 投影就要把 8192 维向量乘上 $8192 \times 8192$ 的矩阵——相当于做 8192 次 8192 维的内积。这就是一个 token 经过一层 Attention 的“起步价”。


3.2 Full Attention FLOPs

逐项拆解标准 Attention(含 GQA)的四部分 FLOPs,区分线性项和平方项。不理解 O(T²) 项从哪里来,就无法理解为什么长上下文推理会变慢——以及为什么 MSA、Mamba 等替代架构有意义。

3.2.1 QKV 投影(线性项,O(T))

投影部分在 Prefill 时随 T 线性增长,在 Decode 时是常数(只投影新 token)。

$$\text{FLOPs}_{Q} = 2 \times d \times (H_q \times D_h) \times T_{\text{new}}$$$$\text{FLOPs}_{K} = 2 \times d \times (H_{kv} \times D_h) \times T_{\text{new}}$$$$\text{FLOPs}_{V} = 2 \times d \times (H_{kv} \times D_h) \times T_{\text{new}}$$

GQA 的精髓:K 和 V 投影的输出维度是 $H_{kv} \times D_h$ 而非 $H_q \times D_h$——这是 GQA 相比于 MHA 在计算量(而不仅是参数量)上的直接节省。

案例代入:Nemotron Attention 层(GQA 32:1,$d=8192$,$H_q=64$,$H_{kv}=2$,$D_h=128$)。

Prefill($T=4096$)

$$\text{FLOPs}_{Q} = 2 \times 8192 \times (64 \times 128) \times 4096 = 2 \times 8192 \times 8192 \times 4096$$$$= 2 \times 67{,}108{,}864 \times 4096 = 549{,}755{,}813{,}888 \approx 550 \text{ GFLOPs}$$$$\text{FLOPs}_{K} = 2 \times 8192 \times (2 \times 128) \times 4096 = 2 \times 8192 \times 256 \times 4096$$$$= 2 \times 2{,}097{,}152 \times 4096 = 17{,}179{,}869{,}184 \approx 17.2 \text{ GFLOPs}$$$$\text{FLOPs}_{V} = 17.2 \text{ GFLOPs} \quad (\text{与 K 相同})$$

Prefill 一次性投影所有 4096 个 token 的 Q、K、V。注意 K 投影(17 GFLOPs)只占 Q 投影(550 GFLOPs)的约 3%——因为 $H_{kv} = 2$ 只有 $H_q = 64$ 的 1/32。

Decode($T_{\text{new}}=1$,$T_{\text{total}}=1\text{M}$)

$$\text{FLOPs}_{Q} = 2 \times 8192 \times (64 \times 128) \times 1 = 134{,}217{,}728 \approx 134.2\text{M FLOPs}$$$$\text{FLOPs}_{K} = 2 \times 8192 \times (2 \times 128) \times 1 = 4{,}194{,}304 \approx 4.2\text{M FLOPs}$$$$\text{FLOPs}_{V} = 4.2\text{M FLOPs}$$

QKV 投影在 decode 时总共 $\approx 142.6\text{M FLOPs}$——与上下文长度无关

QKV 投影就像“打字”——每个新 token 只需要把自己的向量投影一次。历史 token 的 K 和 V 投影结果被缓存在 KV cache 里,不用重算。

3.2.2 QK 点积(平方项,O(T²) 的根源)

$$\text{FLOPs}_{\text{QK}} = 2 \times H_q \times T_{\text{new}} \times T_{\text{total}} \times D_h$$

Prefill($T=T_{\text{new}}=T_{\text{total}}=4096$,causal mask 下约计算一半)

$$\text{FLOPs}_{\text{QK}} = 2 \times 64 \times 4096 \times \frac{4096}{2} \times 128 = 2 \times 64 \times 4096 \times 2048 \times 128$$$$= 2 \times 64 \times 8{,}388{,}608 \times 128 = 137{,}438{,}953{,}472 \approx 137 \text{ GFLOPs}$$

(精确无 causal 时为 275 GFLOPs,causal mask 下约折半。)

Decode($T_{\text{new}}=1$,$T_{\text{total}}=1\text{M}$)——这就是长上下文问题的核心

$$\text{FLOPs}_{\text{QK}} = 2 \times 64 \times 1 \times 1{,}000{,}000 \times 128$$$$= 2 \times 64 \times 128 \times 10^6 = 16{,}384 \times 10^6 = 1.6384 \times 10^{10} \approx 16.4 \text{ GFLOPs}$$

当上下文达到 1M tokens 时,仅一个 Attention 层的 QK 点积就需要 164 亿次浮点运算。对于有 12 个 Attention 层的 Nemotron:$12 \times 16.4 \approx 197 \text{ GFLOPs}$,仅此一项就超过了 QKV 投影(12 × 142.6M ≈ 1.7 GFLOPs)两个数量级。

QK 点积是把新 token 的一个 query 与缓存中所有 1M 个 key 逐一算相似度。1M 个 key,每个 128 维,每个维度一次乘法+一次加法=$2 \times 128 = 256$ FLOPs,64 个 head 各做一次,总计就是 $64 \times 1\text{M} \times 256 = 16.4\text{GFLOPs}$。这就是 Attention 在长上下文下“喘不过气”的根本原因。

3.2.3 V 加权(同样是 O(T) 项,decode 中体量等于 QK)

$$\text{FLOPs}_{\text{V}} = 2 \times H_q \times T_{\text{new}} \times T_{\text{total}} \times D_h$$

Decode($T_{\text{new}}=1$,$T_{\text{total}}=1\text{M}$)

$$\text{FLOPs}_{\text{V}} = 2 \times 64 \times 1 \times 1{,}000{,}000 \times 128 = 16.4 \text{ GFLOPs}$$

与 QK 点积等量级!原因:注意力权重要乘上 V 矩阵——1M 个 value 向量,每个 128 维,64 个 head。计算量路径:$[1, 64, 1, 1\text{M}] \times [1, 64, 1\text{M}, 128] \to [1, 64, 1, 128]$,缩并维度是 1M。

算完“每个历史 token 有多相关”(QK 点积)后,还要把 1M 个 value 向量按相关性加权平均。这个“加权平均”的运算量跟“计算相似度”一样大——都是 $2 \times H \times T \times D_h$。所以 Attention 的 decode 成本 = QK + V ≈ $4 \times H \times T \times D_h$。

3.2.4 输出投影(线性项,O(T))

$$\text{FLOPs}_{\text{O}} = 2 \times d \times (H_q \times D_h) \times T_{\text{new}}$$

decode 时为常数(Nemotron):$\text{FLOPs}_O = 2 \times 8192 \times 8192 \times 1 = 134.2\text{M FLOPs}$

与 Q 投影相同——因为输入和输出的维度都是 $d \times d$。

3.2.5 单层 Full Attention Decode FLOPs 汇总

以 Nemotron Attention 层(GQA 32:1,T=1M)为例:

组件公式$T=1\text{M}$ 时 FLOPs占比
Q 投影$2 \times d \times (H_q \times D_h)$134.2M0.4%
K 投影$2 \times d \times (H_{kv} \times D_h)$4.2M0.01%
V 投影$2 \times d \times (H_{kv} \times D_h)$4.2M0.01%
QK 点积$2 \times H_q \times T \times D_h$16.4G49.7%
V 加权$2 \times H_q \times T \times D_h$16.4G49.7%
O 投影$2 \times d \times (H_q \times D_h)$134.2M0.4%
单层合计~33.1G100%

关键观察:在 1M 上下文下,Attention 层 99.4% 的计算量花在 QK 点积和 V 加权上——这两个 O(T) 项(decode 时)。投影部分是常数,可以忽略。任何想加速长上下文推理的架构,都是从这两个 O(T) 项下手。

3.2.6 GQA 对 FLOPs 的影响

GQA 降低了 K 和 V 投影的 FLOPs($H_{kv}$ 替代 $H_q$),但不降低 QK 点积和 V 加权的 FLOPs。原因是 K 和 V 在注意力计算前会被 repeat_kv 扩展到与 Q 相同的头数:

1
2
# 标准 GQA 实现(transformers 源码)
K = K.repeat_interleave(H_q // H_kv, dim=1)  # [B, H_kv, T, D] -> [B, H_q, T, D]

所以 QK 点积的规模仍然是 $2 \times H_q \times T \times D_h$——与 MHA 完全相同

GQA 节省的是:

  • K、V 投影的 FLOPs(节省比例 $\frac{H_q}{H_{kv}}$ 倍,如 64/2=32×)
  • KV cache(同样 32×)

GQA 节省的 不是

  • QK 点积的 FLOPs
  • V 加权的 FLOPs

GQA 就像“出版社印了 64 份杂志(Q head),但只审了 2 份稿子(KV head),审稿费省了 32×,但印杂志的成本(读者阅读 = QK 点积)没省——因为每份杂志都要发给所有读者看。”


3.3 MSA 稀疏 Attention FLOPs(MiniMax M3)

推导 M3 的 Multi-stage Sparse Attention 计算量,理解“用廉价 Index Branch 筛选 + 昂贵 Main Branch 只在筛选区域计算”的 FLOPs 逻辑。M3 在 1M 上下文时实现约 30× 的 decode 加速——这是稀疏 Attention 的标杆案例。

3.3.1 MSA 架构概述

M3 的 MSA 将 Attention 分为两个分支:

  • Index Branch(廉价筛选器):用少量 head($H_{\text{idx}} = 4$)在全部 T 个 token 上做 QK 评分 + max-pool + top-k,选出 16 个 block(每 block 128 token,共 $16 \times 128 = 2048$ 个候选 token)。
  • Main Branch(精准计算器):用全部 head($H_q = 64$)只在 2048 个入选 token 上做完整 Attention。

M3 有 60 层:3 层 Full Attention(Layer 0-2)+ 57 层 MSA(Layer 3-59)。

3.3.2 Index Branch FLOPs

维度回顾:$d = 6144$,$H_{\text{idx}} = 4$,$D_{\text{idx}} = 128$,$H_q = 64$,$D_h = 128$。

(1) Index Q 投影

$$\text{FLOPs}_{\text{idx\_Q}} = 2 \times d \times (H_{\text{idx}} \times D_{\text{idx}}) \times T_{\text{new}}$$

Decode($T_{\text{new}}=1$):

$$\text{FLOPs}_{\text{idx\_Q}} = 2 \times 6144 \times (4 \times 128) \times 1 = 2 \times 6144 \times 512 = 6{,}291{,}456 \approx 6.3\text{M FLOPs}$$

(2) Index K 投影

Index K 只有一个 head 的维度(128),4 个 index head 共享同一个 K:

$$\text{FLOPs}_{\text{idx\_K}} = 2 \times d \times D_{\text{idx}} \times T_{\text{new}} = 2 \times 6144 \times 128 \times 1 = 1{,}572{,}864 \approx 1.6\text{M FLOPs}$$

(3) Index QK 评分(O(T²) in prefill,O(T) in decode)

这是 Index Branch 的计算主体。Index Branch 用 4 个 head 在全序列上做 QK 点积。

Decode($T_{\text{new}}=1$,$T_{\text{total}}=1\text{M}$)

$$\text{FLOPs}_{\text{idx\_QK}} = 2 \times H_{\text{idx}} \times 1 \times T \times D_{\text{idx}} = 2 \times 4 \times 1 \times 10^6 \times 128$$$$= 2 \times 512 \times 10^6 = 1{,}024{,}000{,}000 \approx 1.02\text{ GFLOPs}$$

对比 Full Attention 的 QK 点积(如果用全部 64 个 head 做全序列评分):

$$\text{FLOPs}_{\text{full\_QK}} = 2 \times 64 \times 1 \times 10^6 \times 128 = 16{,}384 \times 10^6 \approx 16.4\text{ GFLOPs}$$

Index Branch 的 QK 评分仅需要 1.02 GFLOPs,而 Full Attention 需要 16.4 GFLOPs——减少了 16×。原因直截了当:4 个 head vs 64 个 head,$64/4 = 16$。

这就是 Index Branch 设计的精妙之处:用 16× 更便宜的计算,筛选出哪些 token 值得做完整的 64-head Attention。

(4) Max-pool + Top-k

Max-pool 将分数按 block 聚合(每 128 token 一个 block,共 $T/128$ 个 block),再选出 top-16 个 block。这部分本质是遍历和排序,FLOPs $\approx T/128 \times \log(16)$,约 $10^4$ 级别,完全可忽略。

3.3.3 Main Branch FLOPs

Main Branch 的核心:只在入选的 2048 个 token 上做完整 Attention。

$$\text{访问 token 数} = \text{block\_size} \times \text{top\_k\_blocks} = 128 \times 16 = 2048$$

(1) Main QK 点积

$$\text{FLOPs}_{\text{main\_QK}} = 2 \times H_q \times T_{\text{new}} \times T_{\text{selected}} \times D_h$$$$\text{Decode} = 2 \times 64 \times 1 \times 2048 \times 128 = 2 \times 64 \times 262{,}144$$$$= 33{,}554{,}432 \approx 33.6\text{M FLOPs}$$

关键对比:Full Attention 的 QK = $16.4\text{G FLOPs}$,MSA Main QK = $33.6\text{M FLOPs}$。加速比 = $16.4\text{G} / 33.6\text{M} \approx 488\times$(T=1M 时,仅 QK 部分)。

(2) Main V 加权

$$\text{FLOPs}_{\text{main\_V}} = 2 \times H_q \times T_{\text{new}} \times T_{\text{selected}} \times D_h = 33.6\text{M FLOPs}$$

与 Main QK 对称。

3.3.4 MSA 单层 Decode FLOPs 汇总(T=1M)

组件FLOPs类别
Index Q 投影6.3M常数
Index K 投影1.6M常数
Index QK 评分1.02GO(T),但 16× 小
Index max-pool + top-k~0可忽略
Main Q 投影$2 \times 6144 \times (64 \times 128) = 100.7\text{M}$常数
Main K 投影$2 \times 6144 \times (4 \times 128) = 6.3\text{M}$常数(GQA 16:1)
Main V 投影6.3M常数
Main QK 点积33.6M常数(仅 2048 个 token)
Main V 加权33.6M常数
Main O 投影$2 \times 6144 \times (64 \times 128) = 100.7\text{M}$常数
总计~1.31G

对比 Full Attention 层的 $\approx 33.1\text{G FLOPs}$(相同 $d$, $H_q$ 配置在 T=1M 下),MSA 单层仅需 $\approx 1.31\text{G FLOPs}$——加速约 25×

MSA 单层最大的开销是 Index QK 评分(1.02G,占 78%),这一项仍然随 T 线性增长——但它是用 4 个 head 而非 64 个,系数差距是 16×。

3.3.5 总体加速比

Decode 场景(T=1M)

对于 M3 的 57 层 MSA + 3 层 Full Attention:

  • 3 层 Full Attention:$3 \times 33.1\text{G} \approx 99.3\text{G FLOPs}$($d=6144$, $H_q=64$, $H_{kv}=4$)
  • 57 层 MSA:$57 \times 1.31\text{G} \approx 74.7\text{G FLOPs}$
  • 总计:$\approx 174\text{G FLOPs}$ 用于 Attention 部分

假如同样的 60 层全部是 Full Attention:

  • $60 \times 33.1\text{G} \approx 1986\text{G FLOPs} \approx 1.99\text{T FLOPs}$
  • 加速比 $\approx 1986 / 174 \approx 11.4\times$(仅 Attention 部分)

Prefill 场景(T=1M,causal),加速更显著:

  • Index QK 的 O(T²) 部分:$2 \times 4 \times (10^6)^2/2 \times 128 \approx 5.12 \times 10^{14}$ FLOPs/层
  • Full Attention QK 的 O(T²) 部分:$2 \times 64 \times (10^6)^2/2 \times 128 \approx 8.19 \times 10^{15}$ FLOPs/层
  • Main Branch QK:$2 \times 64 \times 10^6 \times 2048 \times 128 \approx 3.36 \times 10^{13}$ FLOPs/层(常数,不随 T² 增长)
  • 加速比 $\approx 8.19 \times 10^{15} / (5.12 \times 10^{14} + 3.36 \times 10^{13}) \approx 15\times$(仅 QK 部分)

综合其他恒定开销,实际整体 decode 加速约 2-5×,Prefill 加速约 10-20×(取决于序列长度和 overhead 比例)。论文声称的 30× 是 decode 场景下 Attention 部分 QK+V 的加速。

MSA 的哲学是“先粗筛再精算”。花 1 GFLOPs(Index Branch)扫一眼全场,发现最有戏的 2048 个 token,然后花 67 MFLOPs(Main QK+V)在这 2048 个 token 上精算。而 Full Attention 要花 33 GFLOPs 在所有 1M 个 token 上精算。前者总花费 $\approx 1.1\text{G}$,后者 $\approx 33\text{G}$,高下立判。


3.4 MLA FLOPs(Kimi K2.5 / DeepSeek V4)

推导 Multi-head Latent Attention 的 FLOPs,区分低秩投影的线性节省和 QK 点积的不变性。MLA 的卖点是“省 KV cache”而非“省 FLOPs”——但低秩投影确实也节省了一部分线性 FLOPs。

3.4.1 MLA 计算流程回顾

以 Kimi K2.5 为例($d=7168$,$d_{kv}=512$,$d_q=1536$,$H=64$,$D_{\text{nope}}=128$,$D_{\text{rope}}=64$,$D_v=128$):

MLA 的两阶段计算

  1. 压缩阶段:hidden $\to$ latent($W_{kv_a}$, $W_{q_a}$)
  2. 解压阶段:latent $\to$ per-head K, V, Q($W_{kv_b}$, $W_{q_b}$)
  3. RoPE 直接投影:hidden $\to$ per-head Q/K rope($W_{q_rope}$,不经过 latent)

3.4.2 KV 侧 FLOPs(线性项节省的来源)

(1) KV 压缩投影 $W_{kv_a}$

$$W_{kv\_a}: [d] \to [d_{kv} + D_{\text{rope}}] = 7168 \to 512 + 64 = 576$$$$\text{FLOPs}_{kv\_a} = 2 \times d \times (d_{kv} + D_{\text{rope}}) \times T_{\text{new}}$$

Decode:$= 2 \times 7168 \times 576 \times 1 = 8{,}257{,}536 \approx 8.3\text{M FLOPs}$

这个投影产生两部分输出:

  • 前 512 维:压缩的 KV latent,进入 $W_{kv_b}$ 解压
  • 后 64 维:K 的 RoPE 分量(不压缩),直接用于注意力计算

(2) KV 解压投影 $W_{kv_b}$

$$W_{kv\_b}: [d_{kv}] \to [H \times (D_{\text{nope}} + D_v)] = 512 \to 64 \times (128 + 128) = 64 \times 256 = 16384$$$$\text{FLOPs}_{kv\_b} = 2 \times d_{kv} \times H \times (D_{\text{nope}} + D_v) \times T_{\text{new}}$$

Decode:$= 2 \times 512 \times 64 \times 256 \times 1 = 16{,}777{,}216 \approx 16.8\text{M FLOPs}$

这个投影从 512 维 latent 中“解压”出 64 个 head,每个 head 有 128 维 nope K 和 128 维 V。等效于用一个 $512 \times 16384$ 的矩阵做投影——但比直接从 $7168 \to 16384$(MHA 方式)的 $7168 \times 16384 = 117.4\text{M}$ 矩阵 小了 7×

3.4.3 Q 侧 FLOPs

(3) Q 压缩投影 $W_{q_a}$

$$W_{q\_a}: [d] \to [d_q] = 7168 \to 1536$$$$\text{FLOPs}_{q\_a} = 2 \times d \times d_q \times T_{\text{new}}$$

Decode:$= 2 \times 7168 \times 1536 \times 1 = 22{,}020{,}096 \approx 22.0\text{M FLOPs}$

(4) Q nope 解压投影 $W_{q_b}$

$$W_{q\_b}: [d_q] \to [H \times D_{\text{nope}}] = 1536 \to 64 \times 128 = 8192$$$$\text{FLOPs}_{q\_b} = 2 \times d_q \times H \times D_{\text{nope}} \times T_{\text{new}}$$

Decode:$= 2 \times 1536 \times 64 \times 128 \times 1 = 25{,}165{,}824 \approx 25.2\text{M FLOPs}$

(5) Q RoPE 直投投影 $W_{q_rope}$

RoPE 分量必须直接从 hidden 维度投影,不能经过压缩——因为 RoPE 的旋转操作施加在维度对上,压缩会破坏这个结构。

$$W_{q\_rope}: [d] \to [H \times D_{\text{rope}}] = 7168 \to 64 \times 64 = 4096$$$$\text{FLOPs}_{q\_rope} = 2 \times d \times H \times D_{\text{rope}} \times T_{\text{new}}$$

Decode:$= 2 \times 7168 \times 64 \times 64 \times 1 = 58{,}720{,}256 \approx 58.7\text{M FLOPs}$

注意:$W_{q_rope}$ 是 MLA 中第二大的单项 FLOPs(仅次于输出投影),因为 RoPE 部分不能享受低秩压缩的红利。

3.4.4 QK 点积与 V 加权(O(T²) 项——与 MHA 完全等同)

MLA 的 QK 点积分为两部分:

(6a) nope 分量的 QK 点积

$$\text{FLOPs}_{QK_{\text{nope}}} = 2 \times H \times T_{\text{new}} \times T_{\text{total}} \times D_{\text{nope}}$$

Decode(T=1M):$= 2 \times 64 \times 1 \times 10^6 \times 128 = 16.4\text{G FLOPs}$

(6b) rope 分量的 QK 点积

$$\text{FLOPs}_{QK_{\text{rope}}} = 2 \times H \times T_{\text{new}} \times T_{\text{total}} \times D_{\text{rope}}$$

Decode(T=1M):$= 2 \times 64 \times 1 \times 10^6 \times 64 = 8.2\text{G FLOPs}$

(6c) 合计 QK 点积

$$\text{FLOPs}_{QK} = 2 \times H \times T \times (D_{\text{nope}} + D_{\text{rope}}) = 2 \times H \times T \times D_h$$$$= 2 \times 64 \times 10^6 \times 192 = 24.6\text{G FLOPs}$$

其中 $D_h = 128 + 64 = 192$。这与标准 MHA($D_h=192$)的 QK 点积 FLOPs 完全相等。

(7) V 加权

$$\text{FLOPs}_{V} = 2 \times H \times T_{\text{new}} \times T_{\text{total}} \times D_v$$

Decode(T=1M):$= 2 \times 64 \times 1 \times 10^6 \times 128 = 16.4\text{G FLOPs}$

3.4.5 输出投影

(8) 输出投影 $W_o$

$$W_o: [H \times D_v] \to [d] = (64 \times 128) = 8192 \to 7168$$$$\text{FLOPs}_o = 2 \times H \times D_v \times d \times T_{\text{new}}$$

Decode:$= 2 \times 64 \times 128 \times 7168 \times 1 = 117{,}440{,}512 \approx 117.4\text{M FLOPs}$

3.4.6 MLA 单层 Decode FLOPs 汇总(T=1M)

组件FLOPs类型vs MHA 同配置
$W_{kv_a}$(KV 压缩)8.3M常数—(MLA 新增)
$W_{kv_b}$(KV 解压)16.8M常数—(MLA 新增)
$W_{q_a}$(Q 压缩)22.0M常数—(MLA 新增)
$W_{q_b}$(Q nope 解压)25.2M常数—(MLA 新增)
$W_{q_rope}$(Q RoPE 直投)58.7M常数MHA Q proj 176.2M → 节省 3×
QK 点积(nope + rope)24.6GO(T)相同
V 加权16.4GO(T)相同
$W_o$(输出投影)117.4M常数相同
单层合计~41.2GMHA 同配置 ~57.4G

MLA 单层节省的 FLOPs 主要来自于:用多个小矩阵(低秩)替代 Q、K、V 的直投大矩阵。$W_{kv_a}$ + $W_{kv_b}$ + $W_{q_a}$ + $W_{q_b}$ + $W_{q_rope}$ 合计 $\approx 131\text{M FLOPs}$,而标准 MHA 的 Q+K+V 三个直投矩阵合计 $\approx 2 \times 7168 \times 64 \times 192 \times 3 \approx 528\text{M FLOPs}$。线性项节省约 4×

但 QK 点积(24.6G)+ V 加权(16.4G)= 41G——这部分在 T=1M 时占比超过 99%,且与标准 MHA 完全相同

3.4.7 关键结论

MLA 省的是 KV cache,不是 FLOPs 的主体。

  • 线性项(投影):MLA 将 QKV 投影从 $\approx 528\text{M}$ 降到 $\approx 131\text{M FLOPs/token}$,但这项在长上下文下只占总 FLOPs 的 $\sim 0.3%$。
  • 平方项/长上下文项(QK + V):MLA 的 FLOPs 与 MHA 完全相同——$2 \times H \times T \times D_h$——因为最终 attention 计算的维度规模没有变。
  • KV Cache:MLA 将每个 token 的 KV cache 从 $2 \times H \times D_h = 2 \times 64 \times 192 = 24{,}576$ 个元素压到 $d_{kv} + D_{\text{rope}} = 512 + 64 = 576$ 个元素——压缩 43×。这才是 MLA 的主要价值。

MLA 就像“快递打包”——包裹运输时压缩(KV cache 小),但到了收件人手里必须拆开原样呈现(注意力计算时的 K、V 维度与 MHA 完全相同)。运费省了(显存),但收件人验货的工作量没少(FLOPs)。


3.5 Mamba-2 SSD FLOPs(Nemotron)

逐项拆解 Mamba-2 Structured State Space Duality 层的 FLOPs,展示为什么它是 O(T) 而非 O(T²)。Mamba-2 是 Nemotron 的核心非 Attention 序列建模层——48 个 Mamba 层的 FLOPs 特征决定了整个模型的长上下文行为。

3.5.1 Mamba-2 计算流程回顾

维度回顾(Nemotron):$d=8192$,$\text{expand}=2 \Rightarrow d_{\text{inner}}=16384$,$H_{\text{mamba}}=256$,$D_{\text{mamba}}=64$,$N=128$(ssm_state_size),$n_{\text{groups}}=8$,$C=128$(chunk size)。

验证自洽性:$d_{\text{inner}} = H_{\text{mamba}} \times D_{\text{mamba}} = 256 \times 64 = 16384$。$\checkmark$

Mamba-2 的 SSD 将序列分成大小为 C 的 chunk,每个 chunk 内部做因果 matmul(对角块),chunk 之间通过状态传递(非对角块)。总计算量分为四部分:

3.5.2 (a) in_proj 输入投影(线性项主力)

in_proj 一次性产生所有需要的分量:$\mathbf{x}$、$\mathbf{z}$、$\mathbf{B}$、$\mathbf{C}$、$\boldsymbol{\Delta}$。

投影维度:$d \to 2 \times d_{\text{inner}} + 2 \times n_{\text{groups}} \times N + H_{\text{mamba}}$
$= 8192 \to 2 \times 16384 + 2 \times 8 \times 128 + 256$
$= 8192 \to 32768 + 2048 + 256 = 35072$

$$\text{FLOPs}_{\text{in\_proj}} = 2 \times d \times 35072 \times T_{\text{new}}$$

Decode:$= 2 \times 8192 \times 35072 \times 1 = 574{,}619{,}648 \approx 574.6\text{M FLOPs}$

这是 Mamba-2 层单 token 计算中最大的一项。对比 Attention 的 Q 投影(134M),Mamba 的 in_proj 约大 4.3×——因为它是一次性投影出 5 个分量(x, z, B, C, Δ),相当于把 Attention 的 Q、K、V、外加两个额外的分量合并到一个矩阵里。

3.5.3 (b) conv1d 深度卷积(可忽略)

一维深度卷积,核大小 = 4,输入通道数 = $d_{\text{conv}} = d_{\text{inner}} + 2 \times n_{\text{groups}} \times N = 16384 + 2048 = 18432$。

$$\text{FLOPs}_{\text{conv1d}} = 2 \times d_{\text{conv}} \times \text{kernel} \times T_{\text{new}}$$

Decode:$= 2 \times 18432 \times 4 \times 1 = 147{,}456 \approx 0.15\text{M FLOPs}$

卷积核只有 4 个元素宽,而且是深度卷积(每个通道独立的 1D 卷积),所以计算量跟 in_proj 比可以忽略不计——就像“顺丰快递的包装费相对于货品价值”。

3.5.4 (c) SSD 对角块(chunk 内因果 matmul)

这是 Mamba-2 “Attention 等价” 的部分。在每个 chunk 内,SSD 做类似因果 Attention 的计算:

$$\text{FLOPs}_{\text{diag}} = 2 \times \frac{T}{C} \times \frac{C^2}{2} \times H_{\text{mamba}} \times D_{\text{mamba}} = T \times C \times H_{\text{mamba}} \times D_{\text{mamba}}$$

代入:$= T \times 128 \times 256 \times 64 = T \times 2{,}097{,}152$

Prefill(T=4096):$4096 \times 2{,}097{,}152 \approx 8.59 \times 10^9 \approx 8.6\text{G FLOPs}$

Decode($T_{\text{new}}=1$,但 chunk 内的因果 matmul 在 decode 时仅涉及当前 chunk 的累积状态):约 4.2M FLOPs(与 T 无关)。

这里需要澄清:在 decode 阶段,Mamba-2 不需要对每个新 token 重做所有 chunk 的内部计算——SSD 的递归特性意味着新 token 只需要更新当前 chunk 的对角块和状态传递。因此 decode 时这部分是常数。

3.5.5 (d) SSD 非对角块:chunk 间的状态传递

前面的对角块是每个 chunk “内部消化”——chunk 里的每个 token 看到前面 token 的计算。但 chunk 1 的最后一个 token 怎么看到 chunk 0 的第一个 token?这需要状态传递

Mamba-2 的 SSM 在每个 chunk 边界维护一个隐藏状态 $h \in \mathbb{R}^{H_{\text{mamba}} \times N}$($N = d_{state} = 128$)。这个状态向量"记住"了之前所有 chunk 的摘要。

当一个 chunk 结束时,它的状态 $h_{i}$ 需要"传递"给下一个 chunk。传递的数学操作是:下一个 chunk 的每个位置,将传入状态与当前 chunk 的 $C$(输出投影)相乘,得到对当前 chunk 内每个 token 的修正量。这个操作为每个 chunk 边界做一次 $h_i \times C_{i+1}$。

$h_i$ 的形状是 $[H_{\text{mamba}}, N]$,$C_{i+1}$(经过 decay 加权后)的形状也是 $[H_{\text{mamba}}, N]$。这不是简单的向量点积——Mamba-2 需要在 $N$ 维空间内做"状态混合",让 $N$ 维的每个分量都能影响当前 chunk 的输出。因此,实际的状态传递矩阵是一个 $[N, N]$ 的变换:

$$\text{FLOPs}_{\text{off-diag}} = 2 \times \underbrace{\frac{T}{C}}_{\text{chunk 数}} \times \underbrace{H_{\text{mamba}}}_{\text{heads}} \times \underbrace{N^2}_{\text{状态传递矩阵}}$$

代入 Nemotron 的数值:chunk 数 $= T/128$,$H_{\text{mamba}} = 256$,$N = 128$:

$$= 2 \times \frac{T}{128} \times 256 \times 128^2 = 2 \times \frac{T}{128} \times 256 \times 16{,}384$$$$= 2 \times \frac{T}{128} \times 4{,}194{,}304 = T \times 65{,}536 \approx 6.55 \times 10^4 \times T$$

Prefill(T=4096):$4096 \times 65{,}536 \approx 0.27\text{G FLOPs}$

Decode:约 $6.55 \times 10^4$ FLOPs(常数级别)。

对角块和非对角块加起来,就是 SSD 的完整 FLOPs。对角块做"块内注意"($O(C^2)$),非对角块做"块间传递"($O(N^2)$)。$C = 128$、$N = 128$ 时,$C^2 = N^2$——这是设计上的巧合,不是必然。如果 chunk_size 变了,对角块和非对角块的比例就会偏移。

3.5.6 (e) out_proj 输出投影

$$\text{FLOPs}_{\text{out\_proj}} = 2 \times d_{\text{inner}} \times d \times T_{\text{new}}$$

Decode:$= 2 \times 16384 \times 8192 \times 1 = 268{,}435{,}456 \approx 268.4\text{M FLOPs}$

3.5.7 Mamba-2 单层 FLOPs 汇总

Prefill(T=4096)

组件FLOPs占比复杂度
in_proj$574.6\text{M} \times 4096 = 2.35\text{T}$92.3%O(T)
conv1d$0.15\text{M} \times 4096 = 0.61\text{G}$~0%O(T)
SSD 对角块8.6G0.3%O(T×C)
SSD 非对角块0.27G~0%O(T)
out_proj$268.4\text{M} \times 4096 = 1.10\text{T}$7.4%O(T)
单层合计~3.46T FLOPs100%O(T)

48 层合计:$\approx 166\text{T FLOPs}$(prefill 4096 token)。全部是 O(T)——没有任何 O(T²) 项。

Decode($T_{\text{new}}=1$,$T=1\text{M}$)

组件FLOPs复杂度
in_proj574.6MO(1)
conv1d0.15MO(1)
SSD 对角块 (decode)~4.2MO(1)
SSD 非对角块 (decode)~0.07MO(1)
out_proj268.4MO(1)
单层合计~847MO(1)

48 层 Mamba-2 合计:$\approx 40.7\text{G FLOPs/token}$(与 T 无关!)

这是最关键的数字:Mamba-2 层的 decode FLOPs 与上下文长度完全无关——每 token 固定 $\approx 847\text{M FLOPs}$。而 Attention 层在 T=1M 时需要 $\approx 33.1\text{G FLOPs/token}$。

3.5.8 与 Attention 的对比:O(T) vs O(T²)

以 1M 上下文为例,单层对比

指标Full Attention (GQA)Mamba-2 SSD比率
线性项 (proj)277M843M0.33×(Mamba 更贵)
长上下文项 (QK/sSD)32.8G~4.3M7600×(Attention 更贵)
单层总计33.1G847M39×(Mamba 更快)

48 层 Mamba-2($\approx 40.7\text{G FLOPs}$) vs 48 层 Full Attention($\approx 48 \times 33.1\text{G} \approx 1.59\text{T FLOPs}$)——Mamba 快 39×

Mamba-2 的 SSD 是“聪明地算”——把 O(T²) 的 Attention 变成了 chunk 内 O(C²) 的因果 matmul(C=128,常数)。1M 个 token 被切成 ~7812 个 chunk,每个 chunk 内部做的计算量恒定。新 token 到来时,只更新当前 chunk 并传播状态。而 Attention 每来一个新 token,都要跟全部 1M 个历史 token 逐一“打招呼”。这就是 O(T) vs O(T²) 的本质区别。


3.6 Sliding Window Attention(SWA)FLOPs

Sliding Window Attention 是 MiMo-V2-Flash、Mistral 等模型使用的稀疏 Attention 方案。每个 token 只关注它前面固定窗口 $W$ 内的 token,而非全部 $T$ 个 token。

QK 点积的复杂度从 $O(T^2)$ 降到 $O(T \times W)$:

$$\text{FLOPs}_{\text{QK, SWA}} = 2 \times H_q \times T_{\text{new}} \times \min(T, W) \times D_h$$
  • Prefill(每个 token 看到前面 $W$ 个):$2 \times H_q \times T \times W \times D_h$
  • Decode(新 token 只往前看 $W$ 步):$2 \times H_q \times 1 \times W \times D_h$

以 MiMo-V2-Flash 为例:$H_q = 64$,$W = 131072$,$D_h = 128$。Prefill 时 $T=W=131\text{K}$:$2 \times 64 \times 131072 \times 131072 \times 128 \approx 2.8 \times 10^{14}$ FLOPs,是 Full Attention($8.4 \times 10^{14}$)的约 $1/3$。但 decode 时:$2 \times 64 \times 1 \times 131072 \times 128 = 2.15 \times 10^9$ FLOPs——与 Full Attention decode 完全相同(因为 decode 时 $T_{new}=1$,Full Attn 也只看全部 $T$ 个历史 token)。

SWA 省的是 prefill 而非 decode。它适合吞吐优先的短上下文场景,但在长上下文 decode 上没有优势。

SWA 的 $W$ 不是凭空取的——通常等于 max_position_embeddingssliding_window 字段。如果 config 中找不到 sliding_window 但模型声称是 SWA,查看 max_position_embeddings 是否与上下文窗口匹配。

3.7 Gated DeltaNet(Linear Attention)FLOPs

Gated DeltaNet 是 Qwen3.5-MoE 等模型使用的线性注意力方案。与 Mamba-2 共享核心思想——用固定大小的隐藏状态 $S_t \in \mathbb{R}^{H \times D_h \times D_h}$ 取代 Attention 的 $O(T^2)$ 点积。

DeltaNet 的更新规则(简化):

$$S_t = \alpha_t \cdot S_{t-1} + \beta_t \cdot (k_t \otimes v_t)$$

其中 $k_t \otimes v_t$ 是 key 和 value 的外积,形状为 $[H, D_h, D_h]$。$\alpha_t$ 是遗忘门(decay),$\beta_t$ 是输入门(input gate),两者都是通过投影从当前输入得到的标量。

输出:$y_t = S_t \cdot q_t$,其中 $S_t \cdot q_t$ 将一个 $[H, D_h, D_h]$ 矩阵与 $[H, D_h]$ 向量相乘,得到 $[H, D_h]$ 的注意力输出。

每 token FLOPs 分解

$$\text{FLOPs}_{\text{DeltaNet}} = \underbrace{2 \times H \times D_h^2}_{\text{外积 } k_t \otimes v_t} + \underbrace{2 \times H \times D_h^2}_{\text{状态乘 } S_t \cdot q_t} + \underbrace{2 \times H \times D_h^2}_{\text{状态更新 } S_t = \alpha S_{t-1} + \beta(k \otimes v)}$$

三项各 $2 \times H \times D_h^2$,合计 $6 \times H \times D_h^2$。全与 $T$ 无关——DeltaNet 的 decode FLOPs 是常数

以 Qwen3.5-MoE 为例($H = 64$,$D_h = 128$):$6 \times 64 \times 128^2 = 6 \times 64 \times 16384 \approx 6.3 \times 10^6$ FLOPs/token/layer。对比 Full Attention 的 decode($2 \times 64 \times 10^6 \times 128 \approx 1.6 \times 10^{10}$),DeltaNet 节省了约 2500×

与 Mamba-2 的核心差异:Mamba-2 通过 in_proj 一次性产生所有 SSM 参数($\Delta, B, C$),其输入投影的 FLOPs 远大于 SSM 核心计算。DeltaNet 的投影更简单(类似标准 Attention 的 QKV 投影),所以整体 FLOPs 更小。但 Mamba-2 的状态维度 $H \times N$($256 \times 128$)比 DeltaNet 的 $H \times D_h^2$($64 \times 128^2$)小得多——状态大小是 $O(H \times N)$ vs $O(H \times D_h^2)$,差了 $D_h$ 倍。

3.8 MoE Gating FLOPs

计算路由器(Router / Gate)的 FLOPs,证明它在总计算量中占比 <1%。很多人担心 MoE 的路由开销会抵消稀疏化的收益——这一页数值直接打消这个顾虑。

Router FLOPs

标准 sigmoid/softmax 路由器的核心计算是一个矩阵乘法:

$$\text{FLOPs}_{\text{router}} = 2 \times d \times E \times T_{\text{new}}$$

Nemotron($d=8192$,$E=512$,decode):

$$\text{FLOPs}_{\text{router}} = 2 \times 8192 \times 512 \times 1 = 8{,}388{,}608 \approx 8.4\text{M FLOPs}$$

M3($d=6144$,$E=128$,decode):

$$\text{FLOPs}_{\text{router}} = 2 \times 6144 \times 128 \times 1 = 1{,}572{,}864 \approx 1.6\text{M FLOPs}$$

对比单层 MoE 的专家计算量(激活 4-22 个专家,每个专家做 $2 \times d \times d_{ff}$ 或 $3 \times d \times d_{ff}$ 的 FFN):

  • Nemotron 单专家(ReLU$^2$,latent 空间):$2 \times 2048 \times 5120 \approx 21\text{M FLOPs}$
  • 激活 22 个专家:$\approx 462\text{M FLOPs}$

Router 的 8.4M FLOPs 占 462M 的 1.8%。在 M3(128 专家,激活 4 个)中占比更低。

DeepSeek V4 Flash 的 hash routing 稍复杂,但本质仍是查表+少量矩阵乘法,FLOPs 在百万量级,可忽略。

Router 就是给 512 扇门各配一把锁(一个 8192 维向量),新 token 来了用自己的 8192 维特征跟 512 把锁各算一次相似度。这个开销相当于一扇门打开后干活(一个专家 FFN)的几十分之一。Router 的 FLOPs 约等于半个 Attention 的 K 投影——在总计算量的大海里是一滴水。


3.9 Vision Encoder FLOPs

计算 ViT 编码器的 FLOPs,理解为什么视觉编码在总推理成本中的占比。多模态模型输入一张图时,ViT 要处理 576-2916 个 patch token——这部分计算量是“固定入场券”,与文本长度无关。

3.7.1 MiniMax M3 ViT FLOPs

M3 ViT:32 层,$d_{\text{vit}}=1280$,$H_{\text{vit}}=16$,$D_{\text{vit}}=80$,$d_{ff}^{\text{vit}}=5120$。

图像 token 数:$\left(\frac{2016}{14}\right)^2 = 144^2 = 20736$ patches,经过 pixel unshuffle($\times 4$ 压缩)后:$20736 / 4 = 5184$,再经 spatial merge:$5184 / 9 = 576$ tokens。本文取 576。

单层 Attention(标准 MHA):

$$\text{FLOPs}_{\text{ViT QKV}} = 4 \times 2 \times d_{\text{vit}} \times H_{\text{vit}} \times D_{\text{vit}} \times T_{\text{img}}$$$$= 8 \times 1280 \times 16 \times 80 \times 576 = 8 \times 1{,}638{,}400 \times 576$$$$= 8 \times 943{,}718{,}400 = 7{,}549{,}747{,}200 \approx 7.55\text{G FLOPs}$$

($4 \times 2 = 8$ 来自 Q、K、V、O 四个投影各 $2 \times m \times n \times k$)

QK 点积(causal 不适用,ViT 对图像做双向 Attention):

$$\text{FLOPs}_{\text{ViT QK}} = 2 \times H_{\text{vit}} \times T_{\text{img}}^2 \times D_{\text{vit}} = 2 \times 16 \times 576^2 \times 80$$$$= 2 \times 16 \times 331{,}776 \times 80 = 849{,}346{,}560 \approx 0.85\text{G FLOPs}$$

V 加权

$$\text{FLOPs}_{\text{ViT V}} = 2 \times H_{\text{vit}} \times T_{\text{img}}^2 \times D_{\text{vit}} = 0.85\text{G FLOPs}$$

单层 MLP(GELU,2 个矩阵):

$$\text{FLOPs}_{\text{ViT MLP}} = 2 \times 2 \times d_{\text{vit}} \times d_{ff}^{\text{vit}} \times T_{\text{img}}$$$$= 4 \times 1280 \times 5120 \times 576 = 4 \times 6{,}553{,}600 \times 576$$$$= 4 \times 3{,}774{,}873{,}600 = 15{,}099{,}494{,}400 \approx 15.1\text{G FLOPs}$$

单层合计:$7.55 + 0.85 + 0.85 + 15.1 \approx 24.35\text{G FLOPs}$

32 层合计:$32 \times 24.35\text{G} \approx 779\text{G FLOPs}$

加上 patch embedding、projector 等:$\approx 800\text{G FLOPs} = 0.8\text{T FLOPs}$(per image)。

对比文本骨干(60 层,prefill 4096 token,$\approx 100\text{T+ FLOPs}$),ViT 的 0.8T FLOPs 占比 <1%。

ViT 虽深(32 层),但维度小(1280 vs 6144)且 token 数固定(576 vs 4096+)。相当于“一辆 Smart 虽也能开到 120 迈,但跟重卡(文本骨干)不是一个吨位的”。

3.7.2 Kimi K2.5 ViT FLOPs(速算)

K2.5 ViT:27 层,$d_{\text{vit}}=1152$,$H_{\text{vit}}=16$,$D_{\text{vit}}=72$,$d_{ff}^{\text{vit}}=4304$。图像 token 数约 576-2916(取决于分辨率模式)。

用 576 token 近似:

$$\text{单层 Attn + MLP} \approx 8 \times 1152 \times 16 \times 72 \times 576 + 4 \times 1152 \times 4304 \times 576$$$$\approx 6.1\text{G} + 11.4\text{G} \approx 17.5\text{G FLOPs}$$

27 层:$\approx 0.47\text{T FLOPs}$。加上 PatchMerger 和投影器:$\approx 0.5-0.7\text{T FLOPs}$。


3.10 完整案例对比:1M 上下文下三种架构的 FLOPs

在同一张表中呈现纯 Full Attention、Nemotron Hybrid(Mamba + Attn)、M3 MSA 三种方案的 FLOPs 分解。这张表是 CH3 的终极输出——一行看懂 Mamba 和 MSA 为什么殊途同归地解决了 O(T²) 问题。

3.8.1 场景设定

  • 上下文长度:T = 1M tokens
  • 解码阶段:$T_{\text{new}} = 1$(单 token decode)
  • 对比模型:
    • 纯 Full Attn (hypothetical):60 层 Full Attention,$d=8192$,$H_q=64$,$H_{kv}=64$(MHA,无 GQA),$D_h=128$,SwiGLU FFN $d_{ff}=8192 \times 4 \approx 32768$(无 MoE 时 FFN 占比较小,此处简化用大维度)
    • Nemotron 3 Ultra (hybrid):48 层 Mamba-2 + 12 层 Attention(GQA 32:1,2 KV heads)+ 48 层 MoE(22/512 激活)。$d=8192$,$H_q=64$,$H_{kv}=2$,$D_h=128$。MoE 专家在 latent 空间计算。
    • M3 (MSA):57 层 MSA(GQA 16:1,4 KV heads)+ 3 层 Full Attention(GQA 16:1)+ 57 层 MoE(4/128 激活)。$d=6144$,$H_q=64$,$H_{kv}=4$,$D_h=128$。

3.8.2 逐项 FLOPs 分解(decode per token, T=1M)

Attention 部分(QK + V 加权)

模型Attention 层数单层 QK+V FLOPsAttn 部分合计
纯 Full Attn60$4 \times 64 \times 1\text{M} \times 128 = 32.8\text{G}$$60 \times 32.8\text{G} = 1.97\text{T}$
Nemotron Hybrid1232.8G (GQA 下 QK+V 仍为 $4 \times 64 \times T \times 128$)$12 \times 32.8\text{G} = 393.6\text{G}$
M3 MSA3 Full + 57 MSAFull: 32.8G(改用 $d=6144$,$H_q=64$,$H_{kv}=4$ 后实际 ~32.8G);MSA: Index QK 1.02G + Main QK+V 67.2M ≈ 1.09G$3 \times 32.8\text{G} + 57 \times 1.09\text{G} \approx 160.5\text{G}$

Mamba/SSD 部分

模型Mamba/SSD 层数单层 FLOPsMamba 部分合计
纯 Full Attn000
Nemotron Hybrid48~847M$48 \times 847\text{M} = 40.7\text{G}$
M3 MSA000

线性投影部分(QKV proj + O proj + in_proj + out_proj + FFN):

模型单层投影估算投影部分合计
纯 Full AttnQ(134M) + K(134M) + V(134M) + O(134M) + FFN(~1.6G) ≈ 2.14G$60 \times 2.14\text{G} \approx 128\text{G}$
Nemotron HybridAttn 投影(~277M) × 12 + Mamba 投影(~843M) × 48 + MoE FFN(~462M) × 48$\approx 3.3\text{G} + 40.5\text{G} + 22.2\text{G} \approx 66\text{G}$
M3 MSAMSA 投影(~220M) × 57 + Full Attn 投影(~220M) × 3 + MoE FFN(~220M) × 57$\approx 12.5\text{G} + 0.7\text{G} + 12.5\text{G} \approx 26\text{G}$

注:以上为近似量级估算。投影部分具体数值取决于 $d_{ff}$、MoE 专家数等配置细节,精确计算需代入各模型 config.json 的实际值。本表的重点是横比数量级差异。

3.8.3 总表

模型Attn QK+V 部分Mamba/SSD 部分线性投影总 FLOPs/token相对纯 Full Attn
纯 Full Attn (hypothetical)~1.97T0~128G~2.10T1×(基线)
Nemotron 3 Ultra (hybrid)~394G~41G~66G~501G~1/4
M3 (MSA)~161G0~26G~187G~1/11

核心发现:

  1. 纯 Full Attn 在 1M 上下文下几乎不可用:每产生一个 token 需要 2.1T FLOPs,单看 Attention QK+V 部分的 1.97T 占 94%。即使最强大的推理硬件也难以达到可接受的吞吐(2.1T / 989 TFLOPS(H100 FP16)$\approx 2.1$ 秒/ token)。

  2. Nemotron Hybrid 将 QK+V 开销砍到原来的 1/5(394G vs 1970G),因为 80% 的层(48/60)用 Mamba-2 完全避开了 O(T) Attention。但 12 个 Attention 层仍贡献了总 FLOPs 的 78%——12 个 Attention 层的成本超过了 48 个 Mamba 层的总和

  3. M3 MSA 更进一步:3 个 Full Attention 层占 98G 的 QK+V,57 个 MSA 层才占 62G(Index QK $57 \times 1.02\text{G} = 58.1\text{G}$ + Main QK+V $57 \times 0.067\text{G} = 3.8\text{G}$)。MSA 的 Index Branch 虽然仍是 O(T),但以 16× 的廉价系数执行。

  4. 殊途同归:Nemotron 用 Mamba-2(状态空间,O(1) decode),M3 用稀疏 Attention(O(T) 但系数极小)——两条不同的技术路线,但都在 1M 上下文上将 Attention 部分从 TFLOPs 量级压到了 GFLOPs 量级。原理不同,效果趋同。

3.8.4 不同上下文长度下的横比

为直观展示 O(T) vs O(1) 的差别,固定模型配置,变化 T。仅计算 Attention 相关的 QK+V 部分(不含投影和 FFN):

T纯 Full Attn QK+V (60层)Nemotron Hybrid Attn QK+V (12层)M3 QK+V (3 Full + 57 MSA)
4K$60 \times 4 \times 64 \times 4096 \times 128 = 8.05\text{G}$$12 \times 4 \times 64 \times 4096 \times 128 = 1.61\text{G}$3 Full: $3 \times 4 \times 64 \times 4096 \times 128 = 0.40\text{G}$
57 MSA Index: $57 \times 2 \times 4 \times 4096 \times 128 = 0.24\text{G}$
57 MSA Main: $57 \times 4 \times 64 \times 2048 \times 128 = 3.82\text{G}$
合计: ~4.46G
32K$60 \times 4 \times 64 \times 32768 \times 128 = 64.4\text{G}$$12 \times 4 \times 64 \times 32768 \times 128 = 12.9\text{G}$3 Full: $3 \times 4 \times 64 \times 32768 \times 128 = 3.22\text{G}$
57 MSA Index: $57 \times 2 \times 4 \times 32768 \times 128 = 1.91\text{G}$
57 MSA Main: $57 \times 4 \times 64 \times 2048 \times 128 = 3.82\text{G}$
合计: ~8.95G
128K$60 \times 4 \times 64 \times 131072 \times 128 = 258\text{G}$$12 \times 4 \times 64 \times 131072 \times 128 = 51.5\text{G}$3 Full: $3 \times 4 \times 64 \times 131072 \times 128 = 12.9\text{G}$
57 MSA Index: $57 \times 2 \times 4 \times 131072 \times 128 = 7.65\text{G}$
57 MSA Main: $57 \times 4 \times 64 \times 2048 \times 128 = 3.82\text{G}$
合计: ~24.4G
1M$60 \times 4 \times 64 \times 1\text{M} \times 128 = 1.97\text{T}$$12 \times 4 \times 64 \times 1\text{M} \times 128 = 394\text{G}$3 Full: $3 \times 4 \times 64 \times 1\text{M} \times 128 = 98.3\text{G}$
57 MSA Index: $57 \times 2 \times 4 \times 1\text{M} \times 128 = 58.4\text{G}$
57 MSA Main: $57 \times 4 \times 64 \times 2048 \times 128 = 3.82\text{G}$
合计: ~160.5G

注:M3 MSA 的 Main Branch 始终只在 2048 个入选 token 上做 Attention——与 T 无关,常数 3.82G。Index Branch 的 QK 评分随 T 线性增长但只有 4 个 head。Full Attention 的 3 层和 Index Branch 的 O(T) 项共同主导 M3 的长上下文成本。

观察

  • 4K 短上下文:三种方案差距较小(8.0G vs 1.6G vs 4.5G)。MSA 反而比纯 Full Attn(12 层)慢,因为 Index Branch 的额外开销 + Main Branch 选了 2048/4096=50% 的 token——稀疏化的好处在短序列上不明显。
  • 128K 中上下文:差距拉开(258G vs 52G vs 24G)。MSA Main Branch 仅访问 2048/131072 = 1.6% 的 token,而 Index Branch O(T) 项(7.7G)仍远小于 Full Attn O(T) 项(258G)。
  • 1M 长上下文:差距成为鸿沟(1970G vs 394G vs 161G)。MSA Main Branch 仅访问 2048/1M = 0.2% 的 token——近乎常数。M3 比纯 Full Attn 的 QK+V 部分快 ~12×,Nemotron Hybrid 快 ~5×。
  • 关键洞察:MSA 在超长上下文时 Main Branch 趋近于常数,Index Branch 成为唯一 O(T) 项。但因为 Index 只有 4 head,实际斜率仅为 Full Attn 的 1/16。MSA 本质是用 O(T) 斜率 1/16 的廉价计算替代全量 O(T)。

如果说短上下文(4K)是“在大厅里找人”,那长上下文(1M)就是“在鸟巢体育场里找人”。Full Attention 的做法是跟每一个观众对视一眼(O(T)),Mamba 的做法是先把体育场分片区,只跟片区组长沟通(chunk + state),MSA 的做法是先派几个侦察兵扫一眼观众席(Index Branch),找到目标区域后再派大队人马过去(Main Branch)。


3.11 速查表:FLOPs 公式汇总

给一张“查表即算”的公式大全。不需要重读整章,从这里抄公式代入 config.json 的数值即可。

组件公式适用场景
Q/K/V 投影$2 \times d \times (H_{\text{type}} \times D_h) \times T_{\text{new}}$Q 用 $H_q$,K/V 用 $H_{kv}$
QK 点积$2 \times H_q \times T_{\text{new}} \times T_{\text{total}} \times D_h$Prefill 时 $T_{\text{new}}=T_{\text{total}}$(causal 约 /2)
V 加权$2 \times H_q \times T_{\text{new}} \times T_{\text{total}} \times D_h$与 QK 等量级
O 投影$2 \times d \times (H_q \times D_h) \times T_{\text{new}}$与 Q 投影等量级
MLA $W_{kv_a}$$2 \times d \times (d_{kv} + D_{\text{rope}}) \times T_{\text{new}}$MLA 模型
MLA $W_{kv_b}$$2 \times d_{kv} \times H \times (D_{\text{nope}} + D_v) \times T_{\text{new}}$MLA 模型
MLA $W_{q_a}$$2 \times d \times d_q \times T_{\text{new}}$MLA 模型
MLA $W_{q_b}$$2 \times d_q \times H \times D_{\text{nope}} \times T_{\text{new}}$MLA 模型
MLA $W_{q_rope}$$2 \times d \times H \times D_{\text{rope}} \times T_{\text{new}}$MLA 模型
MSA Index QK$2 \times H_{\text{idx}} \times T_{\text{new}} \times T_{\text{total}} \times D_{\text{idx}}$M3 式 MSA
MSA Main QK/V$2 \times H_q \times T_{\text{new}} \times T_{\text{selected}} \times D_h$$T_{\text{selected}} = \text{block_size} \times \text{top_k}$
Mamba-2 in_proj$2 \times d \times (2d_{\text{inner}} + 2n_{\text{groups}}N + H_{\text{mamba}}) \times T_{\text{new}}$Nemotron 式 Mamba-2
Mamba-2 SSD diag$T \times C \times H_{\text{mamba}} \times D_{\text{mamba}}$Prefill; decode 时为常数
Mamba-2 SSD off-diag$T / C \times H_{\text{mamba}} \times N^2 \times 2$Prefill; decode 时常数可忽略
Mamba-2 out_proj$2 \times d_{\text{inner}} \times d \times T_{\text{new}}$总是
Router$2 \times d \times E \times T_{\text{new}}$所有 MoE 模型
FFN (ReLU$^2$)$2 \times 2 \times d \times d_{ff} \times T_{\text{new}}$Nemotron
FFN (SwiGLU)$2 \times 3 \times d \times d_{ff} \times T_{\text{new}}$M3, K2.5
ViT Attn$4 \times 2 \times d_{\text{vit}} \times H_{\text{vit}} \times D_{\text{vit}} \times T_{\text{img}}$VL 模型视觉编码器
ViT MLP (GELU)$2 \times 2 \times d_{\text{vit}} \times d_{ff}^{\text{vit}} \times T_{\text{img}}$VL 模型视觉编码器

实战口诀

  1. 先确定场景:prefill 还是 decode?
  2. 线性项(投影 + FFN):直接代入 $T_{\text{new}}$(prefill = 输入长度,decode = 1)
  3. 平方项(QK + V):将 $T_{\text{new}}$ 和 $T_{\text{total}}$ 分开——prefill 时两者相等,decode 时 $T_{\text{new}}=1$ 但 $T_{\text{total}}$ 是全部历史
  4. 稀疏/MSA 项:把 $T_{\text{total}}$ 换成 $T_{\text{selected}}$(入选 token 数)
  5. Mamba 项:decode 时全部为常数,prefill 时乘以 $T$
  6. 把每层加起来,乘以层数,得到单 token FLOPs
  7. 乘以 bytes 和 batch size 得到总计算吞吐需求

CH3 常见计算错误

#常见错误正确做法
1decode 时把 QKV 投影乘以 $T_{\text{total}}$decode 只投影 1 个新 token,投影 FLOPs 是常数
2GQA 下 QK 点积用 $H_{kv}$ 算QK 点积前 K 已经被 repeat_kv 扩展到 $H_q$,用 $H_q$ 算
3MLA 的 QK 点积以为能省 FLOPsMLA 省的是 KV cache(显存),不是 QK 点积 FLOPs——最终 attention 的 $D_h = D_{\text{nope}} + D_{\text{rope}}$ 与 MHA 相同
4把 prefill 的 causal /2 也用在 decodedecode 的 query 只有 1 个,不存在 causal mask 的对称简化,公式是 $T_{\text{new}} \times T_{\text{total}}$ 而非 $T^2/2$
5MSA 的 Index QK 以为不用算 O(T²)Index QK 仍然是 O(T²)(prefill)或 O(T)(decode),只是 head 数少(4 vs 64),系数省 16×
6Mamba-2 decode 时把 SSD 对角块按 O(T) 算Mamba-2 decode 是 O(1)——只需更新当前 chunk 的状态,不重算全部 chunk
7忘记乘 2(MAC 系数)深度学习框架中 1 MAC = 2 FLOPs,所有矩阵乘法公式必须有因子 2
8把参数数量当 FLOPs参数量是“存了多少数”,FLOPs 是“每次前向算多少下”,两者中间隔着序列长度 T(对 O(T) 项)或 T²(对 O(T²) 项)

下一章预告:CH 4 内存分析——KV Cache 大小推导、MLA/GQA 的缓存压缩比、显存带宽瓶颈(Roofline 模型)、batch size 与延迟的权衡。


CH 4 KV Cache 显存:原理、公式与多架构对比

计量约定:本章 KV cache 使用 GiB(1024³ bytes)。1 GiB = 1024³ bytes ≈ 1.074 GB。使用 1024 进制是因为 GPU 显存以 2 的幂次分配。T(序列长度)取 2^20 = 1,048,576。

本章定位:系统推导自回归推理中 KV cache 的显存占用公式,覆盖 MHA、GQA、MLA、MSA、Mamba-2 五种架构,并用 Kimi K2.5(MLA)、Nemotron 3 Ultra(GQA+Mamba)、MiniMax M3(MSA+GQA)的实测配置验证所有公式。


4.1 为什么需要 KV Cache

4.1.1 这节算什么

自回归推理时,模型每步只生成一个 token,但需要与所有历史 token 做 attention 运算。本节量化 KV cache 的本质:空间换时间——缓存中间结果,避免每步重新计算。

4.1.2 为什么重要

KV cache 是长上下文推理的第一瓶颈。1M 上下文中,纯 Attention 模型的 KV cache 可达数百 GiB,远超模型权重本身。架构选择(GQA、MLA、Mamba)的核心动机之一就是压缩或消除 KV cache。

4.1.3 直觉理解

看书时,读到第 100 页,不需要每翻一页就从头重读一遍——记住前面每一页的「关键信息」就够了。KV cache 就是模型在推理过程中对历史 token 的「关键信息摘要」。

标准自回归推理中,第 $t$ 步的 attention 需要对前 $t-1$ 个历史 token 计算 QK 点积:

$$\text{Attention}(Q_t, K_{1:t}, V_{1:t}) = \text{softmax}\left(\frac{Q_t \cdot K_{1:t}^T}{\sqrt{d_k}}\right) \cdot V_{1:t}$$

如果每步都重新计算 $K_{1:t}$ 和 $V_{1:t}$,第 $T$ 步的 FLOPs 将是 $O(T^2 \cdot d)$,总推理 FLOPs 为 $O(T^3 \cdot d)$。而缓存 KV 后,每步只需计算新 token 的 QKV 投影并与缓存中的 K、V 做 attention,总推理 FLOPs 降为 $O(T^2 \cdot d)$。


4.2 标准 MHA/GQA 的 KV Cache

4.2.1 这节算什么

从 MHA 和 GQA 的 attention 计算出发,推导 KV cache 的标准公式。这是所有 KV cache 分析的基准。

4.2.2 推导过程

第 1 步:每个 token 需要缓存什么?

标准自注意力中,对于序列中的每个历史 token,我们需要其 Key 向量和 Value 向量。每个 token 的 K 和 V 各一份,维度完全相同。

对于单个 attention head:

  • K shape: [head_dim]
  • V shape: [head_dim]

但实际存储是按 KV head 组织的(GQA 下 Q head 可以多于 KV head,此时多个 Q head 共享同一个 KV head)。

第 2 步:每层每 token 的缓存元素数

num_kv_heads = H_{kv}head_dim = D。每个 token 需要缓存 K 和 V 各一份:

$$\text{Cache elements per token per layer} = 2 \times H_{kv} \times D$$

其中每份 K 为 $H_{kv} \times D$ 个元素,V 同理。

第 3 步:完整模型公式

$$\text{KV Cache}_{total} = L_{attn} \times 2 \times H_{kv} \times D \times T \times \text{bytes\_per\_elem}$$

其中 $L_{attn}$ 为包含 attention 的层数,$T$ 为序列长度,$\text{bytes_per_elem}$ 取决于精度。
注意:如果模型包含非 attention 层(如 Mamba-2、纯 MLP 层),那些层不需要 KV cache,因此不参与计数。

4.2.3 直觉理解

  • $2 \times H_{kv} \times D$: 每层每 token 缓存 K+V 两个矩阵,每个矩阵有 H_kv 个 head × D 维 = 这就是一个 token 的「关键信息摘要」
  • $\times T$: 序列多长,缓存就多大——线性增长(这是 $O(T)$ 的)
  • $\times L_{attn}$: 每个 attention 层独立缓存
  • GQA 的省法:差异只在于 $H_{kv}$。$H_{kv}$ 越小,缓存越小

4.2.4 验证案例 1:Kimi K2.5(全 MHA,未使用 MLA 压缩时的理论值)

Kimi K2.5 使用全 MHA,即 $H_{kv} = H_Q = 64$,无 GQA 压缩。K 的有效维度为 $D_K = D_{nope} + D_{rope} = 192$(MLA 将 K 拆为 128 维内容 + 64 维位置),V 为 128 维。若不使用 MLA(仅作理论对比),在 $T = 256\text{K}$($262{,}144$ tokens)下,BF16 精度:

$$\text{KV Cache}_{no\_MLA} = 61 \times 64 \times (192 + 128) \times 262{,}144 \times 2$$$$= 61 \times 2 \times 8192 \times 262{,}144 \times 2 = 61 \times 32{,}768 \times 262{,}144 \times 2$$$$= 523{,}986{,}010{,}112 \text{ bytes} \approx 488.0 \text{ GiB}$$

直觉:这就是没有 MLA 压缩的代价——近 500 GiB,远超任何单 GPU 显存。这是 MLA 必须存在的根本原因。

4.2.5 验证案例 2:Nemotron 3 Ultra(GQA 32:1)

Nemotron 3 Ultra 仅有 12 层 Attention,使用极致 GQA($H_{kv} = 2$ 个 KV head),$D = 128$,在 $T = 1\text{M}$($1{,}048{,}576$ tokens)下,BF16:

$$\text{KV Cache}_{Nemotron} = 12 \times 2 \times 2 \times 128 \times 1{,}048{,}576 \times 2$$$$= 12 \times 512 \times 1{,}048{,}576 \times 2 = 12 \times 1{,}073{,}741{,}824$$$$= 12{,}884{,}901{,}888 \text{ bytes} = 12.0 \text{ GiB}$$

✅ 与 Nemotron 3 Ultra 技术报告声明的 ~12.0 GiB 完全一致。

为什么这么小?三个因素叠加:

  • 仅 12 层有 Attention(其余 48 层是 Mamba-2,不需要 KV cache)
  • GQA 32:1,$H_{kv}=2$——每层仅 2 个 KV head
  • 不使用 RoPE,head_dim=128 全部是「内容」维度

4.2.6 验证案例 3:MiniMax M3(GQA 16:1,主 KV cache)

MiniMax M3 全部 60 层使用 GQA 16:1($H_{kv}=4$),$D=128$,在 $T=1\text{M}$ 下,BF16:

$$\text{KV Cache}_{M3\_main} = 60 \times 2 \times 4 \times 128 \times 1{,}048{,}576 \times 2$$$$= 60 \times 1{,}024 \times 1{,}048{,}576 \times 2 = 60 \times 2{,}147{,}483{,}648$$$$= 128{,}849{,}018{,}880 \text{ bytes} = 120.0 \text{ GiB}$$

✅ 与 M3 报告声明的 ~120 GiB 完全一致。

4.2.7 GQA 压缩比公式

GQA 相对 MHA 的 KV cache 节省比例:

$$\text{Compression Ratio}_{GQA} = \frac{H_Q}{H_{kv}}$$

M3 的 GQA 16:1 意味着 KV cache 仅为 MHA 的 $1/16$。Nemotron 的 32:1 节省更极致。

一个全 MHA 60 层模型($H_{kv}=64$, $D=128$)在 1M 上下文的 KV cache:

$$60 \times 2 \times 64 \times 128 \times 1{,}048{,}576 \times 2 = 1{,}886{,}621{,}245{,}440 \text{ bytes} \approx 1{,}758 \text{ GiB}$$

这是不可部署的。GQA 是长上下文推理的基本生存策略。


4.3 MLA 的 KV Cache(Multi-head Latent Attention)

4.3.1 这节算什么

MLA 是本章最复杂的部分。MLA(DeepSeek V2/V3 提出,Kimi K2 系列继承)通过低秩压缩改变 KV cache 的存储对象——不再直接缓存 K 和 V,而是缓存一个低秩潜向量 $\mathbf{c}_t^{KV}$ 和一个额外的 RoPE 分量。本节从 shape 角度逐步推导 MLA 的缓存公式,并用 Kimi K2.5 的实测配置验证。

4.3.2 为什么重要

MLA 是当前 MoE 模型(DeepSeek V3/R1、Kimi K2 系列)实现长上下文推理的关键技术。不压缩时 K2.5 的 KV cache 高达 ~732 GiB(见下),MLA 将其压缩到约 21.5 GiB——压缩比 ~34 倍。理解 MLA 的缓存公式是评估 MoE 推理成本的前提。

4.3.3 核心问题

MLA 的 K 和 V 不是直接存储的——它们从一个共享的低秩潜向量 $\mathbf{c}_t^{KV}$ 通过升维投影得到。那么推理时 cache 应该存什么?是存完整的 K 和 V(失去了 MLA 的意义),还是存压缩后的潜向量?

答案:缓存 $\mathbf{c}_t^{KV}$(共享潜向量,可同时解压出 K 和 V)+ $\mathbf{k}_t^R$(RoPE 位置分量,不可压缩)。

4.3.4 推导过程:从 Shape 角度一步一步来

第 1 步:标准 Attention 的 K 是什么

在标准 MHA 中,每个 token 的 K 是一个形状为 $[H_{kv}, D_K]$ 的矩阵。以 Kimi K2.5 为例(全 MHA, $H_{kv} = H_Q = 64$),其 MLA 架构中 K 的实际维度为 $D_K = D_{nope} + D_{rope} = 128 + 64 = 192$:

$$\text{K cache per token per layer} = 64 \times 192 = 12{,}288 \text{ 个元素}$$

V 的维度为 $D_v = 128$:$64 \times 128 = 8{,}192$ 个元素。合计 $20{,}480$ 个元素。

第 2 步:MLA 如何计算 K——分为两块

MLA 将 K 分为两个功能不同的分量:

分量 1:$\mathbf{k}^{nope}$(内容分量,128 维 per head)

$$ \mathbf{c}_t^{KV} = \mathbf{x}_t \cdot \mathbf{W}_{kv\_down} \in \mathbb{R}^{512} $$$$ \mathbf{K}_{t}^{nope} = \mathbf{c}_t^{KV} \cdot \mathbf{W}_{k\_up} \in \mathbb{R}^{64 \times 128} $$

其中 $\mathbf{c}t^{KV}$ 是 512 维的潜向量,通过共享的 $\mathbf{W}{kv_down}$ 投影得到。然后通过 $\mathbf{W}_{k_up}$ 升维到 64 个 head × 128 维的完整 K(仅 nope 部分)。

关键:$\mathbf{K}^{nope}$ 是 64 × 128 = 8,192 维的矩阵,但它完全由 512 维的 $\mathbf{c}_t^{KV}$ 决定——所以不需要缓存 8,192 维,只需缓存 512 维。

分量 2:$\mathbf{k}^{rope}$(位置分量,64 维)

RoPE 是一个正交旋转变换,施加在 K 的头维度上。按照 MLA 的设计,RoPE 部分使用 MQA(Multi-Query Attention)方式共享:所有 64 个 attention head 使用同一个 RoPE Key 向量,维度为 $d_{rope} = 64$(即 qk_rope_head_dim)。

$$\mathbf{k}_t^R = \text{RoPE}(\mathbf{x}_t \cdot \mathbf{W}_{kr}) \in \mathbb{R}^{64}$$

每个 head $i$ 的完整 K 为:

$$\mathbf{K}_{t,i} = [\mathbf{k}_t^R \,;\, \mathbf{K}_{t,i}^{nope}] \in \mathbb{R}^{64 + 128 = 192}$$

为什么 $\mathbf{k}^R$ 不能被压缩? RoPE 是施加在完整 K 上的旋转变换——位置编码依赖具体的坐标值,不能通过低秩近似保留。因此 $\mathbf{k}^R$ 必须独立缓存。但由于它采用 MQA 共享(而非每 head 一份),实际缓存量很小。

第 3 步:MLA 如何计算 V

V 完全从 $\mathbf{c}_t^{KV}$ 解压得到,没有 RoPE 分量:

$$\mathbf{V}_t = \mathbf{c}_t^{KV} \cdot \mathbf{W}_{v\_up} \in \mathbb{R}^{64 \times 128}$$

关键:V 是 64 × 128 = 8,192 维,但完全由 512 维的 $\mathbf{c}_t^{KV}$ 决定——因此 V 也不需要单独缓存。

第 4 步:Cache 里到底存什么

综合第 2、3 步,每个 token 每层缓存的元素数为:

缓存项维度是否可压缩备注
$\mathbf{c}_t^{KV}$kv_lora_rank = 512这是压缩形式同时编码 K_nope 和 V
$\mathbf{k}_t^R$qk_rope_head_dim = 64不可压缩MQA 共享,所有 head 复用

合计:$512 + 64 = 576$ 个元素 per token per layer。

对比标准 Attention:$64 \times 192 + 64 \times 128 = 20{,}480$ 个元素。MLA 压缩比为 $20{,}480 / 576 \approx 35.6\times$。

第 5 步:Per Token Per Layer 公式

$$\text{Cache per token per layer}_{MLA} = (\text{kv\_lora\_rank} + \text{qk\_rope\_head\_dim}) \times \text{bytes\_per\_elem}$$

注意:这里不是 $\times 2$! 标准 Attention 的 $\times 2$ 是因为 K 和 V 各自独立存储。而 MLA 中 kv_lora_rank 的单个潜向量同时编码了 K_nope 和 V——一份存储,两份产出。

第 6 步:完整模型公式

$$\text{KV Cache}_{MLA} = L \times (\text{kv\_lora\_rank} + \text{qk\_rope\_head\_dim}) \times T \times \text{bytes\_per\_elem}$$

其中 $L$ 为模型总层数(MLA 通常在所有层使用)。

4.3.5 验证:代入 Kimi K2.5

配置回顾config.json):

  • $L = 61$ 层,全部使用 MLA
  • kv_lora_rank = 512
  • qk_rope_head_dim = 64
  • $T = 256\text{K} = 262{,}144$ tokens
  • $\text{bytes_per_elem} = 2$(BF16)

代入公式

$$\text{KV Cache}_{K2.5} = 61 \times (512 + 64) \times 262{,}144 \times 2$$$$= 61 \times 576 \times 262{,}144 \times 2$$$$= 61 \times 302{,}055{,}168 = 18{,}425{,}365{,}248 \text{ bytes}$$$$= 17.2 \text{ GiB}$$

与报告声明的对比:Kimi K2.5 技术报告声明 256K 时 KV cache 约 21.5 GiB。公式推导结果(17.2 GiB ≈ 18.4 GiB)与报告值差异约 15%。这一差异的可能来源:

  1. KV cache 对齐开销:GPU 显存通常以 128B 或 256B 对齐,每层每 token 额外开销约为 5-10%
  2. 额外缓存结构:部分 MLA 实现可能缓存额外的元数据(如 index/causal mask 的辅助结构)
  3. 报告舍入误差:技术报告中的数字通常做了一定程度的舍入

综合考虑对齐开销后约为 $17.2 \times 1.05 \approx 18.0 \text{ GiB}$,与 21.5 GiB 仍在同一数量级。

4.3.6 MLA 的直觉理解

  • 「两本账合一」:标准 Attention 需要分别存 K 和 V 两本账($\times 2$)。MLA 把两本账的信息压缩到同一个潜向量 $\mathbf{c}_t^{KV}$ 里——一个 512 维向量同时包含了 K 和 V 的精华
  • 「位置信息外包」:RoPE 不能压缩,但 MLA 巧妙地将 K 的 RoPE 部分用 MQA 方式共享(所有 head 共用一个 $\mathbf{k}^R$),而不是每个 head 存一份
  • 「为什么 MLA 比纯 GQA 更省」:GQA 只是减少了 KV head 数量(空间省但内容信息量受限),MLA 进一步在每 head 内部做低秩压缩——相当于 GQA 省宽度,MLA 省深度

4.3.7 MLA 压缩比的极限分析

MLA 的压缩比:

$$\text{Compression Ratio}_{MLA} = \frac{2 \times H_{kv} \times D}{\text{kv\_lora\_rank} + \text{qk\_rope\_head\_dim}}$$

以 K2.5 为例:

$$\frac{64 \times 192 + 64 \times 128}{512 + 64} = \frac{20{,}480}{576} \approx 35.6\times$$

压缩比的结构分解

  • 来自「K+V 共享潜向量」:$\times 2 \to \times 1$(省 50%)
  • 来自「低秩压缩 $8{,}192 \to 512$」:约 16 倍
  • 来自「$\mathbf{k}^R$ 的 MQA 共享」:64 head $\to$ 1 个共享向量(省约 64 倍)

三项叠加:$2 \times 16 \approx 32\times$,减去 $\mathbf{k}^R$ 开销后约 28 倍。


4.4 MSA 的 KV Cache(MiniMax Sparse Attention)

4.4.1 这节算什么

MiniMax M3 的 MSA(MiniMax Sparse Attention)在标准的 GQA KV cache 之上,额外引入了一组 Index K cache——用于 block-level 稀疏选择的轻量评分 Key。本节量化 MSA 的额外缓存开销。

4.4.2 为什么重要

MSA 的稀疏性体现在计算(每次只访问 top-16 blocks),但不体现在存储(所有 KV 仍需缓存,因为不同 query 可能选择不同的 blocks)。理解这一点才能正确评估 MSA 的显存需求——MSA 的加速来自计算 FLOPs 的减少,而不是 KV cache 的减少。

4.4.3 主 KV Cache:与标准 GQA 完全相同

MSA 不改变 K 和 V 的存储方式。60 层全部缓存主 KV,与标准 GQA 公式一致:

$$\text{KV Cache}_{M3\_main} = 60 \times 2 \times 4 \times 128 \times T \times 2 = 120.0 \text{ GiB at } T = 1\text{M}$$

计算过程已在 4.2.6 节验证,与 M3 报告声明的 ~120 GiB 完全一致。

4.4.4 Index K Cache:MSA 的额外开销

MSA 的 Index Branch(MiniMaxM3VLIndexer)用于为每个 query 从 $B = \lceil T / 128 \rceil$ 个 block 中评选出 top-16。Index Branch 需要缓存一个独立的 Index Key:

Index K 的 shape

  • $n_{idx_heads} = 4$(4 个 index head 用于多角度评分)
  • Index K head: 只有 1 个(被所有 4 个 index head 通过广播共享)
  • sparse_index_dim = 128
$$\text{Index K elements per token per layer} = 1 \times 128 = 128$$$$\text{Index K cache per token per layer} = 128 \times 2 = 256 \text{ bytes (BF16)}$$

完整公式(仅 MSA 层,即 57 层):

$$\text{KV Cache}_{M3\_index} = L_{MSA} \times H_{idx\_k} \times D_{idx} \times T \times \text{bytes\_per\_elem}$$

代入 M3 配置($L_{MSA} = 57$, $H_{idx_k} = 1$, $D_{idx} = 128$, $T = 1\text{M}$):

$$= 57 \times 1 \times 128 \times 1{,}048{,}576 \times 2$$$$= 57 \times 268{,}435{,}456 = 15{,}300{,}820{,}992 \text{ bytes} = 14.25 \text{ GiB}$$

✅ 与 M3 报告声明的 ~14.2 GiB 完全一致。

4.4.5 M3 总 KV Cache

$$\text{KV Cache}_{M3\_total} = 120.0 + 14.25 = 134.25 \text{ GiB at } T = 1\text{M}$$

其中主 KV cache 占 89.4%,Index K cache 占 10.6%。Index K cache 虽然每 token 只有 128 个元素(vs 主 KV 的 $2 \times 4 \times 128 = 1{,}024$ 个元素),但涉及 57 层,总计也达到了不可忽略的 ~14 GiB。

4.4.6 直觉理解

  • MSA 省计算,不省存储:主 KV cache 与 Full Attention 一模一样——所有历史 token 的 K 和 V 都必须保留,因为不同 query 会选择不同的 top-16 blocks
  • Index K cache 是「目录索引」的代价:在 1M 上下文中,需要额外的 ~14 GiB 来存储这份目录索引,但换来 decode 计算 30 倍加速(参见 M3 报告 CH3.6)
  • Index K 的 MQA 共享:4 个 index head 共享 1 个 index key,如果用 4 个独立的 index key,开销将是 $14.25 \times 4 = 57 \text{ GiB}$

4.5 Sliding Window Attention 的 KV Cache

SWA 的 KV cache 公式与标准 Attention 完全相同——window 只是限制了计算时"看多远",不影响"存多少"。KV cache 仍然需要缓存全部历史 token:

$$M_{\text{kv}}^{\text{SWA}} = 2 \times L \times H_{kv} \times D_h \times T \times \text{bytes}$$

计算时只取最后 $W$ 个 token 参与注意力。这意味着 SWA 在长上下文推理时,KV cache 显存与 Full Attention 完全相同,仅计算量有节省。

对比:如果 $T = 1\text{M}$,$W = 131\text{K}$,KV cache 按 $T$ 存(~120 GiB for M3 的 GQA 配置),但 FLOPs 按 $W$ 算(~2.15 GFLOPs/layer vs ~17.2 GFLOPs/layer for Full Attn)。SWA 的定位是"省计算不省显存"。

4.6 Gated DeltaNet / Linear Attention 的状态空间

Gated DeltaNet 没有传统 KV cache——它用一个固定大小的矩阵 $S \in \mathbb{R}^{H \times D_h \times D_h}$ 替代:

$$M_{\text{state}}^{\text{DeltaNet}} = L \times H \times D_h^2 \times \text{bytes\_per\_elem}$$

以 Qwen3.5-MoE 为例($L$ 层,$H = 64$,$D_h = 128$,BF16):$L \times 64 \times 128^2 \times 2 = L \times 2.1\text{MB}$。假设 $L = 48$:$48 \times 2.1\text{MB} \approx 100\text{MB}$。

对比 Attention 的 KV cache($T = 1\text{M}$):$2 \times 48 \times H_{kv} \times 128 \times 1\text{M} \times 2$——即使 $H_{kv} = 2$(极端 GQA)也是 $2 \times 48 \times 2 \times 128 \times 10^6 \times 2 \approx 49\text{GB}$。差距约 500×

DeltaNet 和 Mamba-2 的选择差异:DeltaNet 的状态是 $O(H \times D_h^2)$——矩阵形状的。Mamba-2 的状态是 $O(H \times N)$——向量形状的,$N \ll D_h$。DeltaNet 的"记忆"更丰富(矩阵可以存更多信息),但代价是状态更新($O(D_h^2)$)比 Mamba-2 的状态传递($O(N^2)$)更贵。这是计算-记忆的 trade-off。

4.7 无 KV Cache 的架构:Mamba-2

4.5.1 这节算什么

Mamba-2(State Space Duality)用固定大小的循环状态替代随序列长度线性增长的 KV cache。本节量化 Mamba 的状态开销,并与 Attention 的 KV cache 做对比。

4.5.2 为什么重要

Mamba 代表了「彻底消除 KV cache」的架构方向。理解 Mamba 的状态开销是评估混合架构(如 Nemotron 3 Ultra = 48 Mamba-2 + 12 Attention)显存优势的前提。

4.5.3 状态空间模型的状态

Mamba-2 的循环递推形式为:

$$h_t = A_t h_{t-1} + B_t x_t$$


$$y_t = C_t h_t + D x_t$$

其中隐状态 $h_t \in \mathbb{R}^{H_{ssm} \times d_{state}}$。对于 Nemotron 3 Ultra:

  • $H_{ssm} = 256$(256 个 SSD head)
  • $d_{state} = 128$(每 head 的状态维度)

每层状态大小(与序列长度无关):

$$\text{State size per layer} = H_{ssm} \times d_{state} \times \text{bytes\_per\_elem}$$

代入:

$$= 256 \times 128 \times 4 \text{ bytes (FP32 cache)} = 131{,}072 \text{ bytes} = 128 \text{ KiB}$$

48 层 Mamba-2 总状态:

$$48 \times 131{,}072 = 6{,}291{,}456 \text{ bytes} \approx 6.0 \text{ MiB}$$

4.5.4 对比:Mamba 状态 vs Attention KV Cache

在 $T = 1\text{M}$ 上下文下:

架构存储与 $T$ 的关系
12 层 Attention (GQA 32:1)12.0 GiB$\propto T$
48 层 Mamba-26.0 MiB常数(与 $T$ 无关)
60 层全 Attention (MHA 64 heads)1,758 GiB$\propto T$

Mamba-2 的状态仅为 12 层 Attention KV cache 的约 1/2000。这就是混合架构(如 Nemotron 3 Ultra)的核心推理效率优势:Mamba-2 层以恒定大小的循环状态替代了 KV cache,使长上下文推理的显存开销主要由少量的 Attention 层决定。

KV Cache 自查清单(算完后对照):

  • 公式中的 ×2 是 K+V 各一份?不是 ×4?
  • GQA 用 H_kv(不是 H_q)?KV head 数少了显存就省了?
  • MLA 的 c_t^{KV} 同时编码 K_nope 和 V → 不需要 ×2?
  • MLA 的 k_rope 维度 = H × qk_rope_head_dim(不是 H_kv × head_dim)?
  • Mamba 层没有 KV cache → 仅 Attention 层计入?
  • 你的数在合理范围吗?256K 时全 MHA ~数百 GiB,MLA ~20 GiB,Mamba-2 <10 MB?

4.5.5 直觉理解

  • 「看书 vs 记笔记」:Attention 是把整本书的每一页都摊在桌上(KV cache $\propto T$),Mamba 是看完一页记一行笔记(固定大小的状态)
  • 「状态是压缩的上下文」:128 维的状态向量是前文所有信息的压缩表示——信息量有限但足以支撑后续推理
  • 「代价是信息损失」:Mamba 的固定状态必然丢失细节——这就是为什么 Nemotron 保留了 12 层 Attention(周期性全局交互补充 Mamba 丢失的长程细节)

4.8 视觉 Token 的 KV Cache 增量

4.6.1 这节算什么

多模态模型(M3、K2.5)中,图像和视频 token 也需要 KV cache。本节量化视觉 token 对 KV cache 的额外贡献。

4.6.2 为什么重要

一张高分辨率图像(如 M3 的 576 visual tokens)在长上下文推理中可能占据显著的 cache 份额。如果输入包含多张图片或视频帧,视觉 token 的 cache 增量不可忽略。

4.6.3 计算公式

视觉 token 对 KV cache 的增量与文本 token 使用完全相同的公式,只是 $T$ 增加了视觉 token 数量:

$$\Delta \text{KV Cache}_{visual} = L_{attn} \times 2 \times H_{kv} \times D \times T_{visual} \times \text{bytes\_per\_elem}$$

对于 M3(GQA 16:1, $H_{kv}=4$, $D=128$, BF16),1 张图(576 visual tokens):

$$\Delta_{1\_image} = 60 \times 2 \times 4 \times 128 \times 576 \times 2 = 60 \times 1{,}024 \times 576 \times 2$$$$= 60 \times 1{,}179{,}648 = 70{,}778{,}880 \text{ bytes} \approx 66.0 \text{ MiB}$$

10 张图:$\approx 659 \text{ MiB}$。100 张图:$\approx 6.6 \text{ GiB}$。

对于 K2.5 MLA($L=61$, kv_lora_rank=512, qk_rope_head_dim=64, $T_{visual} = 1024$ per image):

$$\Delta_{1\_image} = 61 \times (512 + 64) \times 1024 \times 2 = 61 \times 576 \times 1024 \times 2$$$$= 61 \times 1{,}179{,}648 = 71{,}958{,}528 \text{ bytes} \approx 67.0 \text{ MiB}$$

注意:MLA 压缩后,每视觉 token 的 cache 增量为 1,152 bytes(vs 标准 GQA 的 2,048 bytes),单张图差异不大,但在大量图片的场景下 MLA 的优势会累积。


4.9 完整案例对比

4.7.1 三个模型的全量 KV Cache 表

模型架构$L_{attn}$KV 公式关键参数256K Cache1M Cache
Kimi K2.5MLA (全 MHA)61$L \times (lora + d_{rope}) \times T \times 2$lora=512, drope=64~17 GiBN/A(不支持 1M)
Nemotron 3 UltraGQA + Mamba12$L \times 2 \times H_{kv} \times D \times T \times 2$H_kv=2, D=128~3 GiB~12 GiB
MiniMax M3MSA + GQA60 (+57 index)$L \times 2 \times H_{kv} \times D \times T \times 2$ + indexH_kv=4, D=128~30 + 3.6 GiB~120 + 14 GiB
假设纯 Full Attn 60 层MHA60$L \times 2 \times H_{kv} \times D \times T \times 2$H_kv=64, D=128~440 GiB~1,758 GiB

4.7.2 这张表告诉我们什么

  1. 架构选择直接决定部署可行性。纯 Full Attention 60 层模型在 1M 上下文需要 1.76 TiB KV cache——没有任何单 GPU 可以承载。而 Nemotron 3 Ultra 仅需 12 GiB(约 1/150),M3 需 134 GiB(约 1/13)。

  2. MLA 是当前 KV cache 压缩最强的 Attention 方案。K2.5 的 MLA 实现了 ~34× 压缩——仅用 ~21.5 GiB 就支撑了 61 层全 MHA 的 256K 上下文。作为对比,若不用 MLA(纯 MHA),同样配置需要 ~732 GiB。采用正确的 K 维度(192 = 128+64)计算。

  3. Mamba-2 是消除 KV cache 的根本方案。Nemotron 的 48 层 Mamba-2 仅需 6 MiB 状态存储(与序列长度无关),而 12 层 Attention 在 1M 时需要 12 GiB。混合架构的本质是用少量 Attention 层换取全局交互能力,用大量 Mamba 层换取 KV-cache-free 的长程编码。

  4. MSA 是「半方案」——它有效减少计算(decode 加速 30 倍),但不减少存储。M3 的 1M KV cache 高达 134 GiB,仍是部署瓶颈。将 MSA 与 KV cache 量化(FP8/INT4)或 token eviction 结合是自然的演进方向。

4.7.3 各架构 KV Cache 增长曲线(概念性公式)

架构KV Cache 复杂度116K 典型值1M 典型值
全 MHA (60 层)$O(L \cdot T)$~220 GiB~1,758 GiB
GQA 16:1 (60 层)$O(L \cdot T / R_{GQA})$~30 GiB~120 GiB
MLA (61 层, K2.5)$O(L \cdot T / R_{MLA})$~8 GiB~67 GiB
Mamba-2 (48 层)$O(L)$ — 常数~6 MiB~6 MiB
混合 (12 Attn + 48 Mamba)$O(L_{attn} \cdot T)$ + 常数~3 GiB~12 GiB

4.7.4 工程结论

在部署长上下文 LLM 时,KV cache 的架构选择遵循以下优先级:

  1. 如果任务不需要完美 recall:Mamba-heavy 混合架构(如 Nemotron 3 Ultra)是最优解——极致 GQA + 最少 Attention 层
  2. 如果需要高精度长程 attention:MLA 优于纯 GQA——同样 KV head 数下,MLA 通过低秩压缩再省 10-30 倍
  3. 如果需要白盒一致性和全 attention 质量:MSA 减少计算但需承受全量 KV cache 存储——适合计算瓶颈而非显存瓶颈的场景
  4. KV cache 量化(FP8/INT4)是通用的叠加优化:可与上述任何架构组合使用,通常再压缩 2-4 倍

4.10 公式速查表

公式适用架构说明
$L \times 2 \times H_{kv} \times D \times T \times \text{bpe}$MHA / GQA标准 KV cache,$\times 2$ 来自 K+V
$L \times (\text{kv_lora_rank} + \text{qk_rope_head_dim}) \times T \times \text{bpe}$MLA潜向量 $\mathbf{c}_t^{KV}$ + 共享 RoPE key $\mathbf{k}^R$
$L_{MSA} \times H_{idx_k} \times D_{idx} \times T \times \text{bpe}$MSA (Index)额外的 Index K cache
$L_{ssm} \times H_{ssm} \times d_{state} \times \text{bpe}$Mamba-2固定大小,与 $T$ 无关
$\text{bpe}$bytes per element: BF16=2, FP32=4, FP8=1, INT4=0.5


CH 5 推理显存 & CH 6 完整实战推演

读者定位:已掌握 CH 1-2(config.json 读取 + 参数分解)和 CH 3-4(FLOPs 估算 + KV Cache)的工程师,目标是从参数/FLOPs/KV Cache 出发,计算任意模型在给定硬件上的推理部署方案。


CH 5 | 推理显存——「部署需要多少卡」

计量约定:本章显存估算使用 GiB(1024³ bytes),贴近 GPU 硬件规格。1 GiB ≈ 1.074 GB。

5.1 显存预算的三部分

建立推理显存的三要素分解框架。算完 FLOPs 只知道"算得动吗",算完显存才知道"装得下吗"——后者往往是真正的瓶颈,因为模型权重在推理期间必须常驻显存。

推理一块 GPU 需要同时装下三样东西:

$$\text{Total Memory} = \underbrace{M_{\text{weights}}}_{\text{模型权重}} + \underbrace{M_{\text{kv}}}_{\text{KV Cache}} + \underbrace{M_{\text{act}}}_{\text{激活 + 临时缓冲}}$$

三者的比例关系随模型架构不同变化巨大。以下是一个典型 MoE 模型(如 Nemotron 550B)在 1M 上下文、BF16 推理时的显存分配比例(ASCII 图):

Total ∼1,128 GiB (8×H200)
┌──────────────────────────────────────────────────────────────────┐
│██████████████████████████████████████████████████████████████    │  Weights: ∼1,100 GiB (97.5%)
│KV Cache: ∼13 GiB (1.2%)                                           │
│Act+Overhead: ∼15 GiB (1.3%)                                       │
└──────────────────────────────────────────────────────────────────┘

而同一个 1,128 GiB 池子上,M3 BF16 推理的显存分配:

Total ∼1,005 GiB (per sample, 1M context)
┌──────────────────────────────────────────────────────────────────┐
│████████████████████████████████████████████████████████          │  Weights: ∼856 GiB (85%)
│██████████████████████                                            │  KV Cache: ∼144 GiB (14.3%)
│Act: ∼5 GiB (0.5%)                                                 │
└──────────────────────────────────────────────────────────────────┘

Nemotron 的 Attention 层只有 12 层且 GQA 32:1 极度压缩 KV Cache,所以 KV Cache 占比极小;M3 有 60 层全部存 KV Cache(包括 MSA Index K),在 1M 上下文下 KV Cache 膨胀到权重的 ~17%。架构差异直接导致显存瓶颈的转移——Nemotron 是纯权重瓶颈,M3 是权重+KV Cache 双瓶颈。


5.2 权重显存

从总参数量直接换算权重占用的显存。这是显存预算的最大头,也是最容易算的部分——总参 × 精度字节数。

公式

$$M_{\text{weights}} = N_{\text{total}} \times \text{bytes\_per\_param}$$

按精度的换算表

精度bytes/param550B 模型需要428B 模型需要
FP3242,200 GiB1,712 GiB
BF16 / FP1621,100 GiB856 GiB
FP8 (E4M3)1550 GiB428 GiB
INT4 / NVFP40.5275 GiB214 GiB

:本文中 GiB 指 decimal GiB($10^9$ bytes),与 GPU 厂商(NVIDIA H200 标称 141 GiB)的 marketing 单位一致。如需二进制 GiB($2^{30}$ bytes),乘以 $10^9 / 2^{30} \approx 0.931$。

案例 1:Nemotron 3 Ultra(550B)

BF16 推理:

$$M_{\text{weights}} = 550 \times 10^9 \times 2 = 1.1 \times 10^{12} \text{ bytes} = \mathbf{1{,}100 \text{ GiB}}$$

换成 FP8 量化:

$$M_{\text{weights}} = 550 \times 10^9 \times 1 = 5.5 \times 10^{11} \text{ bytes} = \mathbf{550 \text{ GiB}}$$

从 1,100 GiB 降到 550 GiB,可以直接从"必须 8 卡"变为"4 卡可行"(4 × 141 = 564 GiB)。

BF16 下,每 1B 参数 = 2 GiB 显存。550B = 1,100 GiB,428B = 856 GiB。这个换算可以心算:参数量(B)× 2 = 显存(GB,BF16 下)。

案例 2:MiniMax M3(~428B)

BF16 推理:

$$M_{\text{weights}} = 428 \times 10^9 \times 2 = \mathbf{856 \text{ GiB}}$$

FP8:

$$M_{\text{weights}} = 428 \times 10^9 \times 1 = \mathbf{428 \text{ GiB}}$$

案例 3:Kimi K2.5(~1T)

BF16 推理(如果全量加载):

$$M_{\text{weights}} = 1{,}000 \times 10^9 \times 2 = \mathbf{2{,}000 \text{ GiB}} \approx 2 \text{ TB}$$

需要 $\lceil 2000 / 141 \rceil = 15$ 张 H200 才能装下 BF16 权重。实际部署中 K2.5 使用 FP8 量化(1,000 GiB ≈ 8 卡)或 INT4(500 GiB ≈ 4 卡)。

MoE 的权重加载特殊性

上述计算假设所有权重全部驻留在显存中(全量加载)。这是标准推理部署的做法——即使 MoE 每 token 只激活 $k/E$ 的专家,所有 $E$ 个专家的权重仍需在显存中,因为不同 token 激活不同专家。

但存在一种"按需加载"策略:只将当前 batch 需要的专家权重换入显存,不需要的留在 CPU 或 NVMe 上。这种策略的显存占用为:

$$M_{\text{weights}}^{\text{on-demand}} = M_{\text{non-MoE}} + \overbrace{k_{\text{batch}} \times M_{\text{per-expert}}}^{\text{仅加载被命中的专家}}$$

其中 $k_{\text{batch}}$ 是整个 batch 激活的不同专家数(不是 $k$,因为 batch 中不同 token 可能命中不同专家,总的命中专家数随 batch size 增大而增大)。

按需加载的优势是省显存,代价是延迟不可预测(换入专家需要 PCIe/NVLink 带宽)。目前生产部署几乎不使用按需加载——延迟的不可预测性是服务级推理不能接受的。


5.3 KV Cache 显存

从 KV Cache 的公式化计算出发,给出 per-sample 和 per-batch 的显存占用量。KV Cache 与序列长度成线性正比。在 1M 上下文下,它可能膨胀到与权重同量级。

核心公式(沿用 CH 4)

标准 GQA:

$$M_{\text{kv}}^{(1)} = L \times 2 \times H_{kv} \times D_h \times T \times \text{bytes\_per\_elem}$$

其中:

  • $L$:层数
  • $2$:K 和 V 两份
  • $H_{kv}$:KV 头数
  • $D_h$:每头维度
  • $T$:序列长度(cached tokens)
  • $\text{bytes_per_elem}$:BF16=2,FP8=1

每一层有两个缓存矩阵(K 和 V),每个形状是 $H_{kv} \times T \times D_h$(GQA 下 KV 头数少于 Q 头数,矩阵较窄)。60 层 × 2 份 × 4 头 × 128 维 × 1M token × 2 字节 = 60 × 2 × 4 × 128 × 10^6 × 2 ≈ 123 GiB。记法:每层 KV Cache ≈ $2 \times H_{kv} \times D_h \times T \times 2$ bytes。

针对不同模型架构的扩展

MLA(Kimi K2.5):KV Cache 只存压缩后的潜向量,不存展开后的全维度 K/V。公式变为:

$$M_{\text{kv}}^{\text{MLA}} = L \times (d_{kv} + D_{rope}) \times T \times \text{bytes\_per\_elem}$$

其中 $d_{kv}$ 是 KV 压缩维度,$D_{rope}$ 是 RoPE 分量(不可压缩,必须单独存储)。K2.5 中 $d_{kv}=512$,$D_{rope}=64$,合计 $576$ 维。对比标准 MHA($64 \times (192 + 128) = 20{,}480$ 维),MLA 的 KV Cache 维度压缩了 ~34 倍

MSA(MiniMax M3):额外存储 Index K Cache:

$$M_{\text{kv}}^{\text{MSA}} = M_{\text{kv}}^{\text{main}} + L_{\text{MSA}} \times H_{\text{idx\_k}} \times D_{\text{idx}} \times T \times \text{bytes\_per\_elem}$$

其中 $M_{\text{kv}}^{\text{main}}$ 与标准 GQA 的公式完全相同(MSA 不减少 KV Cache 存储——稀疏性体现在计算而非存储),$H_{\text{idx_k}}=1$(Index K 只有 1 个头),$D_{\text{idx}}=128$,$L_{\text{MSA}}=57$。

Mamba-2(Nemotron):没有传统 KV Cache。但每层维护一个 SSM 隐状态(conv state + ssm state),维度为 $\approx H_{mamba} \times D_{mamba} \times N = 256 \times 64 \times 128 = 2.1\text{M}$ 元素,48 层合计约 0.2 GiB——可忽略。

案例 1:MiniMax M3,BF16,T=1M

Main KV Cache(60 层,GQA 16:1):

$$\begin{aligned} M_{\text{kv}}^{\text{main}} &= 60 \times 2 \times 4 \times 128 \times 1{,}048{,}576 \times 2 \\ &= 60 \times 2 \times 4 \times 128 \times 1{,}048{,}576 \times 2 \\ &= 128{,}849{,}018{,}880 \text{ bytes} \\ &\approx \mathbf{128.8 \text{ GiB}} \end{aligned}$$

Index K Cache(57 层 MSA):

$$\begin{aligned} M_{\text{kv}}^{\text{index}} &= 57 \times 1 \times 128 \times 1{,}048{,}576 \times 2 \\ &= 15{,}300{,}329{,}472 \text{ bytes} \\ &\approx \mathbf{15.3 \text{ GiB}} \end{aligned}$$

M3 KV Cache 总计(per sample, 1M, BF16):$\approx 128.8 + 15.3 = \mathbf{144.1 \text{ GiB}}$

案例 2:Nemotron 3 Ultra,BF16,T=1M

仅 12 层 Attention(GQA 32:1,$H_{kv}=2$,$D_h=128$):

$$\begin{aligned} M_{\text{kv}}^{\text{Nemotron}} &= 12 \times 2 \times 2 \times 128 \times 1{,}048{,}576 \times 2 \\ &= 12{,}884{,}901{,}888 \text{ bytes} \\ &\approx \mathbf{12.9 \text{ GiB}} \end{aligned}$$

48 层 Mamba 的 SSM 状态约 $\approx 0.2 \text{ GiB}$——总计约 13.1 GiB

Nemotron 的 KV Cache 比 M3 小 11 倍,尽管总参数更大(550B vs 428B)。这就是"尽量不用 Attention"架构策略的显存红利。

案例 3:DeepSeek V4 Flash(MLA),T=1M

MLA 下 KV Cache per layer = $(d_{kv} + D_{rope}) \times T \times 2 = 576 \times 1{,}048{,}576 \times 2 \approx 1.21 \text{ GiB}$。60 层:$\approx 72.4 \text{ GiB}$。对比同尺寸 GQA 模型的 ~144 GiB,MLA 直接砍半。

Batch 效应

KV Cache 是 per-sample 的。batch_size=100 就是 100 倍。这是推理并发的主要瓶颈——权重可以跨 batch 共享,但每个请求需要自己独立的 KV Cache:

$$M_{\text{kv}}^{\text{total}} = B \times M_{\text{kv}}^{(1)}$$

5.4 激活值与临时缓冲

估算前向传播中激活值和临时 buffer 的显存。虽然通常不到权重的 5%,但在规划显存预算时必须留出这部分余量,否则 OOM。

激活值显存来自三个方面:

  1. 残差流:每层前向传播时,hidden_states $\in \mathbb{R}^{B \times S \times d}$ 在 layer 间传递。BF16 下 per token per layer = $d \times 2$ bytes = 12 KB(d=6144)。
  2. 注意力中间结果:Q、K、V、attn_weights 等临时张量。在 decode 阶段($S_{\text{new}}=1$),这些非常小(< 1 MB/layer)。
  3. MoE 中间结果:4 个路由专家的 gate_up 输出(4 × $d_{ff} \times 2$ bytes)。

估算经验值

对于 decode 阶段,激活值显存经验公式:

$$M_{\text{act}} \approx 0.05 \times M_{\text{weights}} \quad \text{(上限经验值)}$$

更精确的逐模块估算:

组件per-token per-layer×60 layers (M3)
残差流 (hidden_states)12 KB (d=6144, BF16)0.72 MB
Attention activations (Q/K/V/attn)~500 KB~30 MB
MoE 4-expert activations~48 KB (4 × 3072 × 2B)~2.9 MB
Per-token sum~0.56 MB~33.6 MB

对于 M3,per-token 激活值约 34 MB。加上框架开销(PyTorch allocator、cuBLAS workspace 等)约 2-5 GiB。

总显存经验公式

$$M_{\text{total}} \approx 1.05 \sim 1.10 \times (M_{\text{weights}} + M_{\text{kv}}^{\text{total}})$$

即总显存大约比"权重 + KV Cache"多 5%~10%。这在显存规划中作为安全余量使用。


5.5 MoE 的专家加载策略

对比 MoE 在全量加载和按需加载两种策略下的显存-性能 trade-off。MoE 占模型参数的 90%+,显存策略的选择直接决定了最低 GPU 数量。

策略 A:全量 Expert 加载(标准做法)

所有 $E$ 个专家的权重始终在显存中。无论 router 选哪个专家,计算是即时的。

  • 显存需求:$E \times \text{Params}_{\text{expert}} \times \text{bytes}$
  • 延迟:可预测,低延迟
  • 并行:通过 EP(Expert Parallelism)将专家分布到多卡,每卡只加载分配给它的专家切片

策略 B:按需 Expert 加载(实验性)

只在 router 选中后才将对应专家权重从 CPU/NVMe 加载到 GPU。

  • 显存需求:$\approx \text{Params}{\text{non-MoE}} + \text{Params}{\text{avg loaded experts}}$,远小于全量
  • 延迟:不可预测——首次 access 需等待 PCIe 传输(~50 GiB/s),远慢于 HBM(~3 TB/s)
  • 适用场景:极端显存受限的离线批处理,不适合在线服务

Nemotron 512 experts 的极端案例

Nemotron 单独专家部分的 BF16 权重:

$$\begin{aligned} \text{Params}_{\text{all experts}} &= 48 \text{ layers} \times 512 \text{ experts} \times (2 \times 2048 \times 5120) \text{ params} \\ &\approx 48 \times 512 \times 21\text{M} = 48 \times 10.74\text{B} = 515.5\text{B} \\ M_{\text{experts only}} &= 515.5 \times 10^9 \times 2 \text{ bytes} = 1{,}031 \text{ GiB} \approx \mathbf{1.03 \text{ TB}} \end{aligned}$$

仅专家权重就超过 1 TB——比总参数(550B × 2 = 1,100 GiB)的 94% 都在专家上。这就是为什么 EP 对 MoE 模型不是"可选的优化"而是"部署的前提条件"。

512 个专家每个 ~21M 参数,48 层,BF16 → 约 1 TB。8 张 H200 每张装 1/8 的专家(EP=8),每卡专家部分约 129 GiB,加上非 MoE 参数(约 35 GiB),刚好塞进 141 GiB 的 H200。没有 EP,即使 16 张 H200 也装不下所有专家复本。


5.6 并行策略的影响(概念级)

解释 TP/PP/EP 三种并行策略如何改变每张 GPU 的实际显存负载。部署计算不是"总显存 / 卡数",不同并行策略按不同维度切分显存。

Tensor Parallelism (TP) —— 切分矩阵乘法

TP 将单个矩阵乘法的权重按列(column-wise)或行(row-wise)切分到 $N$ 张卡。

  • 每卡权重 = $\text{总权重} / N$
  • 代价:每层需要两次 all-reduce 通信(前向 + 反向),通信量与 hidden_size 成正比
  • 适用场景:单层矩阵太大,单卡装不下时

案例:M3 的 Q 投影矩阵 $W_Q \in \mathbb{R}^{6144 \times 8192}$,BF16 下 100.7 MB。单卡轻松装下,不需要 TP。但如果是 1T 参数模型 hidden=16384,$W_Q \in \mathbb{R}^{16384 \times 32768}$ 约 1 GiB——单个矩阵就接近极限。

Pipeline Parallelism (PP) —— 按层切分

PP 将不同层放到不同 GPU。GPU 0 管层 0-14,GPU 1 管层 15-29,以此类推。

  • 每卡权重 $\approx \text{总权重} / N$(但不均衡——MoE 层比 Attention 层重一个数量级)
  • 代价:流水线 bubble(GPU 空闲等待前一级完成);通信仅在 stage 边界
  • 适用场景:层数多、单层内存适中的模型

注意:PP 不能解决"单层太大装不下"的问题——如果 MoE 单层有 7.3B 权重(M3),BF16 下约 14.6 GiB,单卡完全装得下。PP 解决的是"60 层加起来装不下"。

Expert Parallelism (EP) —— 按专家切分(MoE 专用)

EP 是最适合 MoE 模型的并行策略。其核心思想:不同 GPU 持有不同的专家子集,token 通过 all-to-all 通信被路由到持有对应专家的 GPU。

  • 每卡装的专家数 = $E / \text{EP_size}$
  • 每卡专家权重 = $\text{总专家权重} / \text{EP_size}$
  • 代价:token dispatch 和 combine 需要 all-to-all 通信(仅 MoE 层,非所有层)

Nemotron on 8×H200:EP=8,每卡装 512/8 = 64 个专家。每卡专家权重 = $64 \times 48 \times 21\text{M} \times 2 \text{ bytes} \approx 129 \text{ GiB}$。加上非 MoE 参数(Mamba + Attention + Embedding 等)约 35 GiB,总计约 164 GiB——但 H200 只有 141 GiB!

这就引出了一个关键计算。需要检查 8×H200 是否真的够:

$$\begin{aligned} \text{Per-card non-expert} &= (N_{\text{total}} - N_{\text{experts}}) / \text{cards} \\ &\approx (550 - 515.5) / 8 = 4.31 \text{ B} \\ M_{\text{non-expert per card}} &= 4.31 \times 10^9 \times 2 = 8.63 \text{ GiB} \end{aligned}$$$$\begin{aligned} \text{Per-card experts} &= (515.5 \times 10^9 \times 2) / 8 = 128.9 \text{ GiB} \end{aligned}$$$$\text{Per-card total} \approx 8.6 + 128.9 = 137.5 \text{ GiB}$$

137.5 GiB < 141 GiB ——勉强能装下。但如果加上 KV Cache(per sample ~13 GiB / 8 ≈ 1.6 GiB per card if distributed)和激活值,余量非常紧张。

这个计算说明了为什么部署计算不能只看"总显存够不够":并行策略决定了每张卡实际装载的权重分布。

简单部署公式

当只考虑权重显存时的最简估算:

$$\text{Cards}_{\text{min}} = \left\lceil \frac{M_{\text{weights}}}{\text{Per-card memory}} \right\rceil$$

Nemotron BF16:$\lceil 1100 / 141 \rceil = 8$ 张 H200。
M3 BF16:$\lceil 856 / 141 \rceil = 7$ 张 H200(但实际需要 8 张,因为还要考虑 KV Cache batch 效应和 EP 要求专家数可被 EP 大小整除)。


5.7 完整案例:Nemotron 550B on 8×H200

综合运用 5.2-5.6 的知识,做一次完整的部署方案推算。这就是面试中"这个模型需要多少卡"类问题的标准回答模板。

已知条件

  • 模型:Nemotron 3 Ultra,550B 总参,BF16 推理
  • 硬件:8 × NVIDIA H200(141 GiB/card,合计 1,128 GiB)
  • 上下文:1M tokens
  • 架构特征:12 层 Attention(GQA 32:1)+ 48 层 Mamba-2 + 48 层 LatentMoE(512E, top-22)

Step 1:权重显存

$$M_{\text{weights}} = 550 \times 10^9 \times 2 = \mathbf{1{,}100 \text{ GiB}}$$

Step 2:KV Cache(per sample)

$$M_{\text{kv}}^{(1)} = 12 \times 2 \times 2 \times 128 \times 1{,}048{,}576 \times 2 = \mathbf{12.9 \text{ GiB}}$$

(Mamba 层 SSM 状态约 0.2 GiB,计入 act/overhead)

Step 3:可用显存

$$M_{\text{available}} = 1{,}128 - 1{,}100 = \mathbf{28 \text{ GiB}} \quad (\text{8 卡合计})$$

这 28 GiB 是留给 KV Cache + 激活值 + 框架开销的全部余量。

Step 4:Max Batch Size

每个样本消耗的 KV Cache + 激活值:

$$M_{\text{per sample}} = M_{\text{kv}}^{(1)} + M_{\text{act}}^{(1)} \approx 12.9 + 2 = \mathbf{14.9 \text{ GiB}}$$$$B_{\text{max}} = \left\lfloor \frac{28}{14.9} \right\rfloor = \left\lfloor 1.88 \right\rfloor = \mathbf{1 \sim 2 \text{ samples}}$$

更现实地说,max_batch_size = 1(留安全余量给框架开销和 NCCL buffer):

  • batch=1:$1{,}100 + 12.9 + 2 \approx 1{,}115 \text{ GiB} < 1{,}128 \text{ GiB}$ ✓
  • batch=2:$1{,}100 + 25.8 + 4 \approx 1{,}130 > 1{,}128 \text{ GiB}$ ✗(接近极限,可能 OOM)

Step 5:若使用 FP8 权重

$$M_{\text{weights}}^{\text{FP8}} = 550 \times 10^9 \times 1 = \mathbf{550 \text{ GiB}}$$$$M_{\text{available}}^{\text{FP8}} = 1{,}128 - 550 = \mathbf{578 \text{ GiB}}$$$$B_{\text{max}}^{\text{FP8}} = \left\lfloor \frac{578}{14.9} \right\rfloor \approx \mathbf{38 \text{ samples}}$$

从 batch=1 到 batch=38——FP8 将 Nemotron 从一个"勉强能跑"的模型变成一个"可以服务"的模型。

汇总表

精度权重 (GB)KV Cache/样本 (GB)可用 (GB, 8卡)Max Batch
BF161,10012.9281
FP855012.957838
FP8 KV + FP8 W5506.557876
INT4 / NVFP427512.985357
INT4 W + FP8 KV2756.5853115

Nemotron 在 BF16 下是"纯权重瓶颈"——KV Cache 几乎不占什么(只要 13 GiB),但 1.1 TiB 的 BF16 权重把 8 卡池子塞满了 97.5%。FP8 一开,权重减半,同一个池子马上可以跑几十个并发请求。这就是量化在部署中的价值:它解决的是权重显存瓶颈,不是 FLOPs 瓶颈。


CH 6 | 实战——MiniMax M3 完整推演

以 MiniMax M3 为目标,从 config.json 出发,完整推演参数分解 → FLOPs 估算 → KV Cache → 推理显存 → 部署方案,覆盖 GQA + MSA + MoE + Vision + MTP 五种架构变体的计算。M3 是目前覆盖计算变体最多的开源模型——一个模型练完基本上所有架构你都会算了。

本章使用的前置知识(如果你是跳读的,这些概念在这里能找到定义):

  • FLOPs = 2×m×n×k 及 MAC 概念 → CH 1.2
  • GQA 中 K/V 投影「变窄」→ CH 2.3.2(⚠️ 注意 H_kv × D_h ≠ d
  • SwiGLU 的 3 矩阵结构(gate/up/down)→ CH 2.4.2
  • 激活参 vs 总参(MoE 中每个 token 只激活 top-k 专家)→ CH 2.9
  • MSA 的 Index Branch 机制 → CH 3.3
  • MLA/标准 KV cache 公式 → CH 4.2CH 4.3

6.1 从 config.json 出发

打开 MiniMax-M3config.json,提取以下核心字段(text_config 为主,vision_config 为辅):

字段含义
hidden_size6144残差流维度 $d$
num_hidden_layers60总层数 $L$
num_attention_heads64Q 头数 $H_q$
num_key_value_heads4KV 头数 $H_{kv}$
head_dim128每头维度 $D_h$
vocab_size200,064词表大小 $V$
rope_theta5,000,000RoPE 基频
partial_rotary_factor0.5rotary_dim = 0.5 × 128 = 64
num_local_experts128路由专家数 $E$
num_experts_per_tok4每 token 激活专家 $k$
n_shared_experts1共享专家
intermediate_size3072MoE 专家中间维 $d_{moe_ff}$
dense_intermediate_size12288Dense FFN 中间维(前 3 层)
shared_intermediate_size3072共享专家中间维
scoring_funcsigmoid路由评分函数
sparse_block_size128MSA block 大小
sparse_topk_blocks16每 query 选择 top-k blocks
sparse_num_index_heads4Index heads 数
sparse_index_dim128Index head_dim
sparse_disable_index_value[0,0,0,1,…1]层 0-2: Full Attn, 层 3-59: MSA
moe_layer_freq[0,0,0,1,…1]层 0-2: Dense FFN, 层 3-59: MoE
vision_config.hidden_size1280ViT 隐藏维度
vision_config.num_hidden_layers32ViT 层数
vision_config.num_attention_heads16ViT 头数
vision_config.patch_size14Patch 大小
vision_config.image_size2016输入图像尺寸
num_mtp_modules7MTP 模块数
max_position_embeddings1,048,576最大上下文 1M

层类型分配

层范围Attention 类型FFN 类型层数
0-2Full Attention (GQA 16:1)Dense FFN (SwiGLU-OAI, $d_{ff}=12288$)3
3-59MSA Sparse AttentionMoE (128E, top-4, sigmoid)57

6.2 参数分解

以下按模块逐一计算,所有数值均从 6.1 节的 config.json 字段推导。

Embedding 层

$$N_{\text{embed}} = V \times d = 200{,}064 \times 6144 = 1{,}229{,}193{,}216 \approx \mathbf{1.229\text{B}}$$

tie_word_embeddings=false → 输入 Embedding + 输出 LM Head 各一份:

$$N_{\text{embed+head}} = 2 \times 1.229\text{B} = \mathbf{2.458\text{B}}$$

Attention 模块(per layer, Full Attn / MSA 共享)

Q 投影:$d \times H_q \times D_h = 6144 \times 64 \times 128 = 50{,}331{,}648 \approx 50.3\text{M}$
K 投影:$d \times H_{kv} \times D_h = 6144 \times 4 \times 128 = 3{,}145{,}728 \approx 3.1\text{M}$
V 投影:$d \times H_{kv} \times D_h = 3{,}145{,}728 \approx 3.1\text{M}$
O 投影:$H_q \times D_h \times d = 64 \times 128 \times 6144 = 50{,}331{,}648 \approx 50.3\text{M}$

Per-layer Q/K/V/O 合计:$\approx \mathbf{107.0\text{M}}$

Indexer(仅 MSA 层 3-59,57 层)

Index Q 投影:$d \times H_{\text{idx}} \times D_{\text{idx}} = 6144 \times 4 \times 128 = 3{,}145{,}728 \approx 3.1\text{M}$
Index K 投影:$d \times 1 \times D_{\text{idx}} = 6144 \times 128 = 786{,}432 \approx 0.79\text{M}$
Index QK Norm:$2 \times (4 \times 128) + 2 \times 128 = 1{,}280$(可忽略)

Per-layer Indexer 合计:$\approx 3.93\text{M}$

Attention 总参

$$\begin{aligned} N_{\text{attn}} &= 3 \times 107.0\text{M} \quad \text{(层 0-2: Full Attn)} \\ &+ 57 \times (107.0\text{M} + 3.93\text{M}) \quad \text{(层 3-59: MSA + Indexer)} \\ &= 321.0\text{M} + 6{,}323.0\text{M} = \mathbf{6.644\text{B}} \end{aligned}$$

Dense FFN(层 0-2,SwiGLU-OAI,$d_{ff}=12288$)

Per layer(non-gated SwiGLU:gate_up 合并为 $6144 \to 2 \times 12288$):

$$N_{\text{gate\_up}} = 6144 \times 2 \times 12288 = 150{,}994{,}944$$


$$N_{\text{down}} = 12288 \times 6144 = 75{,}497{,}472$$

Per-layer 合计:$\approx 226.5\text{M}$。3 层汇总:$\mathbf{0.679\text{B}}$。

MoE 模块(层 3-59,57 层)

每个路由专家(SwiGLU-OAI,$d_{ff}=3072$):

$$N_{\text{expert}} = 6144 \times 2 \times 3072 + 3072 \times 6144 = 37{,}748{,}736 + 18{,}874{,}368 = 56{,}623{,}104 \approx 56.62\text{M}$$

每层 128 个路由专家

$$N_{\text{experts\_per\_layer}} = 128 \times 56.62\text{M} = 7{,}247{,}757{,}312 \approx 7.25\text{B}$$

共享专家(per layer, 1 个):

$$N_{\text{shared}} = 56.62\text{M} \quad (\text{维度与路由专家相同})$$

路由器(per layer):

$$N_{\text{router}} = d \times E = 6144 \times 128 = 786{,}432 \approx 0.79\text{M}$$

每层 MoE 合计:$7.25\text{B} + 0.057\text{B} + 0.0008\text{B} \approx 7.31\text{B}$

57 层 MoE 汇总:$57 \times 7.31\text{B} = \mathbf{416.6\text{B}}$

Vision(ViT + Projector)

ViT 32 层($d_{vit}=1280$, $H_{vit}=16$, $D_{vit}=80$, $d_{ff}^{vit}=5120$):

Per-layer Attention:$4 \times (1280 \times 16 \times 80) = 6.55\text{M}$
Per-layer MLP:$2 \times 1280 \times 5120 = 13.11\text{M}$
32 层合计:$32 \times 19.66\text{M} \approx 0.63\text{B}$
加 patch embedding + Pre-LN + 3D RoPE:$\approx \mathbf{0.65\text{B}}$

Projector(双阶段 MLP):

Stage 1:$1280 \times 6144 + 6144 \times 6144 \approx 45.6\text{M}$
Stage 2(spatial merge):$(4 \times 6144) \times 6144 + 6144 \times 6144 \approx 188.7\text{M}$
合计:$\approx \mathbf{0.23\text{B}}$

Vision 总计:$\mathbf{0.88\text{B}}$

汇总与自洽性验证

组件参数量 (B)占比
Embedding + LM Head2.4580.58%
Attention (Q/K/V/O × 60)6.4201.50%
Indexer (57 层 MSA)0.2240.05%
Dense FFN (3 层)0.6790.16%
MoE 路由专家 (128 × 57)413.2596.7%
MoE 共享专家3.2270.76%
MoE 路由器0.0450.01%
Vision (ViT + Projector)0.8800.21%
Norm 等~0.001~0%
直接求和~427.2100%
官方标称~428B

偏差 < 0.2%,自洽性验证通过。

一个 428B 参数的模型,96.7% 的参数在 MoE 专家里。Attention 只有 6.4B(1.5%)——所以"优化 Attention"(GQA、MSA、MLA)主要是优化计算量和 KV Cache,而不是参数量。参数量的主战场永远是 FFN/MoE。

激活参数

$$\begin{aligned} N_{\text{active}} &= N_{\text{embed}} + N_{\text{attn}} + N_{\text{dense\_ffn}} + N_{\text{shared}} + k \times N_{\text{expert}} \times 57 + N_{\text{router}} + N_{\text{head}} \\ &= 1.23 + 6.64 + 0.68 + 3.23 + (4/128) \times 413.25 + 0.045 + 1.23 \\ &= 1.23 + 6.64 + 0.68 + 3.23 + 12.91 + 0.045 + 1.23 \\ &\approx \mathbf{26.0\text{B}} \end{aligned}$$

加上 Vision 编码器(图像输入时激活 $\approx 0.88\text{B}$):$\approx 26.9\text{B}$。

官方标称 $\sim 23\text{B}$。差异可能来源:(1) Vision 编码器在纯文本推理时不激活;(2) 部分参数共享(如 non-gated SwiGLU 中 gate/up 共享投影可视为半激活)。

$$\text{激活率} = \frac{26}{428} \approx \mathbf{6.1\%}$$

6.3 FLOPs 估算(Decode, T=1M)

计算 M3 在 1M 上下文下 decode 单个 token 的 FLOPs,并对 MSA 和 Full Attention 做定量对比。理解 MSA 到底省了多少计算——不是省了几个百分点,而是省了几个数量级(在 Attention 计算部分)。

decode 阶段($T_{\text{new}} = 1$,$T_{\text{cached}} = 1{,}048{,}576$)为例,BF16 精度,统计 multiply-add 为 2 FLOPs。

6.3.1 Full Attention 层(3 层,层 0-2)

QK 点积(decode 时 Q 只有 1 token,K 有 T cached):

$$\begin{aligned} \text{FLOPs}_{\text{QK}} &= 2 \times H_q \times D_h \times T \\ &= 2 \times 64 \times 128 \times 1{,}048{,}576 \\ &= 16{,}384 \times 1{,}048{,}576 = 1.718 \times 10^{10} \approx \mathbf{17.2 \text{ GFLOPs}} \end{aligned}$$

Attention-V 加权

$$\begin{aligned} \text{FLOPs}_{\text{AttnV}} &= 2 \times H_q \times T \times D_h \\ &= 2 \times 64 \times 1{,}048{,}576 \times 128 = 17.2 \text{ GFLOPs} \end{aligned}$$

Per Full Attn layer decode FLOPs:$17.2 + 17.2 = \mathbf{34.4 \text{ GFLOPs}}$

3 层合计:$3 \times 34.4 = \mathbf{103.1 \text{ GFLOPs}}$

在 1M 上下文中,即使 Q 只有 1 个新 token,QK 点积也要算 1M 次内积(每个 cached K 对新 Q 算一次相似度)。64 个 Q 头 × 128 维 × 1M tokens × 2 = 16.4B 次运算。这就是 Full Attention 在长上下文 decode 中的致命弱点——每生成一个新 token,要跟之前所有 token 做一次全量比较。

6.3.2 MSA 层(57 层,层 3-59)

MSA 分为 Index Branch + Main Attention。

Index Branch

$$\begin{aligned} \text{FLOPs}_{\text{idx QK}} &= 2 \times H_{\text{idx}} \times D_{\text{idx}} \times T \\ &= 2 \times 4 \times 128 \times 1{,}048{,}576 = 1{,}024 \times 1{,}048{,}576 \\ &\approx \mathbf{1.074 \text{ GFLOPs}} \end{aligned}$$

Main Attention(仅在 $K = \text{topk_blocks} \times \text{block_size} = 16 \times 128 = 2{,}048$ 个 token 上做精确 attention):

$$\begin{aligned} \text{FLOPs}_{\text{main QK}} &= 2 \times H_q \times D_h \times K \\ &= 2 \times 64 \times 128 \times 2048 = 33{,}554{,}432 \approx \mathbf{33.6 \text{ MFLOPs}} \\ \text{FLOPs}_{\text{main AttnV}} &= 2 \times H_q \times K \times D_h = 33.6 \text{ MFLOPs} \end{aligned}$$

Per MSA layer decode FLOPs:$1{,}074 + 33.6 + 33.6 \approx \mathbf{1.14 \text{ GFLOPs}}$

57 层合计:$57 \times 1.14 = \mathbf{65.0 \text{ GFLOPs}}$

6.3.3 QK 加速比:Full Attn vs MSA

Attention QK 计算部分的加速比(只比较 QK 点积,不含线性投影):

$$\frac{\text{FLOPs}_{\text{QK}}^{\text{Full}}}{\text{FLOPs}_{\text{QK}}^{\text{MSA}}} = \frac{2 \times 64 \times 128 \times 1{,}048{,}576}{2 \times 64 \times 128 \times 2048} = \frac{1{,}048{,}576}{2{,}048} = \mathbf{512\times}$$

单层总 Attention FLOPs 加速比(含 Index Branch + Main Attention 的所有 attention 操作):

$$\frac{34.4 \text{ G}}{1.14 \text{ G}} \approx \mathbf{30.2\times}$$

为什么 512× 变成了 30×?因为 MSA 的 Index Branch 自身也有 FLOPs(1.07 GFLOPs),而且这 1.07G 在层总 FLOPs 中占比不小(1.07/1.14 ≈ 94%)。Index Branch 仍然需要 O(T) 的 QK 计算——它的目的是筛选 top-k blocks,而非跳过 QK 计算。

MSA 的 512× QK 加速是在 Main Branch 上实现的(2,048 vs 1M tokens),但 Index Branch 自身仍做 O(T) 扫描(只不过用了更少的 head:4 vs 64,所以也便宜了 16×)。总体效果约 30×,这意味着同样 1M 上下文,MSA 的 decode 比 Full Attention 快 30 倍——但仍然比短上下文(如 4K)的 Full Attention 要慢(因为 Index Branch 的 O(T) 扫描无法避免)。

6.3.4 线性投影 FLOPs(60 层共享)

Q、K、V、O 四个线性投影(per layer):

$$\begin{aligned} \text{Q proj} &= 2 \times 1 \times 6144 \times (64 \times 128) = 2 \times 6144 \times 8192 = 100.7 \text{ MFLOPs} \\ \text{K proj} &= 2 \times 1 \times 6144 \times (4 \times 128) = 2 \times 6144 \times 512 = 6.3 \text{ MFLOPs} \\ \text{V proj} &= 2 \times 1 \times 6144 \times 512 = 6.3 \text{ MFLOPs} \\ \text{O proj} &= 2 \times 1 \times (64 \times 128) \times 6144 = 100.7 \text{ MFLOPs} \end{aligned}$$

Per-layer 投影合计:$\approx 213.9 \text{ MFLOPs}$。60 层:$\mathbf{12.8 \text{ GFLOPs}}$。

6.3.5 MoE FFN FLOPs(57 层,per token)

共享专家(intermediate=3072,SwiGLU-OAI):

$$\begin{aligned} \text{gate\_up} &= 2 \times 1 \times 6144 \times (2 \times 3072) = 2 \times 6144 \times 6144 = 75.5 \text{ MFLOPs} \\ \text{down} &= 2 \times 1 \times 3072 \times 6144 = 37.7 \text{ MFLOPs} \\ \text{shared total} &= 75.5 + 37.7 = \mathbf{113.2 \text{ MFLOPs}} \end{aligned}$$

4 个路由专家

$$\text{routed total} = 4 \times 113.2 = \mathbf{452.8 \text{ MFLOPs}}$$

路由器:$2 \times 1 \times 6144 \times 128 = \mathbf{1.6 \text{ MFLOPs}}$

Per MoE layer decode FLOPs:$113.2 + 452.8 + 1.6 = \mathbf{567.6 \text{ MFLOPs}}$

57 层 MoE:$57 \times 0.5676 = \mathbf{32.4 \text{ GFLOPs}}$

6.3.6 Dense FFN FLOPs(3 层,per token)

Per layer(intermediate=12288):

$$\begin{aligned} \text{gate\_up} &= 2 \times 1 \times 6144 \times (2 \times 12288) = 2 \times 6144 \times 24576 = 302.0 \text{ MFLOPs} \\ \text{down} &= 2 \times 1 \times 12288 \times 6144 = 151.0 \text{ MFLOPs} \\ \text{per layer total} &= \mathbf{453.0 \text{ MFLOPs}} \end{aligned}$$

3 层合计:$\mathbf{1.36 \text{ GFLOPs}}$

6.3.7 MSA Indexer 投影 FLOPs(57 层)

Index Q 投影:$2 \times 1 \times 6144 \times (4 \times 128) = 6.3 \text{ MFLOPs}$
Index K 投影:$2 \times 1 \times 6144 \times 128 = 1.6 \text{ MFLOPs}$
Per-layer:$\approx 7.9 \text{ MFLOPs}$。57 层:$\mathbf{0.45 \text{ GFLOPs}}$

6.3.8 全模型 Decode FLOPs 汇总

组件层数Per-layer (GFLOPs)合计 (GFLOPs)
Full Attention (QK + AttnV)334.4103.1
MSA Attention (Idx + Main)571.1465.0
线性投影 (Q/K/V/O)600.21412.8
Dense FFN30.4531.36
MoE FFN (shared + 4 routed)570.56832.4
Indexer 投影570.0080.45
Embedding + LM Head1~0.02
Total per decode token @1M~215 GFLOPs

6.3.9 与"全 Full Attention M3"对比

如果 M3 的 57 个 MSA 层全部替换为 Full Attention(保持所有其它参数不变):

$$\begin{aligned} \text{FLOPs}_{\text{Full-only QK+AttnV}} &= 103.1 + 57 \times 34.4 = 103.1 + 1{,}960.8 = \mathbf{2{,}064 \text{ GFLOPs}} \\ \text{FLOPs}_{\text{MSA (actual)}} &= 103.1 + 65.0 = \mathbf{168.1 \text{ GFLOPs}} \end{aligned}$$$$\text{Attention 计算加速比} = \frac{2{,}064}{168.1} \approx \mathbf{12.3\times}$$

若计算全模型 FLOPs(含投影 + FFN):

$$\text{FLOPs}_{\text{Full-only total}} = 2{,}064 + 12.8 + 1.36 + 32.4 + 0.02 = 2{,}111 \text{ GFLOPs}$$$$\text{Overall speedup} = \frac{2{,}111}{215} \approx \mathbf{9.8\times}$$

Attention 计算加速 12.3×,但因线性投影和 FFN 不变,总体加速约 10×。M3 花了 57 层 Indexer 的代价(+0.224B 参数,占总参 0.05%),换来了约 10 倍的 decode 速度提升。这是 MSA 被称为 “architectural free lunch” 的原因。


6.4 KV Cache(T=1M)

Main KV Cache(60 层)

$$\begin{aligned} M_{\text{kv}}^{\text{main}} &= 60 \times 2 \times 4 \times 128 \times 1{,}048{,}576 \times 2 \\ &= 60 \times 2 \times 4 \times 128 \times 1{,}048{,}576 \times 2 \\ &= 128{,}849{,}018{,}880 \text{ bytes} \\ &\approx \mathbf{128.8 \text{ GiB}} \end{aligned}$$

Index K Cache(57 层 MSA)

$$\begin{aligned} M_{\text{kv}}^{\text{index}} &= 57 \times 1 \times 128 \times 1{,}048{,}576 \times 2 \\ &= 15{,}300{,}329{,}472 \text{ bytes} \\ &\approx \mathbf{15.3 \text{ GiB}} \end{aligned}$$

总 KV Cache(per sample, BF16)

$$M_{\text{kv}}^{(1)} = 128.8 + 15.3 = \mathbf{144.1 \text{ GiB}}$$

分项占比

KV Cache per sample @1M = 144.1 GiB
┌──────────────────────────────────────────────────────────┐
│████████████████████████████████████████████████          │ Main KV (128.8 GiB, 89.4%)
│██████                                                      │ Index K (15.3 GiB, 10.6%)
└──────────────────────────────────────────────────────────┘

MSA 的 KV Cache 与 Full Attention 完全相同——稀疏性只体现在计算(哪些 KV 被访问),不体现在存储(所有 KV 仍需缓存,因为不同 query 可能选择不同 blocks)。Index K Cache 额外增加了约 10.6% 的 KV Cache 开销。这是 MSA 与 sliding window attention 的本质区别——后者可以裁剪 KV Cache,但 MSA 不能(理论上可以 evict 从未被任何 query 选中的 block,但这需要额外的 bookkeeping)。

Batch Scaling

Batch SizeMain KV (GB)Index KV (GB)Total KV (GB)
1128.815.3144.1
2257.730.6288.3
4515.461.2576.6
81,030.7122.41,153.1

Batch=4 时 KV Cache 已超过 500 GiB——仅 KV Cache 就够塞满 4 张 H200。这是长上下文推理的核心瓶颈。


6.5 推理显存

BF16 精度,单样本,1M 上下文

$$\begin{aligned} M_{\text{weights}} &= 428 \times 10^9 \times 2 = \mathbf{856 \text{ GiB}} \\ M_{\text{kv}} &= \mathbf{144.1 \text{ GiB}} \\ M_{\text{act+overhead}} &\approx \mathbf{5 \text{ GiB}} \\ M_{\text{total}}^{(1)} &= 856 + 144.1 + 5 = \mathbf{1{,}005.1 \text{ GiB}} \end{aligned}$$

硬件匹配

以 8 × H200(141 GiB/card,合计 1,128 GiB)为目标平台:

Step 1:权重装得下吗?

$$M_{\text{weights}} = 856 \text{ GiB} < 1{,}128 \text{ GiB} \quad \checkmark$$

Step 2:可用显存

$$M_{\text{available}} = 1{,}128 - 856 = \mathbf{272 \text{ GiB}} \quad (\text{8 卡合计})$$

Step 3:最大并发 batch

$$B_{\text{max}} = \left\lfloor \frac{272}{144.1 + 5} \right\rfloor = \left\lfloor \frac{272}{149.1} \right\rfloor = \lfloor 1.82 \rfloor = \mathbf{1 \text{ sample}}$$

结论:BF16 下 8×H200 可以跑 M3 的 1M 上下文 BF16 推理,但只能支持最多 1 个并发请求。batch=2 理论上可能($272 / 149.1 \approx 1.8$),但接近显存上限,实际部署中不建议。

FP8 权重 + BF16 KV Cache

$$\begin{aligned} M_{\text{weights}} &= 428 \times 10^9 \times 1 = \mathbf{428 \text{ GiB}} \\ M_{\text{available}} &= 1{,}128 - 428 = \mathbf{700 \text{ GiB}} \\ B_{\text{max}} &= \left\lfloor \frac{700}{149.1} \right\rfloor \approx \mathbf{4 \text{ samples}} \end{aligned}$$

FP8 权重 + FP8 KV Cache

$$\begin{aligned} M_{\text{kv}}^{(1)\text{ FP8}} &= 144.1 / 2 = \mathbf{72.1 \text{ GiB}} \\ M_{\text{available}} &= 1{,}128 - 428 = 700 \text{ GiB} \\ B_{\text{max}} &= \left\lfloor \frac{700}{72.1 + 5} \right\rfloor \approx \mathbf{9 \text{ samples}} \end{aligned}$$

汇总表

精度方案权重 (GB)KV/样本 (GB)可用 (GB)Max Batch @1M
BF16 W + BF16 KV856144.12721
FP8 W + BF16 KV428144.17004
FP8 W + FP8 KV42872.17009
INT4 W + FP8 KV21472.191412

对比 Nemotron(5.7 节):Nemotron 从 BF16→FP8 后 batch 从 1→38,M3 只从 1→4。原因:M3 的 KV Cache 占比高(144 GiB/样本),量化权重解放的显存很快被 KV Cache 吃掉。M3 的显存瓶颈是双重的——权重和 KV Cache 都制约并发。


6.6 验算与交叉对比

与 M3 官方博客声明对照

M3 官方博客声称 MSA 在 1M 上下文下实现 ~30× decode 加速。本节 6.3.3 的直接计算给出:

$$\text{Per-layer Attention FLOPs ratio} = \frac{34.4 \text{ G}}{1.14 \text{ G}} \approx 30.2\times$$

全模型(含投影+FFN):$\approx 9.8\times$。

差异解释:官方的 30× 是指 Attention 计算部分(QK + AttnV),不含线性投影(Q/K/V/O)和 FFN。两者都是正确的——只是口径不同:

  • 30×:Attention 算子层面(孤立地看 MSA 替代 Full Attention 的效果)
  • 10×:端到端 decode 速度(含所有矩阵乘法和 FFN)

当别人说"MSA 让 M3 快了 30 倍",他说的是注意力计算。当你说"为什么我实测只快了 10 倍",因为你还算上了 FFN 和线性投影。两者都对,但需要明确口径。

与纯 Full Attention M3 的显存对比

如果 M3 不使用 MSA(即全部 60 层 Full Attention),KV Cache 变化:

$$\begin{aligned} M_{\text{kv}}^{\text{Full-only}} &= 60 \times 2 \times 4 \times 128 \times 1{,}048{,}576 \times 2 = 128.8 \text{ GiB} \\ M_{\text{kv}}^{\text{MSA (actual)}} &= 144.1 \text{ GiB} \end{aligned}$$

KV Cache 不降反增(+15.3 GiB Index K)——MSA 在显存上不是优化,是略微增加了开销。MSA 的价值在计算(FLOPs),不在存储(Memory)。

与 Nemotron 550B 的横向对比

维度Nemotron 550BMiniMax M3比值
总参550B428B1.29×
BF16 权重1,100 GiB856 GiB1.29×
KV Cache (1M, BF16)13 GiB144 GiB0.09×
KV/Weights 比1.2%16.8%14× 差异
Decode FLOPs/T~300G (estimate)~215G~1.4×
显存瓶颈类型纯权重权重 + KV Cache 双瓶颈
FP8 后 Batch (8×H200)3849.5× 差异

核心洞见:Nemotron 用 Mamba-2 置换 Attention 的策略,在 1M 上下文下产生了约 10× 的 KV Cache 优势。这个优势在短上下文(< 32K)下不明显(因为 KV Cache 本来就小),但随着上下文增长到 1M 时成为决定性的架构差异。MSA 解决了 Attention 的计算瓶颈,但没有解决存储瓶颈——在极端长上下文下,Mamba-2/MLA 的 KV Cache 优势会越来越明显。


本章公式速查

计算目标公式说明
权重显存$M_w = N \times \text{bytes}$$N$ 为总参数量
GQA KV Cache (per layer)$2 \times H_{kv} \times D_h \times T \times \text{bytes}$K 和 V 两份
MLA KV Cache (per layer)$(d_{kv} + D_{rope}) \times T \times \text{bytes}$只存压缩向量
MSA Index KV (per layer)$1 \times D_{\text{idx}} \times T \times \text{bytes}$只有 K
总显存$M_w + B \times M_{kv}^{(1)} + M_{act}$Batch 乘 KV
MoE 激活参$\text{Non-MoE} + k \times \text{Params}_{\text{expert}}$$k$ 为 top-k
Full Attn decode QK FLOPs$2 \times H_q \times D_h \times T$$T$ = cached length
MSA Main QK FLOPs$2 \times H_q \times D_h \times K$$K$ = 2048 (16 blocks × 128)
最低卡数$\lceil M_w / \text{per_card} \rceil$仅考虑权重
最大 Batch$\lfloor (M_{pool} - M_w) / (M_{kv}^{(1)} + M_{act}) \rfloor$考虑 KV Cache

本章常见计算错误

#错误正确做法
1用 $H_q$ 代替 $H_{kv}$ 算 KV CacheKV Cache 宽度由 KV 头数决定(GQA 下 $H_{kv} \ll H_q$),与 Q 头数无关
2MSA 的 KV Cache 忘记 Index KMSA 额外存储 1 个 Index K(每层 1 头 × $D_{idx}$),约占总 KV Cache 的 10%
3认为 MSA 减少了 KV CacheMSA 减少的是计算(FLOPs),不是存储(KV Cache)——两者解耦
4EP 只看总显存够不够EP 要求每张卡装得下其分配的专家子集 + 非 MoE 参数副本,不能只看平均数
5Batch 乘 KV Cache 时忘记 batch 效应权重跨 batch 共享,KV Cache 不共享——$B=100$ 就是 100× KV
6混淆 Attention 加速比和端到端加速比30× 是 per-layer Attention 算子加速;10× 是全模型 end-to-end 加速(含不变的线性投影和 FFN)
7激活值完全忽略虽然通常 < 权重的 5%,但在 tight memory budget 下 5% = 50 GiB(8 卡场景),可能就是 OOM 的原因

各模型 BF16 推理显存横评

模型总参权重 (GB)KV/样本 (GB)可用 (GB)Max Batch
Nemotron 3 Ultra550B1,10013281
MiniMax M3428B8561442721
DeepSeek V4 Flash~300B~600~72 (MLA)~528~7
Kimi K2.5~1T~2,000~72 (MLA)< 0 (OOM!)需 16 卡+

注:K2.5 BF16 推理即使 16 张 H200 (2,256 GiB) 也只能负载 ~2,000 GiB 权重 + 少量 KV。实际部署需要 FP8 或 INT4 量化。


系列导航:CH 1-2(预备知识 + 参数分解)→ CH 3(FLOPs 估算)→ CH 4(KV Cache)→ CH 5(推理显存)→ CH 6(M3 实战推演)


附录

附录 A: 常见 config.json 字段速查表

哪些字段影响哪些计算:

config 字段影响的计算示例值
hidden_size所有投影矩阵参数 + QKV/O 的 FLOPs6144 (M3), 7168 (K2.5), 8192 (Nemotron)
num_hidden_layers总层数 → 乘到每层参数/FLOPs/KV cache60 (M3), 61 (K2.5), 108 blocks (Nemotron)
num_attention_headsQ 投影大小 + QK 点积 FLOPs64 (大多数 7B+ 模型)
num_key_value_headsK/V 投影大小 + KV cache 大小4 (M3 GQA), 2 (Nemotron), 64 (K2.5 MHA)
head_dimQK 点积维度 + KV cache 的 D128 (大多数)
intermediate_sizeFFN 参数(up/gate/down gate)12288 (M3 dense), 18432 (K2.5)
moe_intermediate_sizeMoE expert 参数2048 (M3), 5120 (Nemotron)
n_routed_expertsMoE 总专家数 → 总 MoE 参数128 (M3), 256 (GLM-5.1), 384 (K2.5), 512 (Nemotron)
num_experts_per_tok激活参数计算4 (M3), 8 (K2.5), 22 (Nemotron)
kv_lora_rankMLA KV 压缩维度 → KV cache 大小512 (K2.5, DeepSeek V3/V4)
q_lora_rankMLA Q 压缩维度 → Attention 参数1536 (K2.5)
qk_rope_head_dimMLA k_rope 维度 → KV cache 的 rope 分量64 (K2.5)
ssm_state_sizeMamba-2 state 维度 → 替代 KV cache 的状态大小128 (Nemotron)
max_position_embeddings最大上下文 → KV cache 最大 T + FLOPs 最大 T262144 (K2.5), 1048576 (M3/Nemotron)
vocab_sizeEmbedding 参数 + LM head 参数131072 (Nemotron), 200064 (M3)
dense_intermediate_sizeMoE 模型的 dense FFN 层参数12288 (M3)
shared_intermediate_size共享 expert 的 FFN 参数3072 (M3)
sparse_block_sizeMSA 的 block 大小 → FLOPs 计算128 (M3)
sparse_topk_blocksMSA 的 top-k blocks → FLOPs 计算16 (M3)
vision_config.hidden_sizeViT 参数 + FLOPs1280 (M3), 1152 (K2.5)
vision_config.num_hidden_layersViT 层数32 (M3), 27 (K2.5)
patch_size图像 token 数 → Vision encoder FLOPs14 (大多数)
rope_theta位置编码 theta → 上下文扩展策略判断50000 (K2.5), 5000000 (M3), 10000 (Nemotron)

附录 B: 符号与缩写表

符号含义常用值示例
$d$ / $d_{model}$隐藏维度 (hidden_size)6144, 7168, 8192
$H$Q(Query)头数 (num_attention_heads)64
$H_{kv}$K/V 头数 (num_key_value_heads)4 (GQA), 2 (GQA), 64 (MHA)
$D$每个 head 的维度 (head_dim)128
$d_{ff}$FFN 中间维度 (intermediate_size / moe_intermediate_size)2048-18432
$L$总层数60-108
$L_{attn}$使用 Attention 的层数(Mamba hybrid 中仅部分层)12 (Nemotron)
$T$序列长度(当前总 token 数)4K-1M
$T_{new}$新生成 token 数(decode 时为 1)1
$N_E$MoE 专家总数 (n_routed_experts)128-512
$k$每个 token 激活的专家数 (num_experts_per_tok)4-22
$B$Batch size1 (单样本推理)
$d_{kv}$MLA KV 压缩维度 (kv_lora_rank)512
$d_{rope}$MLA RoPE 维度 (qk_rope_head_dim)64
$H_{mamba}$Mamba-2 SSD head 数256
$d_{state}$Mamba-2 状态空间维度128
$d_{latent}$LatentMoE 低秩维度2048 (Nemotron)
$C$Mamba-2 chunk 大小128
$B_{msa}$MSA block 大小128
$K_{msa}$MSA top-k blocks16
$N_{img}$每图像 token 数576
$V$词表大小 (vocab_size)131072, 200064
字节精度BF16=2, FP8=1, FP4=0.5, FP32=4

附录 C: 8 个已拆解模型的计算结果速览

模型总参激活参FLOPs (decode, T=1M)KV Cache (1M)推理显存 (BF16, 1 sample)
Nemotron 3 Ultra550B55B~1.2×10¹⁵12.0 GiB (仅12 Attn层)~1.13 TiB
MiniMax M3428B23B~2.2×10¹¹144 GiB~1,000 GiB
Kimi K2.51T32B~(未在1M下)~21.5 GiB (256K)— (256K context)
DeepSeek V4-Flash~300B37B~131 GiB (1M, MLA)
MiniMax M2.7~275B~17B— (Full Attn O(T²))
GLM-5.1744B32B
Qwen3.5-MoE~35B~3B
MiMo-V2-Flash~140B~7B

("—" = 该模型未在该上下文长度下做详细估算,或报告未公开该维度数据)


关于本文:本文档从 8 个开源 LLM 的深度架构拆解中提炼而成。每个公式、每个数字都在对应模型上验证通过。如果你发现错误或有改进建议,欢迎反馈。