DeepSeek MOE 代码实现

前置知识:

PyTorch 基础函数操作整理

1. topk 操作

功能: torch.topk 用于返回输入张量中指定维度上的前 k 个最大元素及其对应的索引。

示例代码:

import torch

x = torch.tensor([[3, 1, 4],
                  [1, 5, 9],
                  [2, 6, 5]])

values, indices = torch.topk(x, k=2, dim=1)

print(values)
print(indices)

输出:

values: tensor([[4, 3],
                [9, 5],
                [6, 5]])

indices: tensor([[2, 0],
                 [2, 1],
                 [1, 2]])

2. scatter_ 操作

功能: torch.scatter_ 是一个原地操作函数,用于根据指定索引 index,将 src 中的元素分散到目标张量的指定位置。

示例代码:

import torch

indices = torch.tensor([[0, 2],
                        [1, 2],
                        [1, 2]])

result = torch.zeros([3, 3]).scatter_(1, indices, True)
print(result)

输出:

tensor([[1., 0., 1.],
        [0., 1., 1.],
        [0., 1., 1.]])

3. unsqueeze 操作

功能: 在指定维度上插入一个大小为 1 的新维度,从而改变张量的形状。

示例代码:

import torch

x = torch.tensor([[1, 2], [3, 4]])
print("原始张量形状:", x.shape)

y = torch.unsqueeze(x, dim=0)
print("在第 0 维插入新维度后的张量形状:", y.shape)

z = torch.unsqueeze(x, dim=1)
print("在第 1 维插入新维度后的张量形状:", z.shape)

w = torch.unsqueeze(x, dim=2)
print("在第 2 维插入新维度后的张量形状:", w.shape)

输出:

原始张量形状: torch.Size([2, 2])
在第 0 维插入新维度后的张量形状: torch.Size([1, 2, 2])
在第 1 维插入新维度后的张量形状: torch.Size([2, 1, 2])
在第 2 维插入新维度后的张量形状: torch.Size([2, 2, 1])

4. gather 操作

功能: torch.gather 根据给定的索引 index 从输入张量中收集元素,构建一个新的张量。(gatherscatter_ 互为反操作)

示例代码:

import torch

input_tensor = torch.tensor([[10, 20, 30],
                             [40, 50, 60],
                             [70, 80, 90]])

index_tensor = torch.tensor([[2, 0],
                              [1, 2],
                              [0, 1]])

output = torch.gather(input_tensor, dim=1, index=index_tensor)
print(output)

输出:

tensor([[30, 10],
        [50, 60],
        [70, 80]])

5. bincount 操作

功能: 统计非负整数张量中每个值出现的次数。

示例代码:

import torch

input_tensor = torch.tensor([1, 1, 2, 2, 10])
output = torch.bincount(input_tensor)
print(output)

输出:

tensor([0, 2, 2, 0, 0, 0, 0, 0, 0, 0, 1])

6. where 操作

功能: 根据给定的条件对张量元素进行选择性操作,类似于 Python 中的三元运算符,返回满足条件的元素索引。

示例代码:

import torch

input_tensor = torch.tensor([[1, 2], [3, 4]])
indices = torch.where(input_tensor == 2)
print(indices)

输出:

(tensor([0]), tensor([1]))

原始MOE 代码实现

import torch
from torch import nn

# ExpertNetwork 类:定义每个专家的网络
class ExpertNetwork(nn.Module):
    def __init__(self, hidden_size, intermediate_size):
        super().__init__()
        self.hidden_size = hidden_size                # 输入和输出的特征维度
        self.intermediate_size = intermediate_size    # 中间层的大小

        # 定义两个线性层
        self.linear1 = nn.Linear(hidden_size, intermediate_size)  # (batch_size, hidden_size) -> (batch_size, intermediate_size)
        self.linear2 = nn.Linear(intermediate_size, hidden_size)  # (batch_size, intermediate_size) -> (batch_size, hidden_size)

    def forward(self, x):
        x = self.linear1(x)                   # 经过第一个线性层
        x = nn.functional.relu(x)             # ReLU 激活函数
        output = self.linear2(x)              # 经过第二个线性层
        return output                         # 返回输出,尺寸为 (batch_size, hidden_size)

# Router 类:用于选择每个输入数据的专家
class Router(nn.Module):
    def __init__(self, hidden_size, expert_num, top_k):
        super().__init__()
        self.router = nn.Linear(hidden_size, expert_num)  # (batch_size, hidden_size) -> (batch_size, expert_num)
        self.top_k = top_k                    # 每次选择 top_k 个专家
        self.hidden_size = hidden_size        # 输入的特征维度

    def forward(self, x):
        x = x.view(-1, self.hidden_size)           # 展平输入,尺寸变为 (batch_size * seq_len, hidden_size)
        x = self.router(x)                         # 通过 router 得到每个专家的选择权重,尺寸为 (batch_size * seq_len, expert_num)
        x = nn.functional.softmax(x, dim=-1)       # 使用 softmax 转换为概率分布,尺寸为 (batch_size * seq_len, expert_num)
        topk_weight, topk_idx = torch.topk(x, k=self.top_k, dim=-1, sorted=False)  # 选择 top_k 个专家,尺寸为 (batch_size * seq_len, top_k)
        
        # 权重归一化,使得它们的和为 1,尺寸为 (batch_size * seq_len, top_k)
        topk_weight = topk_weight / topk_weight.sum(dim=-1, keepdim=True)
        
        return topk_weight, topk_idx  # 返回选择的 top_k 权重和专家索引

# MOELayer 类:实现混合专家层
class MOELayer(nn.Module):
    def __init__(self, hidden_size, intermediate_size, expert_num, top_k):
        super().__init__()
        self.hidden_size = hidden_size
        self.intermediate_size = intermediate_size
        self.expert_num = expert_num
        self.top_k = top_k

        # 定义多个专家网络
        self.experts = nn.ModuleList(
            [ExpertNetwork(self.hidden_size, self.intermediate_size) for _ in range(self.expert_num)]
        )

        # 定义路由器
        self.router = Router(self.hidden_size, self.expert_num, self.top_k)

    def forward(self, x):
        batch_size, seq_len, _ = x.size()  # 获取输入的尺寸,(batch_size, seq_len, hidden_size)
        token_num = batch_size * seq_len  # 计算总的 token 数量,(batch_size * seq_len)
        x_flat = x.view(token_num, self.hidden_size)  # 展平输入,尺寸为 (batch_size * seq_len, hidden_size)

        # 通过路由器获取 top_k 权重和索引
        topk_weight, topk_idx = self.router(x)
        
        # 初始化输出为零张量,尺寸为 (batch_size * seq_len, hidden_size)
        output = torch.zeros_like(x_flat)

        # 对于每个 token,选择 top_k 个专家进行计算
        for token_idx in range(token_num):  # 遍历所有 token
            for expert_idx in range(self.top_k):  # 遍历每个 token 的 top_k 个专家
                # 选择相应的专家,并计算其输出
                expert = self.experts[topk_idx[token_idx][expert_idx]]
                output[token_idx] += topk_weight[token_idx][expert_idx] * expert(x_flat[token_idx])  # 加权输出

        # 将输出恢复为原始形状 (batch_size, seq_len, hidden_size)
        output = output.view(batch_size, seq_len, self.hidden_size)
        return output

# 设置超参数
HIDDEN_SIZE = 4096
INTERMEDIATE_SIZE = 2048
EXPERT_NUM = 8
TOP_K = 2

# 输入张量,尺寸为 (batch_size, seq_len, hidden_size)
inputs = torch.randn((2, 11, 4096))

# 实例化 MOELayer
moe_layer = MOELayer(HIDDEN_SIZE, INTERMEDIATE_SIZE, EXPERT_NUM, TOP_K)

# 计算输出
outputs = moe_layer(inputs)

# 输出结果的尺寸
print(outputs.size())  # 输出尺寸: (batch_size, seq_len, hidden_size)

DeepSeek MoE

源代码请参考:https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py
此处只保留有用的部分

1. Transformer 结构

class Transformer(nn.Module):
    def __init__(self, args):
        self.embed = ...
        self.layers = torch.nn.ModuleList()
        for layer_id in range(args.n_layers):
            self.layers.append(Block(layer_id, args))
        self.norm = RMSNorm(args.dim)
        self.head = ColumnParallelLinear(...)
    
    def forward(self, tokens):
        h = self.embed(tokens)
        for layer in self.layers:
            h = layer(h, ...)
        h = self.norm(h)
        logits = self.head(h)
        return logits
  • 结构
    • self.embed:嵌入层,转换输入 token。
    • self.layers:使用 ModuleList 存储多个 Block 层。
    • self.norm:最终的 RMSNorm 归一化层。
    • self.head:输出层,使用 ColumnParallelLinear 进行并行计算。
  • 前向传播
    • 先通过 embed 进行 token 转换。
    • 依次经过多个 Block 层。
    • 经过 RMSNorm 归一化。
    • 通过 head 进行最终计算,返回 logits

2. Block 结构

class Block(nn.Module):
    def __init__(self, layer_id, args):
        self.attn = MLA(args)
        self.ffn = MoE(args)
        self.attn_norm = RMSNorm(args.dim)
        self.ffn_norm = RMSNorm(args.dim)

    def forward(self, x):
        x = x + self.attn(self.attn_norm(x))
        x = x + self.ffn(self.ffn_norm(x))
        return x
  • 主要包含:
    • MLA(多头注意力层)。
    • MoE(专家混合机制)。
    • 两个 RMSNorm 层。
  • 前向传播
    • 归一化后,输入 attn 并进行残差连接。
    • 归一化后,输入 MoE 并进行残差连接。

3. MoE(专家混合)

class MoE(nn.Module):
    def __init__(self, args):
        self.n_routed_experts = ...  # 路由专家个数
        self.n_activated_experts = ...  # 激活专家个数
        self.gate = Gate(args)  # 路由选择 Router
        self.experts = nn.ModuleList([Expert(...) for i in range(self.n_routed_experts)])
        self.shared_experts = MLP(...)  # 共享专家网络列表

    def forward(self, x):
        weights, indices = self.gate(x)  # 选择专家及权重
        y = torch.zeros_like(x)  # 初始化
        counts = torch.bincount(indices.flatten(), minlength=self.n_routed_experts).to(x.device)
        for i in range(self.n_routed_experts):  # 轮询专家
            if counts[i] == 0:
                continue
            expert = self.experts[i]
            idx, top = torch.where(indices == i)
            y[idx] += expert(x[idx]) * weights[idx, None]  # 加权累加
        z = self.shared_experts(x)
        return (y + z)
  • Gate 选择 top-k 个专家,并给出权重。
  • torch.bincount 统计每个专家的使用次数。
  • 依次对 n_routed_experts 轮询:
    • 找到被选中的 token。
    • 经过 Expert 计算并加权累加。
    • shared_experts 计算并加上结果。

4. Expert(专家网络)

class Expert(nn.Module):
    def __init__(self, dim, inter_dim):
        self.w1 = Linear(dim, inter_dim)
        self.w2 = Linear(inter_dim, dim)
        self.w3 = Linear(dim, inter_dim)

    def forward(self, x):
        return self.w2(F.silu(self.w1(x)) * self.w3(x))
  • 结构:
    • w1, w3:分别将 dim → inter_dim
    • w2:将 inter_dim → dim
  • 前向计算:
    • SILU 激活后,与 w3(x) 相乘,再通过 w2

5. Gate(路由选择)

class Gate(nn.Module):
    def __init__(self, args):
        self.topk = args.n_activated_experts
        self.n_groups = args.n_expert_groups
        self.topk_groups = args.n_limited_groups
        self.weight = nn.Parameter(torch.empty(args.n_routed_experts, args.dim))
        self.bias = nn.Parameter(torch.empty(args.n_routed_experts))

    def forward(self, x):
        scores = linear(x, self.weight)  # 计算专家得分
        scores = scores.sigmoid()
        scores = scores.view(x.size(0), self.n_groups, -1)
        group_scores = scores.topk(2, dim=-1)[0].sum(dim=-1)  # 选择最高的 2 个分数
        indices = group_scores.topk(self.topk_groups, dim=-1)[1]  # 选择 top 4 组
        mask = torch.zeros_like(scores[..., 0]).scatter_(-1, indices, True)
        scores = (scores * mask.unsqueeze(-1)).flatten(1)
        indices = torch.topk(scores, self.topk, dim=-1)[1]
        weights = original_scores.gather(1, indices)
        return weights, indices
  • 步骤
    • linear(x, weight) 计算 x 对所有专家的分数。
    • 经过 sigmoid 归一化分数。
    • 计算 top-k,选出最优的 4 组。
    • scatter_ 生成 mask 进行筛选。
    • topk 选取最终 k 个专家,并获取权重。

6. MLP(共享专家网络)

class MLP(nn.Module):
    def __init__(self, dim, inter_dim):
        self.w1 = ColumnParallelLinear(dim, inter_dim)
        self.w2 = RowParallelLinear(inter_dim, dim)
        self.w3 = ColumnParallelLinear(dim, inter_dim)

    def forward(self, x):
        return self.w2(F.silu(self.w1(x)) * self.w3(x))
  • ColumnParallelLinear & RowParallelLinear 进行并行计算。
  • SILU 激活后,与 w3(x) 相乘,再通过 w2 计算。
posted @ 2025-03-03 18:57  AAA建材王师傅  阅读(392)  评论(3)    收藏  举报