大模型- moe技术汇总-95

学术界和工业界确实持续在 Mixture of Experts (MoE) 架构上进行创新,

主流 MoE 架构核心组件回顾-基础 MoE 结构

import torch
import torch.nn as nn
import torch.nn.functional as F


class SparseMoELayer(nn.Module):

    def __init__(self, d_model, num_experts, top_k=2):
        super().__init__()

        self.num_experts = num_experts
        self.top_k = top_k

        self.gate = nn.Linear(d_model, num_experts)

        self.experts = nn.ModuleList([nn.Linear(d_model, d_model) for _ in range(num_experts)])

    def forward(self, x):

        # 1. 计算每个token到每个专家的权重
        router_logits = self.gate(x)  # [B, S, E]  E为专家的个数
        routing_weights = F.softmax(router_logits, dim=-1)

        # 2. 选择top_k个专家 [B, S, K]  E变成K
        topk_weight, topk_idx = torch.topk(routing_weights, self.top_k, dim=-1)
        """
        topk_weight:
        tensor([[[0.1600, 0.1509],
                 [0.1587, 0.0914],
                 [0.1082, 0.1051],
                 [0.1114, 0.1078],
                 [0.1190, 0.1107],
                 [0.1359, 0.1147],
                 [0.1092, 0.0960],
                 [0.1351, 0.1074],
                 [0.1493, 0.1097],
                 [0.1048, 0.0910]]], grad_fn=<TopkBackward0>)
        
        
        index_id:
        tensor([[[ 9,  2],  # 第一个token由 9,2这两个expert来计算
                 [14, 12],
                 [ 1, 15],
                 [ 3, 13],
                 [11, 14],
                 [ 2,  0],
                 [13,  0],
                 [13,  5],
                 [ 0,  7],
                 [ 8,  1]]])
        """

        # 3. 初始化输出
        output = torch.zeros_like(x)   # [B, S, D]

        # 4. 处理topk_weight, topk_idx
        for k in range(self.top_k):
            expert_idx = topk_idx[..., k]   # [B, S]
            expert_weight = topk_weight[..., k].unsqueeze(-1)  # [B, S, 1]
            '''
            expert_weight:
            tensor([[[0.1600, ],
                     [0.1587, ],
                     [0.1082, ],
                     [0.1114, ],
                     [0.1190, ],
                     [0.1359, ],
                     [0.1092, ],
                     [0.1351, ],
                     [0.1493, ],
                     [0.1048, ]]], grad_fn=<TopkBackward0>)
            
            
            expert_idx:
            tensor([[[ 9, ],
                     [14, ],
                     [ 1, ],
                     [ 3, ],
                     [11, ],
                     [ 2, ],
                     [13, ],
                     [13, ],
                     [ 0, ],
                     [ 8, ]]])
            
            '''

            # 提取专家处理的token
            mask = F.one_hot(expert_idx, num_classes=self.num_experts).bool()  # [B, S, E]
            '''
            tensor([[[False, False, False, False, False, False, False, False, False, True, False, False, False, False, False, False],
                     [False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, False],
                     [False, True, False, False, False, False, False, False, False, False, False, False, False, False, False, False],
                     [False, False, False, True, False, False, False, False, False, False, False, False, False, False, False, False],
                     [False, False, False, False, False, False, False, False, False, False, False, True, False, False, False, False],
                     [False, False, True, False, False, False, False, False, False, False, False, False, False, False, False, False],
                     [False, False, False, False, False, False, False, False, False, False, False, False, False, True, False, False],
                     [False, False, False, False, False, False, False, False, False, False, False, False, False, True, False, False],
                     [True, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False],
                     [False, False, False, False, False, False, False, False, True, False, False, False, False, False, False, False]]])
            '''
            for i in range(self.num_experts):
                # tensor([[False, False, False, False, False, False, False, False,  True, False]])
                token_mask = mask[..., i]  # [B, S]
                if token_mask.sum() == 0:
                    continue
                expert_output = self.experts[i](x[token_mask])
                output[token_mask] += expert_output *expert_weight[token_mask]
        return output


if __name__ == '__main__':
    moe = SparseMoELayer(d_model=1024, num_experts=16, top_k=2)
    print(moe(torch.randn(1, 10, 1024)).shape)


特点:每个 token 只激活 top_k 个专家,大幅降低计算量。

负载均衡(Load Balancing Loss)

def load_balancing_loss(router_probs, expert_mask):
    """
    router_probs: [B, S, E] 每个 token 分配给专家的概率
    expert_mask: [B, S, E] 实际是否分配给该专家(0/1)
    """
    # 专家负载:实际被使用的频率
    importance = router_probs.sum(dim=[0, 1])  # [E]
    # 路由频率:被路由到的总次数
    usage = expert_mask.sum(dim=[0, 1])  # [E]
    # 平衡损失
    loss = F.mse_loss(usage, importance) * self.num_experts
    return loss

# 在 forward 中返回 loss
loss = load_balancing_loss(routing_weights, mask)

Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity

专家并行(Expert Parallelism)

将专家分布到多个 GPU 上,解决显存瓶颈。

# 使用 PyTorch + FSDP 或 DeepSpeed 实现专家并行
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

# 将每个专家包装为独立 FSDP 模块
self.experts = nn.ModuleList([
    FSDP(nn.Linear(d_model, d_model)) for _ in range(num_experts)
])

层级化 MoE(Hierarchical MoE / H-MoE)

专家本身也是 MoE 结构,形成树状路由。

class HierarchicalMoE(nn.Module):
    def __init__(self, d_model, num_groups, experts_per_group):
        super().__init__()
        self.num_groups = num_groups
        self.group_gate = nn.Linear(d_model, num_groups)
        self.group_experts = nn.ModuleList([
            SparseMoELayer(d_model, experts_per_group, top_k=1)
            for _ in range(num_groups)
        ])

    def forward(self, x):
        group_weights = F.softmax(self.group_gate(x), dim=-1)  # [B, S, G]
        group_idx = torch.topk(group_weights, 1, dim=-1)[1]  # [B, S, 1]

        output = torch.zeros_like(x)
        for g in range(self.num_groups):
            mask = (group_idx == g).squeeze(-1)  # [B, S]
            if mask.sum() > 0:
                output[mask] += self.group_experts[g](x[mask]) * group_weights[..., g].unsqueeze(-1)[mask]
        return output

动态稀疏模式(Dynamic Sparsity)

根据输入内容动态决定 top_k 或专家数量。
mask = probs > threshold
创建了一个布尔掩码 mask,它的形状与 probs 相同。 [B, S, E]
mask 中的每个元素如果对应的专家概率大于 threshold,则为 True;否则为 False。这就像一个过滤器,只保留那些“值得考虑”的专家。

k = mask.sum(dim=-1).max().item()

mask.sum(dim=-1) 沿着最后一个维度(专家维度)对 mask 求和。对于每个 token,这个求和结果就是它选择了多少个专家(即有多少个 True 值)。[B, S]
.max():找到所有 token 中,选择专家数量的最大值。
.item():将结果从张量转换为一个 Python 整数。

为什么取最大值?
这是为了确保 torch.topk 函数能够成功执行。torch.topk 要求一个固定的 k 值,如果 k 太小,它可能无法包含所有满足阈值条件的专家。通过取最大值,我们可以保证 topk 操作能够覆盖所有 token 所需的最大专家数量。

总结:k值是动态变化 根据[B, S,E] 输入序列动态计算需要的专家数量

与静态 top-k 的区别?
静态 top-k: torch.topk(probs, k=2, dim=-1)。每个 token 总是选择 2 个专家,即使其中一个专家的概率非常低。概率非常低的也参与计算了
动态 top-k:
优点:更灵活、更智能。对于那些大多数专家概率都很低的 token,它会选择更少的专家(例如 1 个);对于那些多个专家概率都很高的 token,它可能会选择更多的专家(例如 3 个)。
计算效率:这种方法不一定总是比静态 topk 更高效。如果 threshold 设置得太低,导致 k 变得非常大(甚至接近 num_experts),反而会增加计算量。

残差连接与专家融合(Residual + Expert Fusion)

class ResidualMoE(nn.Module):
    def __init__(self, d_model, num_experts, top_k=2):
        super().__init__()
        self.moe = SparseMoELayer(d_model, num_experts, top_k)
        self.norm = nn.LayerNorm(d_model)

    def forward(self, x):
        residual = x
        moe_out = self.moe(x)
        return self.norm(residual + moe_out)

论文汇总

模型/项目 创新点 论文链接
DeepSeekMoE 细粒度专家分割 + 共享专家 + 层级路由 arXiv:2401.06066
FairSeq-MoE 负载均衡 + 梯度裁剪 + 专家并行 arXiv:2203.16257
Switch Transformer 简单高效路由 + 大规模训练 arXiv:2101.03961
H3 (Hungry Hungry Hippos) 层级 MoE + 长序列建模 arXiv:2205.07197
posted @ 2025-09-01 14:41  jack-chen666  阅读(36)  评论(0)    收藏  举报