pytorch-multi-gpu-training

  • data_parallel方法
  1. 代码链接
  2. 两处改动
    1. 改动一:py文件最开始导入模块处
    os.environ["CUDA_VISIBLE_DEVICES"]="2,3"  # 必须在`import torch`语句之前设置才能生效
    
    1. 改动二:模型实例化处
    model = Net()
    model = model.to(device)
    model = nn.DataParallel(model)  # 就在这里wrap一下,模型就会使用所有的GPU
    

参考资料:[1] https://github.com/jia-zhuang/pytorch-multi-gpu-training

posted @ 2023-07-05 17:46  Elina-Chang  阅读(14)  评论(0)    收藏  举报