数据集
这里记录一些我用过的数据集
- 可以从 tensorflow_datasets 数据集列表 找到几乎所有数据集的 source。
- torch.utils.data.Dataset 提供了常见数据集的读取类。
ImageNet
适用任务:
- 图像分类
- 目标定位
- 目标检测
经常使用的是 ImageNet 的一个子集,ILSVRC2012(ImageNet-1K)。
wget
# 下载数据集
wget https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_train.tar
wget https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_val.tar
#wget https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_test_v10102019.tar
wget https://www.image-net.org/data/ILSVRC/2012/ILSVRC2012_devkit_t12.tar.gz
# 解压数据集
python -c "from torchvision.datasets import ImageNet; ImageNet('.', split='train')"
python -c "from torchvision.datasets import ImageNet; ImageNet('.', split='val')"
datasets
from datasets import load_dataset
ds = load_dataset("ILSVRC/imagenet-1k")
kaggle
kaggle competitions download -c imagenet-object-localization-challenge
AutoDL 提供了 ImageNet100 和 ImageNet 数据集:
/root/autodl-pub/ImageNet100
/root/autodl-pub/ImageNet
在评估 ILSVRC2012 的分类模型时,通常使用验证集(val)来报告模型的分类准确率。测试集(test)一般用于比赛或最终评估,通常不公开其标签,只在官方评测服务器上使用。
参见:
COCO
适用任务:
- 目标检测
- 实例分割
- 图像描述
- 关键点检测
- 全景分割
- 密集图像描述
经常使用的是 COCO 的一个子集,COCO-2017。
wget
# COCO-2017
wget http://images.cocodataset.org/zips/train2017.zip
wget http://images.cocodataset.org/zips/val2017.zip
wget http://images.cocodataset.org/zips/test2017.zip
wget http://images.cocodataset.org/annotations/annotations_trainval2017.zip
wget http://images.cocodataset.org/annotations/image_info_test2017.zip
unzip *2017.zip
datasets
ds = load_dataset("detection-datasets/coco", split="train")
fiftyone
import fiftyone.zoo as foz
import fiftyone as fo
# 加载 COCO 2017 数据集
dataset = foz.load_zoo_dataset(
"coco-2017",
split="train",
dataset_dir="/path/to/save"
)
# 启动数据集浏览工具
session = fo.launch_app(dataset)
CIFAR-10
wget https://www.cs.utoronto.ca/~kriz/cifar-10-python.tar.gz
wget https://www.cs.utoronto.ca/~kriz/cifar-100-python.tar.gz
AutoDL 提供了 CIFAR-10 和 CIFAR-100 的数据集:
/root/autodl-pub/cifar-10
/root/autodl-pub/cifar-100
from torchvision.datasets import CIFAR10
# 训练集
train_data = CIFAR10('./data', train=True, transform=None, target_transform=None, download=True)
# 测试集
test_data = CIFAR10('./data', train=False, transform=None, target_transform=None, download=True)
dataset_dir
:存放数据集的路径train
(可选):如果为True
,则构建训练集,否则构建测试集transform
:定义数据预处理,数据增强方案都是在这里指定target_transform
:标注的预处理,分类任务不常用download
:是否下载,若为True
则从互联网下载,如果已经在dataset_dir
下存在,就不会再次下载
数据增强:在 transform
中指定参数
import torchvision.transforms as transforms
custom_transform = transforms.Compose([
transforms.Resize((64, 64)), # 缩放到指定大小(64*64)
transforms.ColorJitter(0.2, 0.2, 0.2), # 随机颜色变换
transforms.RandomRotation(5), # 随机旋转
transforms.Normalize([0.485,0.456,0.406], # 对图像像素进行归一化
[0.229,0.224,0.225])])
train_data = CIFAR10('./data', train=True, transform=custom_transforms, target_transform=None, download=False)
使用 DataLoader:
from torch.utils.data import DataLoader
# 实现数据批量读取
train_loader = DataLoader(train_data, batch_size=2, shuffle=True, num_workers=4)
batch_size
:设置批次大小shuffle
:在装载过程中随机乱序num_workers
:>=1
表示多进程读取数据,在 Windows 下num_workers
只能设置为0
,否则会报错。
MNIST
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
# 训练集
train_set = MNIST('./data', train=True, transform=transforms.ToTensor(), download=True)
# 测试集
test_set = MNIST('./data', train=False, transform=transforms.ToTensor(), download=True)
# 训练集载入器
train_data = DataLoader(train_set, batch_size=64, shuffle=True)
# 测试集载入器
test_data = DataLoader(test_set, batch_size=128, shuffle=False)
# 可视化数据
import random
for i in range(4):
ax = plt.subplot(2, 2, i+1)
idx = random.randint(0, len(train_set))
digit_0 = train_set[idx][0].numpy()
digit_0_image = digit_0.reshape(28, 28)
ax.imshow(digit_0_image, interpolation="nearest")
ax.set_title('label: {}'.format(train_set[idx][1]), fontsize=10, color='black')
plt.show()