MoE(Mixture of Experts)架构的流行自不必多说,近来火出圈的DeepSeek-V3便是MoE架构,传言GPT-5也是MoE架构,国内最近出的一些模型(Qwen3系列相关)也有不少用上了MoE。然而,虽然MoE的研究由来已久,但其应用长时间内都不愠不火,大致上是从去年初的《Mixtral of Experts》开始,MoE才逐渐吸引大家的注意力,其显著优点是参数量大,但训练和推理成本都显著低。
但同时MoE也有一些难题,如训练不稳定、负载不均衡、效果不够好等,这也是它早年没有流行起来的主要原因。不过随着这两年关注度的提升,这些问题在很大程度上已经得到解决,我们在接下来的介绍中会逐一谈到这些内容。
问题定义 我们知道,Transformer模型由Attention层和MLP层组成,MoE替换的是模型中MLP层。MLP层又分FFN(FeedForward Network)和GLU(Gated Linear Unit)两种,主流的是GLU,但简单起见我们还是以FFN为例:
$$y=f(xW^{(A)})W^{(B)}$$其中$x\in\mathbb{R}^d$ 是输入向量(行向量),$W^{(A)}\in\mathbb{R}^{d\times{D}}$, $W^{(B)}\in\mathbb{R}^{D\times{d}}$ 是两个参数矩阵,$f$是Element-wise的激活函数,设$n$是一个能整除$D$的整数,那么上面的FFN可以用分块矩阵等价:
$$ \begin{equation}\boldsymbol{y} = f\big(\boldsymbol{x}\begin{bmatrix}\boldsymbol{W}^{(A)}_1 & \boldsymbol{W}^{(A)}_2 & \cdots & \boldsymbol{W}^{(A)}_n\end{bmatrix}\big)\begin{bmatrix}\boldsymbol{W}^{(B)}_1 \\ \boldsymbol{W}^{(B)}_2 \\ \vdots \\ \boldsymbol{W}^{(B)}_n\end{bmatrix} = \sum_{i=1}^n \underbrace{f(\boldsymbol{x}\boldsymbol{W}^{(A)}_i)\boldsymbol{W}^{(B)}_i}_{\boldsymbol{v}_i}\end{equation} $$ 其中
$W^{(A)}_i = W^{(A)}_{[:,(i-1)c:ic]}$, $W^{(B)}_i = W^{(B)}_{[(i-1)c:ic,:]}$, $c= D/n$,这里的切片按照Python规则来。由此可见,FFN可以等价表示成n个向量
$\boldsymbol{v}_1,\boldsymbol{v}_2,\cdots,\boldsymbol{v}_n$
之和,每个向量代表了一个小模型$f(\boldsymbol{x}\boldsymbol{W}^{(A)}_i)\boldsymbol{W}^{(B)}_i$的输出,每个小模型计算量相同,这些小模型就是MoE中的“Expert”。
MoE提出的问题是:
能否只挑k个向量的和来逼近n个向量的和呢?这样就可以将计算量降低到k/n了。
模长排序 要解决上述的问题,实质上是要解决低秩近似的问题,数学公式就是:
$$\begin{equation}\mathop{\text{argmin}}_{\lambda_1,\lambda_2,\cdots,\lambda_n\in\{0,1\}}\left\Vert\sum_{i=1}^n \lambda_i \boldsymbol{v}_i - \sum_{i=1}^n\boldsymbol{v}_i\right\Vert^2\quad\text{s.t.}\quad \sum_{i=1}^n \lambda_i = k\end{equation}$$ 记$\gamma_i = 1 - \lambda_i$,那么它又可以写成:
$$\begin{equation}\mathop{\text{argmin}}_{\gamma_1,\gamma_2,\cdots,\gamma_n\in\{0,1\}}\left\Vert\sum_{i=1}^n \gamma_i \boldsymbol{v}_i\right\Vert^2\quad\text{s.t.}\quad \sum_{i=1}^n \gamma_i = n - k\end{equation}$$ 这个问题的精确求解是比较困难的(NP Hard),但有一个简单的近似解:当$v_i$两两正交时,我们有
$$\begin{equation}\left\Vert\sum_{i=1}^n \gamma_i \boldsymbol{v}_i\right\Vert^2 = \sum_{i=1}^n \gamma_i^2 \Vert\boldsymbol{v}_i\Vert^2 = \sum_{i=1}^n \gamma_i \Vert\boldsymbol{v}_i\Vert^2\end{equation}$$ 上式最优解显然就是让模长$\Vert\boldsymbol{v}_i\Vert$最小的$n-k$个$\gamma_i$等于1,这又等价于说挑出模长最大的$k$个向量来逼近$n$个向量之和。当$v_i$不满足两两正交的条件时,我们依然用它来作为一个近似解。它的几何意义也很直观,模长越大的向量,在求和过程中越不容易被抵消,从而作用越突出。
...