RankMixer的整体架构包含T个输入标记,这些标记经过L个连续的RankMixer块处理,随后接一个输出池化操作。每个RankMixer块包含两个主要组件:(1) 多头令牌混合层,以及(2) 逐令牌前馈网络层,如图所示。首先,输入向量e i n p u t \mathbf{e}_{\mathrm{input}}einput被切分为T个特征令牌x 1 , x 2 , … , x T \mathbf{x}_1, \mathbf{x}_2, \dots, \mathbf{x}_Tx1,x2,…,xT,每个令牌代表一个连贯的特征向量。RankMixer块通过以下方式对标记表示进行L层迭代优化: S n − 1 = L N ( T o k e n M i x i n g ( X n − 1 ) + X n − 1 ) , X n = LN ( PFFN ( S n − 1 ) + S n − 1 ) , (1) \begin{array}{l} \mathrm {S} _ {n - 1} = \mathrm {L N} (\text {T o k e n M i x i n g} (\mathrm {X} _ {n - 1}) + \mathrm {X} _ {n - 1}), \\ \mathrm {X} _ {n} = \operatorname {L N} \left(\operatorname {P F F N} \left(\mathrm {S} _ {n - 1}\right) + \mathrm {S} _ {n - 1}\right), \tag {1} \\ \end{array}Sn−1=LN(T o k e n M i x i n g(Xn−1)+Xn−1),Xn=LN(PFFN(Sn−1)+Sn−1),(1) 其中L N ( ⋅ ) \mathrm{LN}(\cdot)LN(⋅)表示层归一化函数,TokenMixing(⋅ \cdot⋅) 与 PFFN(⋅ \cdot⋅) 分别为多头令牌混合模块与逐令牌前馈网络模块,X n ∈ R T × D \mathbf{X}_n \in \mathbb{R}^{T \times D}Xn∈RT×D是第n nn个 RankMixer 块的输出,X 0 ∈ R T × D \mathbf{X}_0 \in \mathbb{R}^{T \times D}X0∈RT×D由x 1 , x 2 , … , x T \mathbf{x}_1, \mathbf{x}_2, \ldots, \mathbf{x}_Tx1,x2,…,xT堆叠而成,D DD为模型的隐藏维度。输出表示o o u t p u t \mathbf{o}_{\mathrm{output}}ooutput源自最终层表示X L \mathbf{X}_LXL的平均池化,该表示将用于计算不同任务的预测结果。
为了做高效的特征交叉,将每个Token(令牌)划分为H个头,令牌x t \mathbf{x}_txt的第h hh个头记为x t h x_t^hxth: [ x t ( 1 ) ∥ x t ( 2 ) ∥ … ∥ x t ( H ) ] = SplitHead ( x t ) . (3) \left[ \mathbf {x} _ {t} ^ {(1)} \| \mathbf {x} _ {t} ^ {(2)} \| \dots \| \mathbf {x} _ {t} ^ {(H)} \right] = \operatorname {S p l i t H e a d} \left(\mathbf {x} _ {t}\right). \tag {3}[xt(1)∥xt(2)∥…∥xt(H)]=SplitHead(xt).(3)
这些头部可视为将标记x t \mathbf{x}_txt投影到低维特征子空间,因为推荐任务需要从不同视角进行考量。Token混合用于融合这些子空间向量以实现全局特征交互。形式上,经过多头标记混合后,第h hh个头部对应的第h hh个标记s h \mathbf{s}^hsh构建如下: s h = Concat ( x 1 h , x 2 h , … , x T h ) . (4) \mathbf {s} ^ {h} = \operatorname {C o n c a t} \left(\mathbf {x} _ {1} ^ {h}, \mathbf {x} _ {2} ^ {h}, \dots , \mathbf {x} _ {T} ^ {h}\right). \tag {4}sh=Concat(x1h,x2h,…,xTh).(4)
多头令牌混合模块的输出为S ∈ R H × T D H \mathbf{S} \in \mathbb{R}^{H \times \frac{T D}{H}}S∈RH×HTD,由所有重排后的令牌s 1 , s 2 , … , s H \mathbf{s}_1, \mathbf{s}_2, \dots, \mathbf{s}_Hs1,s2,…,sH堆叠而成。本研究中,论文中设定H = T H = TH=T以保持令牌混合后残差连接所需的令牌数量不变。
最后的形式如下: s 1 , s 2 , … , s T = LN ( 令牌混合 ( x 1 , x 2 , … , x T ) + ( x 1 , x 2 , … , x T ) ) (5) \mathbf {s} _ {1}, \mathbf {s} _ {2}, \dots , \mathbf {s} _ {T} = \operatorname {L N} (\text {令牌混合} (\mathbf {x} _ {1}, \mathbf {x} _ {2}, \dots , \mathbf {x} _ {T}) + (\mathbf {x} _ {1}, \mathbf {x} _ {2}, \dots , \mathbf {x} _ {T})) \tag {5}s1,s2,…,sT=LN(令牌混合(x1,x2,…,xT)+(x1,x2,…,xT))(5)
对于令牌s i ∈ R d h s_i \in \mathbb{R}^{d_h}si∈Rdh及其第j jj个专家e i , j ( ⋅ ) e_{i,j}(\cdot)ei,j(⋅),通过路由器h ( ⋅ ) h(\cdot)h(⋅)计算: G i , j = ReLU ( h ( s i ) ) , v i = ∑ j = 1 N e G i , j e i , j ( s i ) , (10) G_{i,j} = \operatorname{ReLU}\left(h\left(\mathbf{s}_{i}\right)\right), \quad \mathbf{v}_{i} = \sum_{j=1}^{N_{e}} G_{i,j} e_{i,j}\left(\mathbf{s}_{i}\right), \tag{10}Gi,j=ReLU(h(si)),vi=j=1∑NeGi,jei,j(si),(10) 其中N e N_{e}Ne表示每个词元的专家数量,N t N_{t}Nt表示词元总数。ReLU路由机制将为高信息量词元激活更多专家,从而提升参数效率。稀疏性通过L r e g \mathcal{L}_{\mathrm{reg}}Lreg正则项进行调控,其系数λ \lambdaλ使平均激活专家比例维持在预算阈值附近: L = L 任务 + λ L 正则 , L 正则 = ∑ i = 1 N t ∑ j = 1 N e G i , j . (11) \mathcal {L} = \mathcal {L} _ {\text {任务}} + \lambda \mathcal {L} _ {\text {正则}}, \quad \mathcal {L} _ {\text {正则}} = \sum_ {i = 1} ^ {N _ {t}} \sum_ {j = 1} ^ {N _ {e}} G _ {i, j}. \tag {11}L=L任务+λL正则,L正则=i=1∑Ntj=1∑NeGi,j.(11)
密集训练/稀疏推理(DTSI-MoE)部分,采用两个路由器h t r a i n h_{\mathrm{train}}htrain和h i n f e r h_{\mathrm{infer}}hinfer,且正则化损失L r e g \mathcal{L}_{\mathrm{reg}}Lreg仅作用于h i n f e r h_{\mathrm{infer}}hinfer。训练期间h t r a i n h_{\mathrm{train}}htrain与h i n f e r h_{\mathrm{infer}}hinfer同步更新,而推理阶段仅使用h i n f e r h_{\mathrm{infer}}hinfer。该方法使专家模型在降低推理成本的同时避免了训练不足的问题。