from torch.multiprocessing.reductions import reduce_tensor || CUDA error: invalid device context
这是一个非常极客的需求。要实现“调用方式不变”且“单进程绕过”,我们需要对 reduce_tensor 的返回结果进行一层高阶函数封装。
1. 多进程重建示例
首先,我们先看在真正的子进程中如何处理。在多进程环境下,reduce_tensor 产生的句柄是可以被正常打开的。
import torch
import torch.multiprocessing as mp
from torch.multiprocessing.reductions import reduce_tensor
def worker(reduce_info):
# 子进程接收到元组
rebuild_fn, args = reduce_info
# 执行重建
reconstructed = rebuild_fn(*args)
print(f"子进程重建成功: {reconstructed.shape}, Device: {reconstructed.device}")
print(f"数值校验: {reconstructed.sum().item():.4f}")
if __name__ == "__main__":
mp.set_start_method('spawn', force=True) # CUDA IPC 必须使用 spawn
tensor = torch.rand(size=(1, 10, 10), dtype=torch.float32, device="cuda:0")
print(f"原进程 Tensor Sum: {tensor.sum().item():.4f}")
# 序列化
reduce_info = reduce_tensor(tensor)
# 启动子进程
p = mp.Process(target=worker, args=(reduce_info,))
p.start()
p.join()
2. 结论:为什么单进程会崩?
- 句柄排他性:CUDA IPC 句柄设计用于进程间所有权转移。当进程尝试
Open一个它已经Handle的内存时,驱动程序会认为这是一个无效的上下文逻辑错误。 - 状态依赖:
rebuild_cuda_tensor内部会调用底层 C++ 的_new_shared_cuda,它强制触发 IPC 映射流程。
3. 终极方案:adaptive_reduce_tensor
为了满足你“不改调用逻辑”且“单进程绕过”的要求,我们可以包装一个包装器。它会记录创建时的 PID,在执行重建时判断:如果是同一个进程,直接返回原引用;如果是新进程,走 IPC。
import os
import torch
from torch.multiprocessing.reductions import reduce_tensor
def _adaptive_rebuild_wrapper(original_tensor, origin_pid, rebuild_fn, args):
"""
自适应重建函数
"""
# 如果当前进程就是产生该 tensor 的进程,直接返回原对象,绕过 CUDA IPC
if os.getpid() == origin_pid:
return original_tensor
# 如果是子进程,则执行原定的重建逻辑
return rebuild_fn(*args)
def adaptive_reduce_tensor(tensor):
"""
包装后的 reduce 方法
"""
# 1. 获取原始的 reduce 信息 (rebuild_fn, args)
rebuild_fn, args = reduce_tensor(tensor)
# 2. 构造新的 reduce_info
# 结构保持为: (函数, 参数元组)
new_fn = _adaptive_rebuild_wrapper
new_args = (tensor, os.getpid(), rebuild_fn, args)
return new_fn, new_args
# --- 测试代码 ---
tensor = torch.rand(size=(1, 10, 10), dtype=torch.float32, device="cuda:0")
# 使用新包装
reduce_info = adaptive_reduce_tensor(tensor)
# 无论在单进程还是多进程,调用方式完全统一:
reconstructed = reduce_info[0](*reduce_info[1])
print(f"重建成功!ID是否一致: {id(tensor) == id(reconstructed)}")
# 单进程下输出 True,且不会报 AcceleratorError
为什么这个方法完美?
- 接口兼容:返回的依然是
(Callable, Tuple),完美适配res[0](*res[1])。 - 性能无损:单进程下只是多了一层
if判断,直接返回引用,连内存拷贝都省了。 - 透明性:当这个
reduce_info被pickle序列化到子进程时,original_tensor会被序列化过去(或者通过 IPC 句柄重建),os.getpid()判定失效,自动走标准的rebuild_fn逻辑。
需要我帮你把这个逻辑集成到一个更完整的分布式调试工具类中吗?
本博文本意在于记录个人的思考与经验,部分博文采用英语写作,可能影响可读性,请见谅
本文来自博客园,作者:ZXYFrank,转载请注明原文链接:https://www.cnblogs.com/zxyfrank/p/19547242

浙公网安备 33010602011771号