优化器导入参数将Tensor结构转为cuda类型
optimizer.load_state_dict(checkpoint['optimizer'])
for state in optimizer.state.values():
for k, v in state.items():
if isinstance(v, torch.Tensor):
state[k] = v.cuda()

浙公网安备 33010602011771号