12.6.3 训练
非常棒的问题!👏
你问的这两点:
- 为什么
X, y = X.to(devices[0]), y.to(devices[0])是必须的? net(X)是如何决定数据在 GPU 之间分配的?
我们来一个一个详细解释。
🔍 第一个问题:
为什么 X, y = X.to(devices[0]), y.to(devices[0]) 是必须的?
✅ 简单回答:
因为你的模型被包装成了
nn.DataParallel,它要求输入数据必须已经在目标设备(通常是主设备devices[0])上,才能启动多 GPU 推理/训练。
🧠 深入理解:
1. nn.DataParallel 的工作原理
当你写:
net = nn.DataParallel(net, device_ids=devices)
- PyTorch 会把你的模型复制到多个 GPU 上(每个
device_ids对应一个副本); - 在每次前向传播时:
- 输入数据会被自动切分;
- 每个 GPU 上的模型副本处理一部分输入;
- 最后输出结果会在主设备(默认是
device_ids[0])上合并;
2. 为什么输入数据要放到 devices[0]?
DataParallel的设计逻辑是:输入数据必须位于主设备上(即device_ids[0]);- 它会从主设备获取输入张量,然后切分并发送到其他 GPU;
- 如果你传入的数据不在
device_ids[0]上,就会报错,例如:
RuntimeError: module must have its parameters and buffers on device cuda:0 (device_ids[0]) but found one of them on device cuda:1
✅ 所以总结一下这行代码的作用:
X, y = X.to(devices[0]), y.to(devices[0])
- 把输入数据显式移动到主设备(也就是第一个 GPU);
- 这样
DataParallel才能正常工作; - 否则,PyTorch 不知道该从哪里开始分割数据。
🔍 第二个问题:
net(X) 是如何决定数据在 GPU 之间的分配的?
✅ 简单回答:
nn.DataParallel会自动将输入数据切分,并分发到各个 GPU 上运行模型副本;
输出结果会在主设备(device_ids[0])上合并。
🧠 工作流程详解:
假设你有 4 个 GPU:[cuda:0, cuda:1, cuda:2, cuda:3]
net = nn.DataParallel(net, device_ids=[0, 1, 2, 3])
当你调用:
output = net(X) # X 在 cuda:0 上
PyTorch 会执行以下步骤:
- 检查输入是否在主设备上(这里是
cuda:0)- 如果不是,抛出错误;
- 将输入
X切分成 4 份(按 batch 维度)- 比如原来的 shape 是
[256, 3, 224, 224],会被切成 4 个[64, 3, 224, 224];
- 比如原来的 shape 是
- 将每一份输入发送到对应的 GPU
- 第一份送到
cuda:0,第二份送到cuda:1,依此类推;
- 第一份送到
- 在每个 GPU 上运行模型副本进行推理/训练
- 每个 GPU 上的模型副本是相同的;
- 收集所有 GPU 的输出,并在主设备上拼接成完整输出
📌 示例图解:
Input X (on cuda:0)
↓
+-------------------+
| DataParallel |
| - Split X |
| - Send to GPUs |
| - Run models |
| - Collect outputs |
+-------------------+
↓
Output (on cuda:0)
✅ 总结一句话:
使用
nn.DataParallel时,你需要确保输入数据已经放在主设备上(如devices[0]),因为DataParallel会从这个设备开始切分和分发数据。模型副本分布在多个 GPU 上,计算完成后结果会在主设备上合并。
🚀 小贴士:DataParallel vs DistributedDataParallel
| 特性 | nn.DataParallel |
torch.nn.parallel.DistributedDataParallel |
|---|---|---|
| 是否支持多进程 | ❌ 否 | ✅ 是 |
| 多卡效率 | ⛽ 较低(单进程) | ⚡ 更高(多进程 + 多线程) |
| 是否需要手动划分数据 | ✅ 是(通常使用 DataLoader) | ✅ 是(使用 DistributedSampler) |
| 推荐场景 | 快速原型开发 | 生产环境、大规模训练 |
如果你对更现代的分布式训练方式感兴趣(如 DistributedDataParallel、FSDP、accelerate 等),我也可以继续讲解 😎

浙公网安备 33010602011771号