device

Ref:CSDN

device = torch.device('cuda' if (self.worker == 'gpu' and torch.cuda.is_available()) else 'cpu')

if torch.cuda.device_count() > 1:  # 多gpu
    model = torch.nn.DataParallel(model, device_ids=[x for x in range(self.config.gpu_num)])

几个需要添加to.device的地方

  1. model(如:model.to(device))
  2. input(通常需要使用Variable包装,如:input = Variable(input).to(device))
  3. target(通常需要使用Variable包装,如:target = Variable(torch.from_numpy(np.array(target)).long()).to(device)
  4. nn.CrossEntropyLoss()(如:criterion = nn.CrossEntropyLoss().to(device))
posted @ 2021-11-04 15:45  小康要好好学习  阅读(465)  评论(0编辑  收藏  举报