大模型- moe++-96
MoE++: Accelerating Mixture-of-Experts Methods with Zero-Computation Experts
链接 https://arxiv.org/abs/2410.07348
这篇 MoE++ 的核心创新是 “零计算专家”(Zero-Computation Experts, ZCE),旨在 减少 MoE 推理时的实际计算量,
核心思想:零计算专家(Zero-Computation Experts, ZCE)
传统 MoE 的瓶颈:
专家激活数量固定(如 top_k=2),即使某些专家输出是冗余的。
所有专家都要计算,即使输入内容很简单。
ZCE 的三大核心概念
| 概念 | 解释 | 类比 |
|---|---|---|
| Zero-Computation Expert (ZCE) | 一种“虚拟专家”,其输出直接等于输入(无计算) | 类似恒等映射(Identity Mapping) |
| Adaptive Computation Path | 根据输入复杂度,动态决定是否使用 ZCE | 类似条件计算(Conditional Computation) |
| Routing with ZCE | 路由器可以选择将简单 token 分配给 ZCE | 类似“短路”机制 |
Identity Metric单位矩阵 实现的就是恒等映射Identity Mapping
关键创新:ZCE 不参与实际计算,直接返回输入,节省计算资源!
二、MoE++ 架构详解
输入 Token x
│
▼
[ 路由器 Gate ] ──► 选择 top_k 专家(含 ZCE 选项)
│ 例如:[专家 A, ZCE]
▼
[ 专家计算层 ]
├── 专家 A: 计算 F(x) ← 实际计算
└── ZCE: 直接返回 x ← 零计算!
│
▼
加权输出: y = w_A * F(x) + w_ZCE * x
路由器不仅要决定专家权重,还要 预测哪些 token 适合用 ZCE。
import torch
import torch.nn as nn
import torch.nn.functional as F
class MoEPlusPlusLayer(nn.Module):
def __init__(self, d_model, num_real_experts, top_k=2, zce_ratio=0.3):
super().__init__()
self.num_real_experts = num_real_experts
self.top_k = top_k
self.zce_ratio = zce_ratio # 控制 ZCE 被选中的概率
# 真实专家(需要计算)
self.real_experts = nn.ModuleList([
nn.Sequential(
nn.Linear(d_model, d_model * 4),
nn.GELU(),
nn.Linear(d_model * 4, d_model)
) for _ in range(num_real_experts)
])
# 路由器:输出 [num_real_experts + 1](+1 是 ZCE)
self.gate = nn.Linear(d_model, num_real_experts + 1)
# ZCE 的“权重”(实际上不需要计算)
self.zce_id = num_real_experts # ZCE 的索引
def forward(self, x):
# x: [B, S, d_model]
B, S, _ = x.shape
# 1. 路由
router_logits = self.gate(x) # [B, S, num_real_experts + 1]
routing_weights = F.softmax(router_logits, dim=-1)
# 2. 选择 top_k 专家
topk_weight, topk_idx = torch.topk(routing_weights, self.top_k, dim=-1) # [B, S, K]
"""
topk_weight:
tensor([[[0.1600, 0.1509],
[0.1587, 0.0914],
[0.1082, 0.1051],
[0.1114, 0.1078],
[0.1190, 0.1107],
[0.1359, 0.1147],
[0.1092, 0.0960],
[0.1351, 0.1074],
[0.1493, 0.1097],
[0.1048, 0.0910]]], grad_fn=<TopkBackward0>)
index_id:
tensor([[[ 9, 2], # 第一个token由 9,2这两个expert来计算
[14, 12],
[ 1, 15],
[ 3, 13],
[11, 14],
[ 2, 0],
[13, 0],
[13, 5],
[ 0, 7],
[ 8, 1]]])
"""
# 3. 初始化输出
output = torch.zeros_like(x)
# 4. 遍历每个选中的专家
for k in range(self.top_k):
expert_idx = topk_idx[..., k] # [B, S]
expert_weight = topk_weight[..., k].unsqueeze(-1) # [B, S, 1]
'''
expert_weight:
tensor([[[0.1600, ],
[0.1587, ],
[0.1082, ],
[0.1114, ],
[0.1190, ],
[0.1359, ],
[0.1092, ],
[0.1351, ],
[0.1493, ],
[0.1048, ]]], grad_fn=<TopkBackward0>)
expert_idx:
tensor([[[ 9, ], # 第一个token的第一个专家id
[14, ], # 第二个token的第一个专家id
[ 1, ],
[ 3, ],
[11, ],
[ 2, ],
[13, ],
[13, ],
[ 0, ],
[ 8, ]]])
'''
for i in range(self.num_real_experts + 1): # 遍历所有的专家
mask = (expert_idx == i) # [B, S] 如果id与当前遍历的专家id相等
"""
tensor([[False, False, False, False, False, False, False, True, True, False]])
"""
if mask.sum() == 0: # 如果不存在与当前遍历的专家id相等
continue
if i == self.zce_id:
# ✅ ZCE:直接返回输入(零计算!)
output[mask] += x[mask] * expert_weight[mask]
else:
# 真实专家:计算 F(x)
expert_out = self.real_experts[i](x[mask])
output[mask] += expert_out * expert_weight[mask]
return output
if __name__ == '__main__':
moe = MoEPlusPlusLayer(d_model=1024, num_real_experts=16, top_k=2)
print(moe(torch.randn(1, 10, 1024)).shape)
当 mask 选中 ZCE 时,没有任何矩阵乘法或激活函数计算
ZCE 的选择策略
路由器直接输出 P(ZCE),训练时通过 负载均衡损失 控制 ZCE 使用频率。
# 在训练时添加 ZCE 使用比例损失
zce_mask = (topk_idx == self.zce_id)
zce_usage = zce_mask.float().mean()
loss_zce = (zce_usage - self.zce_ratio) ** 2 # 鼓励使用 zce_ratio 比例的 ZCE
基于复杂度的选择
# 复杂度预测器(可插拔)
self.complexity_predictor = nn.Linear(d_model, 1)
complexity_score = self.complexity_predictor(x) # [B, S, 1]
use_zce = (complexity_score < threshold).float() # 简单 token 用 ZCE
| 指标 | 传统 MoE | MoE++ (ZCE) | 提升 |
|---|---|---|---|
| 计算量 (FLOPs) | 100% | 60-80% | 20-40% ↓ |
| 延迟 (Latency) | 100ms | 65ms | 35% ↓ |
| 显存 (Memory) | 100% | 95% | 5% ↓ |
| 准确率 (Accuracy) | 基准 | 下降 < 1% | 几乎无损 |
为什么 ZCE 有效?
1 语言建模的冗余性
简单 token(如标点、重复词)不需要复杂变换。
ZCE 相当于 “恒等映射”,保留了原始信息。
2 残差连接的保护
即使 ZCE 输出错误,残差连接 y = x + Δ 仍能保留原始信息。
类似 ResNet 中的恒等路径。
3 路由器的自适应能力
训练后,路由器会 自动学会 将简单 token 分配给 ZCE。
总结:MoE++ (ZCE) 的价值
零计算专家(ZCE):首次提出将“无计算”作为 MoE 的一种专家类型。
自适应计算路径:根据输入复杂度动态分配计算资源。
高效训练策略:通过负载均衡损失控制 ZCE 使用比例。
适用场景
边缘设备部署(如手机、IoT)
高并发低延迟服务(如对话系统)
大规模 MoE 预训练(节省训练成本)

浙公网安备 33010602011771号