Grain用于读取和处理用于训练和评估 JAX 模型的数据

Grain介绍

Grain官方Docs

  • 用户可以带来任意的 Python 转换。
  • Grain 被设计为模块化。如果需要,用户可以使用自己的实现轻松覆盖 Grain 组件。
  • 同一管道的多次运行将产生相同的输出。
  • Grain 的设计使得检查点的大小最小。抢占后,Grain 可以从中断的地方恢复,并产生与从未被抢占相同的输出。
  • 我们在设计 Grain 时小心翼翼地确保其性能良好(请参阅文档的幕后部分。我们还针对多种数据模式(例如文本/音频/图像/视频)对其进行了测试。
  • Grain 会尽可能减少其依赖项集。例如,它不应该依赖于 TensorFlow。

Grain安装

pip install grain

导入所有库

from pathlib import Path
import grain
import grain.python as pygrain
import cv2
import albumentations as A
import pandas as pd

创建Grain数据源

参考Grain官网要求,数据源需要实现两个魔法方法:

class RandomAccessDataSource(Protocol, Generic[T]):
  """Interface for datasources where storage supports efficient random access."""

  def __len__(self) -> int:
    """Number of records in the dataset."""

  def __getitem__(self, record_key: SupportsIndex) -> T:
    """Retrieves record for the given record_key."""

自用数据分类的数据源

from pathlib import Path


def get_image_extensions() -> tuple[str, ...]:
    """Returns a tuple of common image file extensions.

    This function provides a centralized list of supported image file types,
    making it easy to manage and update.

    Returns:
        A tuple of strings, where each string is an image file extension
        (e.g., ".jpg", ".png").

    """
    return (".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".tif", ".webp")


class ImageFolderDataSource:
    """A data source class that mimics the structure of torchvision.datasets.ImageFolder.

    It expects the root directory to contain subdirectories, where each subdirectory
    represents a class and contains images belonging to that class.

    Example directory structure:
    root_dir/
    ├── class_a/
    │   ├── image1.jpg
    │   └── image2.png
    └── class_b/
        ├── image3.jpeg
        └── image4.bmp
    """

    def __init__(self, root_dir: str):
        """Initializes the ImageFolderDataSource.

        Args:
            root_dir: The path to the root directory containing class subdirectories.

        """
        self.root_dir = Path(root_dir)
        self.samples: list[tuple[str, int]] = []
        self.classes: list[str] = []
        self._load_samples()

    def _load_samples(self) -> None:
        """Loads image file paths and their corresponding class indices from the root directory.

        This method iterates through subdirectories, treats each subdirectory name as a class,
        and collects all valid image files within them.
        """
        if not self.root_dir.exists():
            msg = f"Root directory {self.root_dir} does not exist."
            raise FileNotFoundError(msg)

        class_to_idx = {}
        valid_extensions = get_image_extensions()

        class_dirs = [d for d in self.root_dir.iterdir() if d.is_dir()]
        class_dirs.sort(key=lambda x: x.name)

        for class_dir in class_dirs:
            class_name = class_dir.name
            class_idx = len(class_to_idx)
            class_to_idx[class_name] = class_idx
            self.classes.append(class_name)

            for ext in valid_extensions:
                for img_path in class_dir.glob(f"*{ext}"):
                    self.samples.append((str(img_path), class_idx))

        if not self.samples:
            msg = f"No valid images found in directory '{self.root_dir}'"
            raise RuntimeError(msg)
	# 实现魔法方法__len__
    def __len__(self) -> int:
        """Returns the total number of samples (images) in the dataset.
        This allows the dataset object to be used with `len()`.
        """
        return len(self.samples)
	# 实现魔法方法__getitem__
    def __getitem__(self, index: int) -> tuple[str, int]:
        """Returns a sample (image path and its class index) at the given index.
        This allows the dataset object to be indexed like a list (e.g., `dataset[0]`).
        """
        return self.samples[index]

CLIP数据加载DataSource

class CLIPDataSource:
    def __init__(self, csv_file):
        df = pd.read_csv(csv_file)
        img_paths = df["image_path"].tolist()
		texts = df['txt'].tolist()

    def __len__(self) -> int:
        return len(self.samples)

    def __getitem__(self, idx: int) -> tuple[str, str]:
        return self.samples[idx]

创建IndexSampler(类似PyTorch的Sampler)

# 实例化数据源
train_dataset = ImageFolderDataSource(params.train_data_path)
train_sampler=pygrain.IndexSampler(
        num_records=len(train_dataset),
        shuffle=True,
        seed=params.seed,
        shard_options=pygrain.NoSharding(),
        num_epochs=1,
    )

创建DataLoader

在grain中需要自己写数据加载流。

OpenCVLoadImageMap

使用OpenCV读取图像

class OpenCVLoadImageMap(grain.transforms.Map):
    def map(self, element: tuple[str, int]) -> tuple[np.ndarray, int]:
        img_path, label = element
        img = cv2.imread(img_path, cv2.IMREAD_COLOR_RGB)
        return img, label

CLIPOpenCVLoadImageMap

加载CLIP的图像和文本数据集

class OpenCVLoadImageMap(grain.transforms.Map):
    def map(self, element: tuple[str, str]) -> tuple[np.ndarray, str]:
        img_path, text = element
        img = cv2.imread(img_path, cv2.IMREAD_COLOR_RGB)
        return img, text

PILoadImageMap

使用PIL读取图像

class PILoadImageMap(grain.transforms.Map):
    def map(self, element: tuple[str, int]) -> tuple[np.ndarray, int]:
        img_path, label = element
        img = np.asarray(Image.open(img_path).convert(mode="RGB"))
        return img, label

AlbumentationsTransform

使用Albumentations进行图像数据增强

class AlbumentationsTransform(grain.transforms.Map):
    def __init__(self, transforms):
        self.transforms = transforms

    def map(self, element: tuple[np.ndarray, int]) -> tuple[np.ndarray, int]:
        image, label = element
        transformed_image = self.transforms(image=image)["image"]
        return transformed_image, label

CLIPAlbumentationsTransform

class CLIPAlbumentationsTransform(grain.transforms.Map):
    def __init__(self, transforms):
        self.transforms = transforms

    def map(self, element: tuple[np.ndarray, str]) -> tuple[np.ndarray, str]:
        image, text = element
        transformed_image = self.transforms(image=image)["image"]
        return transformed_image, text

TokenizerMap

对数据集中的Text进行分词。

class TokenizerMap(grain.transforms.Map):
    def __init__(self, tokenizer, context_length=77):
        self.tokenizer = partial(tokenizer, context_length=context_length)

    def map(self, element: tuple[np.ndarray, str]) -> tuple[np.ndarray, np.ndarray]:
        img, txt = element
        text = self.tokenizer(txt)[0]
        return img, text

create_transforms

创建图像增强

def create_transforms(target_size, *, is_training=True) -> A.Compose:
    """Create image augmentation and normalization transformations.

    Args:
        target_size: The desired height and width for the images.
        is_training: A boolean indicating whether to apply training-specific augmentations.

    Returns:
        An `A.Compose` object containing the sequence of transformations.

    """
    transforms_list = [
        A.Resize(height=target_size, width=target_size, p=1.0),
    ]

    if is_training:
        transforms_list.extend(
            [
                A.HorizontalFlip(p=0.5),
                A.VerticalFlip(p=0.5),
                A.Rotate(limit=30, p=0.5),
                A.ColorJitter(
                    brightness=0.2,
                    contrast=0.2,
                    saturation=0.2,
                    hue=0.05,
                    p=0.5,
                ),
                A.RandomResizedCrop(
                    size=(target_size, target_size),
                    scale=(0.8, 1.0),
                    ratio=(0.75, 1.33),
                    p=0.5,
                ),
            ],
        )

    transforms_list.append(
        A.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225],
            max_pixel_value=255.0,
        ),
    )

    return A.Compose(transforms_list)
train_loader = pygrain.DataLoader(
        data_source=train_dataset,
        sampler=train_sampler,
        worker_count=params.num_workers,
        worker_buffer_size=2,
        operations=[
            OpenCVLoadImageMap(),
            AlbumentationsTransform(
                create_transforms(
                    target_size=params.target_size,
                    is_training=True,
                ),
            ),
            pygrain.Batch(
                params.batch_size,
                drop_remainder=True,
            ),
        ],
    )
posted @ 2025-08-15 16:48  里列昂遗失的记事本  阅读(11)  评论(0)    收藏  举报