自定义 DataLoader 时应使用 Unix 系统

自定义 Dataset 类

PyTorch 允许自定义 Dataset 类,并由此获得 DataLoader,能方便训练时获得 batch:

from torch.utils.data import DataLoader, Dataset
import h5py
import os


class RadarDataset(Dataset):

    def __init__(self, directory):
        ...

    def __len__(self):
        ...

    def __getitem__(self, idx):
        ...


dataset = RadarDataset(r'/mnt/z/automotive_pre_processed')
dataloader = DataLoader(
    dataset,
    batch_size=128,
    shuffle=True,
    num_workers=8,
    persistent_workers=True,
    prefetch_factor=3,
)

顺便一提,若要在自定义 Dataset 类中使用 h5py 库,要注意文件句柄不能在 __init__ 里获取和存储。因为 __init__ 中出现的成员会被其他 worker 共享,而 h5py 不允许。

现在问题出现了。使用这个 dataloader,总是出现错误:

dataloader worker (pid(s) 9144, 29312, 25764, 26220, 27448, 27116) exited unexpectedly

问题的解决方法

PyTorch 文档 提到过这个问题。

在Unix上,默认 fork() 启动多进程,子 worker 可以直接访问数据集和 Python 参数函数;

在 Windows 或 MacOS 上,默认 spawn() 启动多进程,即启动另一个解释器运行主脚本,然后内部的工作函数通过 pickle 序列化接收数据集、collate_fn和其他参数。

简单来说,用 Linux 吧,就没这些破事了。

posted @ 2024-03-01 21:22  倒地  阅读(4)  评论(0编辑  收藏  举报