训练trick
1、https://www.zhihu.com/follow
2、训练时目标网络 从源网络集成 网络层参数,单又涉及torch.float16到torch.float32转换
data_type = self.model.module.visual.attnpool.c_proj.weight.dtype # 目标网络类型
p_weight = self.p_model.visual.attnpool.c_proj.weight.data # 源参数
p_weight = p_weight.to(data_type) # 参数类型转换
self.model.module.visual.attnpool.c_proj.weight = torch.nn.Parameter(p_weight, requires_grad=False) # 参数赋值给目标参数

浙公网安备 33010602011771号