训练trick

1、https://www.zhihu.com/follow
2、训练时目标网络 从源网络集成 网络层参数,单又涉及torch.float16到torch.float32转换

 data_type = self.model.module.visual.attnpool.c_proj.weight.dtype # 目标网络类型
 p_weight = self.p_model.visual.attnpool.c_proj.weight.data  # 源参数
 p_weight = p_weight.to(data_type)      # 参数类型转换
 self.model.module.visual.attnpool.c_proj.weight = torch.nn.Parameter(p_weight, requires_grad=False)  # 参数赋值给目标参数
posted @ 2022-08-17 12:46  哈哈哈喽喽喽  阅读(30)  评论(0)    收藏  举报