大模型- moe技术汇总-95
学术界和工业界确实持续在 Mixture of Experts (MoE) 架构上进行创新,
主流 MoE 架构核心组件回顾-基础 MoE 结构
import torch
import torch.nn as nn
import torch.nn.functional as F
class SparseMoELayer(nn.Module):
def __init__(self, d_model, num_experts, top_k=2):
super().__init__()
self.num_experts = num_experts
self.top_k = top_k
self.gate = nn.Linear(d_model, num_experts)
self.experts = nn.ModuleList([nn.Linear(d_model, d_model) for _ in range(num_experts)])
def forward(self, x):
# 1. 计算每个token到每个专家的权重
router_logits = self.gate(x) # [B, S, E] E为专家的个数
routing_weights = F.softmax(router_logits, dim=-1)
# 2. 选择top_k个专家 [B, S, K] E变成K
topk_weight, topk_idx = torch.topk(routing_weights, self.top_k, dim=-1)
"""
topk_weight:
tensor([[[0.1600, 0.1509],
[0.1587, 0.0914],
[0.1082, 0.1051],
[0.1114, 0.1078],
[0.1190, 0.1107],
[0.1359, 0.1147],
[0.1092, 0.0960],
[0.1351, 0.1074],
[0.1493, 0.1097],
[0.1048, 0.0910]]], grad_fn=<TopkBackward0>)
index_id:
tensor([[[ 9, 2], # 第一个token由 9,2这两个expert来计算
[14, 12],
[ 1, 15],
[ 3, 13],
[11, 14],
[ 2, 0],
[13, 0],
[13, 5],
[ 0, 7],
[ 8, 1]]])
"""
# 3. 初始化输出
output = torch.zeros_like(x) # [B, S, D]
# 4. 处理topk_weight, topk_idx
for k in range(self.top_k):
expert_idx = topk_idx[..., k] # [B, S]
expert_weight = topk_weight[..., k].unsqueeze(-1) # [B, S, 1]
'''
expert_weight:
tensor([[[0.1600, ],
[0.1587, ],
[0.1082, ],
[0.1114, ],
[0.1190, ],
[0.1359, ],
[0.1092, ],
[0.1351, ],
[0.1493, ],
[0.1048, ]]], grad_fn=<TopkBackward0>)
expert_idx:
tensor([[[ 9, ],
[14, ],
[ 1, ],
[ 3, ],
[11, ],
[ 2, ],
[13, ],
[13, ],
[ 0, ],
[ 8, ]]])
'''
# 提取专家处理的token
mask = F.one_hot(expert_idx, num_classes=self.num_experts).bool() # [B, S, E]
'''
tensor([[[False, False, False, False, False, False, False, False, False, True, False, False, False, False, False, False],
[False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, False],
[False, True, False, False, False, False, False, False, False, False, False, False, False, False, False, False],
[False, False, False, True, False, False, False, False, False, False, False, False, False, False, False, False],
[False, False, False, False, False, False, False, False, False, False, False, True, False, False, False, False],
[False, False, True, False, False, False, False, False, False, False, False, False, False, False, False, False],
[False, False, False, False, False, False, False, False, False, False, False, False, False, True, False, False],
[False, False, False, False, False, False, False, False, False, False, False, False, False, True, False, False],
[True, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False],
[False, False, False, False, False, False, False, False, True, False, False, False, False, False, False, False]]])
'''
for i in range(self.num_experts):
# tensor([[False, False, False, False, False, False, False, False, True, False]])
token_mask = mask[..., i] # [B, S]
if token_mask.sum() == 0:
continue
expert_output = self.experts[i](x[token_mask])
output[token_mask] += expert_output *expert_weight[token_mask]
return output
if __name__ == '__main__':
moe = SparseMoELayer(d_model=1024, num_experts=16, top_k=2)
print(moe(torch.randn(1, 10, 1024)).shape)
特点:每个 token 只激活 top_k 个专家,大幅降低计算量。
负载均衡(Load Balancing Loss)
def load_balancing_loss(router_probs, expert_mask):
"""
router_probs: [B, S, E] 每个 token 分配给专家的概率
expert_mask: [B, S, E] 实际是否分配给该专家(0/1)
"""
# 专家负载:实际被使用的频率
importance = router_probs.sum(dim=[0, 1]) # [E]
# 路由频率:被路由到的总次数
usage = expert_mask.sum(dim=[0, 1]) # [E]
# 平衡损失
loss = F.mse_loss(usage, importance) * self.num_experts
return loss
# 在 forward 中返回 loss
loss = load_balancing_loss(routing_weights, mask)
Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity
专家并行(Expert Parallelism)
将专家分布到多个 GPU 上,解决显存瓶颈。
# 使用 PyTorch + FSDP 或 DeepSpeed 实现专家并行
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
# 将每个专家包装为独立 FSDP 模块
self.experts = nn.ModuleList([
FSDP(nn.Linear(d_model, d_model)) for _ in range(num_experts)
])
层级化 MoE(Hierarchical MoE / H-MoE)
专家本身也是 MoE 结构,形成树状路由。
class HierarchicalMoE(nn.Module):
def __init__(self, d_model, num_groups, experts_per_group):
super().__init__()
self.num_groups = num_groups
self.group_gate = nn.Linear(d_model, num_groups)
self.group_experts = nn.ModuleList([
SparseMoELayer(d_model, experts_per_group, top_k=1)
for _ in range(num_groups)
])
def forward(self, x):
group_weights = F.softmax(self.group_gate(x), dim=-1) # [B, S, G]
group_idx = torch.topk(group_weights, 1, dim=-1)[1] # [B, S, 1]
output = torch.zeros_like(x)
for g in range(self.num_groups):
mask = (group_idx == g).squeeze(-1) # [B, S]
if mask.sum() > 0:
output[mask] += self.group_experts[g](x[mask]) * group_weights[..., g].unsqueeze(-1)[mask]
return output
动态稀疏模式(Dynamic Sparsity)
根据输入内容动态决定 top_k 或专家数量。
mask = probs > threshold
创建了一个布尔掩码 mask,它的形状与 probs 相同。 [B, S, E]
mask 中的每个元素如果对应的专家概率大于 threshold,则为 True;否则为 False。这就像一个过滤器,只保留那些“值得考虑”的专家。
k = mask.sum(dim=-1).max().item()
mask.sum(dim=-1) 沿着最后一个维度(专家维度)对 mask 求和。对于每个 token,这个求和结果就是它选择了多少个专家(即有多少个 True 值)。[B, S]
.max():找到所有 token 中,选择专家数量的最大值。
.item():将结果从张量转换为一个 Python 整数。
为什么取最大值?
这是为了确保 torch.topk 函数能够成功执行。torch.topk 要求一个固定的 k 值,如果 k 太小,它可能无法包含所有满足阈值条件的专家。通过取最大值,我们可以保证 topk 操作能够覆盖所有 token 所需的最大专家数量。
总结:k值是动态变化 根据[B, S,E] 输入序列动态计算需要的专家数量
与静态 top-k 的区别?
静态 top-k: torch.topk(probs, k=2, dim=-1)。每个 token 总是选择 2 个专家,即使其中一个专家的概率非常低。概率非常低的也参与计算了
动态 top-k:
优点:更灵活、更智能。对于那些大多数专家概率都很低的 token,它会选择更少的专家(例如 1 个);对于那些多个专家概率都很高的 token,它可能会选择更多的专家(例如 3 个)。
计算效率:这种方法不一定总是比静态 topk 更高效。如果 threshold 设置得太低,导致 k 变得非常大(甚至接近 num_experts),反而会增加计算量。
残差连接与专家融合(Residual + Expert Fusion)
class ResidualMoE(nn.Module):
def __init__(self, d_model, num_experts, top_k=2):
super().__init__()
self.moe = SparseMoELayer(d_model, num_experts, top_k)
self.norm = nn.LayerNorm(d_model)
def forward(self, x):
residual = x
moe_out = self.moe(x)
return self.norm(residual + moe_out)
论文汇总
| 模型/项目 | 创新点 | 论文链接 |
|---|---|---|
| DeepSeekMoE | 细粒度专家分割 + 共享专家 + 层级路由 | arXiv:2401.06066 |
| FairSeq-MoE | 负载均衡 + 梯度裁剪 + 专家并行 | arXiv:2203.16257 |
| Switch Transformer | 简单高效路由 + 大规模训练 | arXiv:2101.03961 |
| H3 (Hungry Hungry Hippos) | 层级 MoE + 长序列建模 | arXiv:2205.07197 |

浙公网安备 33010602011771号