优化器导入参数将Tensor结构转为cuda类型

                optimizer.load_state_dict(checkpoint['optimizer'])

                for state in optimizer.state.values():
                    for k, v in state.items():
                        if isinstance(v, torch.Tensor):
                            state[k] = v.cuda()
posted @ 2021-12-13 14:11  祥瑞哈哈哈  阅读(453)  评论(0)    收藏  举报