pytorch强制转换模型的所有参数都变成统一类型

可以调用模型的父类Module中的type方法,例如model.type(torch.float64),将网络模型model的参数和缓冲区强制转换为torch.float64类型,这样就可以训练torch.float64类型的数据了,还可以指定其他类型。另外还有一些强制转换为某一种类型的方法:float()、double()、half()、bfloat16()

使用:

net=Net()

net.type(torch.float)

即可。

posted @ 2023-12-06 15:50  ZephyrYin  阅读(322)  评论(0)    收藏  举报