稀疏专家模型实现,关键流程分析
- 将数据压平
- 通过 nn.Linear(cfg.in_feature, cfg.expert_num) 得到专家权重
- 通过 torch.topk 得到 每个top上的权重,以及 以专家索引为value的 专家 (batch_size*seq_len,top_k) ,
它表示每个token在不同top_k 上的权重,以及每个token在不同topk_上对应的专家 - 通过 F.one_hot “以专家索引为value” 将value变换为索引,值为0,1
- 通过 变换将专家索引提到第一位
- 通过 torch.nonzero 得到 tok_k 索引列表和 token 索引列表。
它表示不同的top_k与 token之间的一一对应关系 - 通过token索引列表从1中拿到需要处理的数据, 并通过专家去执行 (token, output_featue)
- 通过 tok_k 索引列表和 token 索引列表 去 top权重上去获得权重 (此时数据为1维度,需要unsqueeze)
9 通过7和8的数据逐乘计算权重 - 用token 索引列表 和最终计算的权重 添加到目标输出结果中。
最终代码:
import torch
from torch import nn, Tensor
from torch.nn import functional as F
from xtransformer.moe.config import MoeConfig
from xtransformer.moe.expert import BasicExpert
class SparseMoe(nn.Module):
def __init__(self, cfg: MoeConfig):
super(SparseMoe, self).__init__()
self.cfg = cfg
self.expert_list = nn.ModuleList(
[
BasicExpert(
cfg.in_feature,
cfg.out_feature,
cfg.hidden_dim,
) for _ in range(cfg.expert_num)
]
)
self.share_expert = nn.ModuleList(
[
BasicExpert(
cfg.in_feature,
cfg.out_feature,
cfg.hidden_dim,
) for _ in range(cfg.share_num)
]
)
self.gate = nn.Linear(cfg.in_feature, cfg.expert_num)
def forward(self, x: Tensor) -> tuple[Tensor, Tensor]:
# x shape (batch_size,seq_len,hidden_dim)
# 由于每个专家的选中无法从x张量中抽离出新的张量,因为形状无法对齐, 先将x进行压平
# flat_x shape(batch_size * seq_len,-1)
flat_x = x.reshape(-1, x.size(-1))
# 计算专家权重
expert_weights = self.gate(flat_x)
expert_weights = F.softmax(expert_weights, dim=-1)
# topk_value, topk_idx shape (batch_size*seq_len,top_k)
# topk_value 的值是专家权重, topk_idx 值是专家索引,索引我value都是经过排序的
topk_value, topk_idx = expert_weights.topk(self.cfg.top_k, -1)
# 重新初始化权重
topk_value = topk_value / topk_value.sum(dim=-1, keepdim=True)
# 由于专家索引在 topk_idx的value上,我们需要将其转换为索引上, 所以最好的办法是topk_idx的value进行one_hot 编码
# 此时得到的 topk_idx shape (batch_size*seq_len,top_k,expert_num), 值为 0 和 1 , 1 代表选中的专家,,expert_num 表示专家索引
topk_idx = F.one_hot(topk_idx, num_classes=self.cfg.expert_num)
# topk_idx shape(expert_num,top_k,batch_size*seq_len)
topk_idx = topk_idx.permute(2, 1, 0)
# 初始化最终结果张量
final_ret = torch.zeros([flat_x.size(0), cfg.out_feature], device=x.device)
# 接下来的目的,是找对应类别的专家,去执行他们需要处理的token,并计算权重
for expert_idx in range(self.cfg.expert_num):
cur_expert = self.expert_list[expert_idx]
# cur shape (top_k,batch_size*seq_len) 其值为0 和 1 ,0表示被选中,, 1 表示未被选中
cur = topk_idx[expert_idx]
# selected_x 为查找非0的所有索引,
# 其返回值的元素个数等于cur_x 的维度,
# 返回的第一个值代表第一个维度上的索引, 第二个值代表第二维度上的索引 ....
# 每一个返回值都是一个元组, 元组的长度都相同
selected_topk_idx, selected_token_idx = torch.nonzero(cur, as_tuple=True)
# cur_x shape(selected_token_idx,hidden_dim)
cur_x = flat_x[selected_token_idx, :]
# cur_weight shape (selected_tokens) 值为权重
cur_weight = topk_value[selected_token_idx, selected_topk_idx]
cur_weight = cur_weight.unsqueeze(dim=-1)
# 专家执行 expert_ret shape(selected_token_idx,out_feature)
expert_ret = cur_expert(cur_x)
# 专家执行结果与权重进行计算
# expert_ret shape(selected_token_idx,out_feature)
expert_ret = expert_ret * cur_weight
final_ret.index_add_(0, selected_token_idx, expert_ret)
# reshape 到标准结果
final_ret = final_ret.reshape(x.size(0), x.size(1), cfg.out_feature)
# 计算共享专家的结果
for share_expert in self.share_expert:
final_ret = final_ret + share_expert(x)
return final_ret, expert_weights
if __name__ == '__main__':
cfg = MoeConfig()
sm = SparseMoe(cfg)
ret, _ = sm(torch.rand([2, 3, 1024]))
print(ret.shape)
知识是我们已知的
也是我们未知的
基于已有的知识之上
我们去发现未知的
由此,知识得到扩充
我们获得的知识越多
未知的知识就会更多
因而,知识扩充永无止境

浙公网安备 33010602011771号