pytorch-multi-gpu-training
- data_parallel方法
- 代码链接
- 两处改动
- 改动一:py文件最开始导入模块处
os.environ["CUDA_VISIBLE_DEVICES"]="2,3" # 必须在`import torch`语句之前设置才能生效- 改动二:模型实例化处
model = Net() model = model.to(device) model = nn.DataParallel(model) # 就在这里wrap一下,模型就会使用所有的GPU
参考资料:[1] https://github.com/jia-zhuang/pytorch-multi-gpu-training

浙公网安备 33010602011771号