Grain用于读取和处理用于训练和评估 JAX 模型的数据
Grain介绍
- 用户可以带来任意的 Python 转换。
- Grain 被设计为模块化。如果需要,用户可以使用自己的实现轻松覆盖 Grain 组件。
- 同一管道的多次运行将产生相同的输出。
- Grain 的设计使得检查点的大小最小。抢占后,Grain 可以从中断的地方恢复,并产生与从未被抢占相同的输出。
- 我们在设计 Grain 时小心翼翼地确保其性能良好(请参阅文档的幕后部分。我们还针对多种数据模式(例如文本/音频/图像/视频)对其进行了测试。
- Grain 会尽可能减少其依赖项集。例如,它不应该依赖于 TensorFlow。
Grain安装
pip install grain
导入所有库
from pathlib import Path
import grain
import grain.python as pygrain
import cv2
import albumentations as A
import pandas as pd
创建Grain数据源
参考Grain官网要求,数据源需要实现两个魔法方法:
class RandomAccessDataSource(Protocol, Generic[T]):
"""Interface for datasources where storage supports efficient random access."""
def __len__(self) -> int:
"""Number of records in the dataset."""
def __getitem__(self, record_key: SupportsIndex) -> T:
"""Retrieves record for the given record_key."""
自用数据分类的数据源
from pathlib import Path
def get_image_extensions() -> tuple[str, ...]:
"""Returns a tuple of common image file extensions.
This function provides a centralized list of supported image file types,
making it easy to manage and update.
Returns:
A tuple of strings, where each string is an image file extension
(e.g., ".jpg", ".png").
"""
return (".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".tif", ".webp")
class ImageFolderDataSource:
"""A data source class that mimics the structure of torchvision.datasets.ImageFolder.
It expects the root directory to contain subdirectories, where each subdirectory
represents a class and contains images belonging to that class.
Example directory structure:
root_dir/
├── class_a/
│ ├── image1.jpg
│ └── image2.png
└── class_b/
├── image3.jpeg
└── image4.bmp
"""
def __init__(self, root_dir: str):
"""Initializes the ImageFolderDataSource.
Args:
root_dir: The path to the root directory containing class subdirectories.
"""
self.root_dir = Path(root_dir)
self.samples: list[tuple[str, int]] = []
self.classes: list[str] = []
self._load_samples()
def _load_samples(self) -> None:
"""Loads image file paths and their corresponding class indices from the root directory.
This method iterates through subdirectories, treats each subdirectory name as a class,
and collects all valid image files within them.
"""
if not self.root_dir.exists():
msg = f"Root directory {self.root_dir} does not exist."
raise FileNotFoundError(msg)
class_to_idx = {}
valid_extensions = get_image_extensions()
class_dirs = [d for d in self.root_dir.iterdir() if d.is_dir()]
class_dirs.sort(key=lambda x: x.name)
for class_dir in class_dirs:
class_name = class_dir.name
class_idx = len(class_to_idx)
class_to_idx[class_name] = class_idx
self.classes.append(class_name)
for ext in valid_extensions:
for img_path in class_dir.glob(f"*{ext}"):
self.samples.append((str(img_path), class_idx))
if not self.samples:
msg = f"No valid images found in directory '{self.root_dir}'"
raise RuntimeError(msg)
# 实现魔法方法__len__
def __len__(self) -> int:
"""Returns the total number of samples (images) in the dataset.
This allows the dataset object to be used with `len()`.
"""
return len(self.samples)
# 实现魔法方法__getitem__
def __getitem__(self, index: int) -> tuple[str, int]:
"""Returns a sample (image path and its class index) at the given index.
This allows the dataset object to be indexed like a list (e.g., `dataset[0]`).
"""
return self.samples[index]
CLIP数据加载DataSource
class CLIPDataSource:
def __init__(self, csv_file):
df = pd.read_csv(csv_file)
img_paths = df["image_path"].tolist()
texts = df['txt'].tolist()
def __len__(self) -> int:
return len(self.samples)
def __getitem__(self, idx: int) -> tuple[str, str]:
return self.samples[idx]
创建IndexSampler(类似PyTorch的Sampler)
# 实例化数据源
train_dataset = ImageFolderDataSource(params.train_data_path)
train_sampler=pygrain.IndexSampler(
num_records=len(train_dataset),
shuffle=True,
seed=params.seed,
shard_options=pygrain.NoSharding(),
num_epochs=1,
)
创建DataLoader
在grain中需要自己写数据加载流。
OpenCVLoadImageMap
使用OpenCV读取图像
class OpenCVLoadImageMap(grain.transforms.Map):
def map(self, element: tuple[str, int]) -> tuple[np.ndarray, int]:
img_path, label = element
img = cv2.imread(img_path, cv2.IMREAD_COLOR_RGB)
return img, label
CLIPOpenCVLoadImageMap
加载CLIP的图像和文本数据集
class OpenCVLoadImageMap(grain.transforms.Map):
def map(self, element: tuple[str, str]) -> tuple[np.ndarray, str]:
img_path, text = element
img = cv2.imread(img_path, cv2.IMREAD_COLOR_RGB)
return img, text
PILoadImageMap
使用PIL读取图像
class PILoadImageMap(grain.transforms.Map):
def map(self, element: tuple[str, int]) -> tuple[np.ndarray, int]:
img_path, label = element
img = np.asarray(Image.open(img_path).convert(mode="RGB"))
return img, label
AlbumentationsTransform
使用Albumentations进行图像数据增强
class AlbumentationsTransform(grain.transforms.Map):
def __init__(self, transforms):
self.transforms = transforms
def map(self, element: tuple[np.ndarray, int]) -> tuple[np.ndarray, int]:
image, label = element
transformed_image = self.transforms(image=image)["image"]
return transformed_image, label
CLIPAlbumentationsTransform
class CLIPAlbumentationsTransform(grain.transforms.Map):
def __init__(self, transforms):
self.transforms = transforms
def map(self, element: tuple[np.ndarray, str]) -> tuple[np.ndarray, str]:
image, text = element
transformed_image = self.transforms(image=image)["image"]
return transformed_image, text
TokenizerMap
对数据集中的Text进行分词。
class TokenizerMap(grain.transforms.Map):
def __init__(self, tokenizer, context_length=77):
self.tokenizer = partial(tokenizer, context_length=context_length)
def map(self, element: tuple[np.ndarray, str]) -> tuple[np.ndarray, np.ndarray]:
img, txt = element
text = self.tokenizer(txt)[0]
return img, text
create_transforms
创建图像增强
def create_transforms(target_size, *, is_training=True) -> A.Compose:
"""Create image augmentation and normalization transformations.
Args:
target_size: The desired height and width for the images.
is_training: A boolean indicating whether to apply training-specific augmentations.
Returns:
An `A.Compose` object containing the sequence of transformations.
"""
transforms_list = [
A.Resize(height=target_size, width=target_size, p=1.0),
]
if is_training:
transforms_list.extend(
[
A.HorizontalFlip(p=0.5),
A.VerticalFlip(p=0.5),
A.Rotate(limit=30, p=0.5),
A.ColorJitter(
brightness=0.2,
contrast=0.2,
saturation=0.2,
hue=0.05,
p=0.5,
),
A.RandomResizedCrop(
size=(target_size, target_size),
scale=(0.8, 1.0),
ratio=(0.75, 1.33),
p=0.5,
),
],
)
transforms_list.append(
A.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
max_pixel_value=255.0,
),
)
return A.Compose(transforms_list)
train_loader = pygrain.DataLoader(
data_source=train_dataset,
sampler=train_sampler,
worker_count=params.num_workers,
worker_buffer_size=2,
operations=[
OpenCVLoadImageMap(),
AlbumentationsTransform(
create_transforms(
target_size=params.target_size,
is_training=True,
),
),
pygrain.Batch(
params.batch_size,
drop_remainder=True,
),
],
)