Fork me on github

PyTorch使用LMDB加快数据集访问速度

下面给出的代码,允许用户使用LMDataset对象,加快数据集的访问速度。它预先读取传入dataset中的数据,并存储于LMDB数据库中。在ImageNet的测试表明,它能够加快图像读取速度4.25倍。

使用代码如下:

from LMDataset import LMDataset
from torchvision.datasets import ImageNet


if __name__ == "__main__":
    print("Begin Caching")
    dataset_train: Dataset[torch.Tensor] = ImageNet(
        "~/dataset/ImageNet-1k", split="train"
    )
    cached_dataset = LMDataset(dataset_train, "~/dataset/ImageNet-1k", "train")

    for X, y in cached_dataset:
        # ...

源码LMDataset.py如下:

# -*- coding: utf-8 -*-
import lmdb
import pickle
from tqdm import tqdm
import multiprocessing as mp
from os import path, makedirs
from torch.utils.data import Dataset, DataLoader
from typing import TypeVar, Optional, Literal, Sized, Tuple, Callable
from torch.nn import Identity

__all__ = ["LMDataset"]

# Begin Configurations
MAX_SIZE = 1024**4  # 1TB
NUM_WORKERS = max(8, mp.cpu_count())
MAX_READERS = max(128, 2 * mp.cpu_count())
PICKLE_PROTOCOL = pickle.HIGHEST_PROTOCOL
INDEX_SIZE = 4
BATCH_SIZE = 256
# End Configurations


dtype = TypeVar("dtype", covariant=True)
transform_type = TypeVar("transform_type", covariant=True)


class _PickleWrapper(Dataset[bytes]):
    def __init__(self, dataset: Dataset[dtype]) -> None:
        super().__init__()
        assert isinstance(dataset, Sized)
        self.dataset = dataset

    def __len__(self) -> int:
        return len(self.dataset)

    def __getitem__(self, index: int) -> Tuple[bytes, bytes]:
        value = pickle.dumps(self.dataset[index], protocol=PICKLE_PROTOCOL)
        key = index.to_bytes(INDEX_SIZE, "little")
        return key, value


class LMDataset(Dataset[transform_type]):
    def __init__(
        self,
        dataset: Optional[Dataset[dtype]],
        root: str,
        split: Literal["train", "val", "test"],
        transform: Callable[[dtype], transform_type] = Identity(),
        desc: Optional[str] = "Caching Dataset",
    ) -> None:
        assert isinstance(dataset, Sized)
        self.root = path.join(path.expanduser(root), ".LMDB", f"{split}.mdb")
        self.dataset = dataset
        makedirs(path.dirname(self.root), exist_ok=True)

        # Cache Dataset if not already cached
        if not path.isfile(self.root):
            if dataset is not None:
                self._cache_dataset(dataset, desc=desc)

        # Open Read-only LMDB Environment
        self.env = lmdb.Environment(
            self.root,
            map_size=MAX_SIZE,
            subdir=False,
            readonly=True,
            max_readers=MAX_READERS,
            max_dbs=4,
            lock=False,
        )

        # Check if dataset is complete
        self.length = self._len_db()
        if dataset is not None:
            assert len(dataset) == self.length, "Dataset Length Mismatch!"

        self.transform = transform

    def _cache_dataset(
        self, dataset: Dataset[dtype], desc: Optional[str] = None
    ) -> None:
        assert isinstance(dataset, Sized)
        makedirs(path.dirname(self.root), exist_ok=True)

        pickle_dataset = _PickleWrapper(dataset)

        data_loader = DataLoader(
            pickle_dataset,
            batch_size=BATCH_SIZE,
            num_workers=NUM_WORKERS,
            prefetch_factor=8,
            persistent_workers=True,
        )

        if desc is not None:
            data_loader = tqdm(data_loader, desc=desc)

        with lmdb.Environment(
            self.root,
            map_size=MAX_SIZE,
            subdir=False,
            readonly=False,
            max_readers=1,
            readahead=False,
            sync=False,
            map_async=True,
            metasync=False,
            meminit=False,
            lock=True,
        ) as write_env:
            with write_env.begin(write=True, buffers=True) as txn:
                for batch in data_loader:
                    for key, value in zip(*batch):
                        txn.put(key, value)

    def __len__(self) -> int:
        return self.length

    def __getitem__(self, index: int) -> transform_type:
        key = index.to_bytes(INDEX_SIZE, "little")
        with self.env.begin(write=False, buffers=True) as txn:
            value: bytes = txn.get(key)  # type: ignore
        return self.transform(pickle.loads(value))

    def _len_db(self) -> int:
        with self.env.begin(write=False, buffers=True) as txn:
            stat = txn.stat(db=None)
        return stat["entries"]

    def __getattr__(self, name: str):
        return getattr(self.dataset, name, getattr(self.dataset, name))
posted @ 2024-04-01 16:07  fang-d  阅读(35)  评论(0编辑  收藏  举报