如何用MindSpore自定义数据集
引言
在深度学习模型的训练过程中,数据集是起着至关重要作用的。然而,由于任务的复杂性,深度学习模型的输入数据也有着各种各样的形式,深度学习模型搭建的过程中,如果遇到特别复杂的数据,研究者可能要花费大半的时间在数据集的预处理(包括清洗、加载等过程)中。因此,高效的加载数据集,能给研究者构建一套高效的开发流程。
使用过PyTorch的读者都知道,PyTorch框架为我们提供了一套极其便利且高效率的自定义数据加载的接口。用户只需要简单的继承torch.utils.data.Dataset并且在get_item函数和__len__函数,再利用Dataloader进行封装,就可以很简单的实现数据集的自动化加载流程(个人认为设置PyTorch在数据层面上做的超级好的一个点)。
如何用MindSpore自定义数据集
MindSpore数据集加载简介
在MindSpore中,mindspore.dataset里面的函数为我们提供了大量的数据集专有加载算子,这些算子经过优化,拥有较好的数据集加载性能。但是,由于MindSpore本身的数据加载都是在C语言层面完成的,用户很难感知到内部进行的具体操作,特别是针对coco这一类较为复杂的数据集时(就是比较黑洞,很难自己掌握)。由于笔者是一个很喜欢把模型训练的每一步都抓在自己手里的一个人,因此除了cifar10、cifar100、imagefolder等经典的数据(结构)时,尽量都希望自己完成数据集的加载流程,以便更好的了解模型模型和数据集。因此,这篇博客将会主要介绍如何使用MindSpore自定自定义类似PyTorch范式的数据集加载流程。
mindspore.dataset.GeneratorDataset
区别用PyTorch,MindSpore并不能像继承Dataset来完成数据集的构建,但是MindSpore为用户提供了一个类似于DataLoader的数据集封装接口。
用户可以通过自定义object对象的数据集对象,然后使用GeneratorDataset进行封装,接下来我将以自定义cifar10和imagenet数据集来简单展示使用GeneratorDataset接口的方法。
自定义cifar10数据集
分析格式
在定义数据集之前,我们首先要做的就是数据集的格式分析。在cifar官网中,我们可以得知数据集的基本格式,还可以通过已有的博客,查看读取cifar10的代码样例。
如下图所示是cifar-10-batches-py数据集的目录文件,这里我们主要是关注data_batch和test_batch。
加载数据
这里我主要以torchvision中的cifar10数据加载为例,说明构建cifar10数据集的方法。
train_list = [
['data_batch_1', 'c99cafc152244af753f735de768cd75f'],
['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'],
['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'],
['data_batch_4', '634d18415352ddfa80567beed471001a'],
['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'],
]
test_list = [
['test_batch', '40351d587109b95175f43aff81a1287e'],
]
...
if self.train:
downloaded_list = self.train_list
else:
downloaded_list = self.test_list
...
for file_name, checksum in downloaded_list:
file_path = os.path.join(self.root, self.base_folder, file_name)
with open(file_path, 'rb') as f:
entry = pickle.load(f, encoding='latin1')
self.data.append(entry['data'])
if 'labels' in entry:
self.targets.extend(entry['labels'])
else:
self.targets.extend(entry['fine_labels'])
"""可以很容易理解到,数据集文件里面有一个"data"和一个"label"键,分别拿出来就好"""
self.data = np.vstack(self.data).reshape(-1, 3, 32, 32)
self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC
构建cifar10数据集并且完成预处理
由于cifar10读取进来以后已经是数据形式,因此并不需要想用的图像解码,可以直接使用opencv或者PIL进行处理。这里以cifar10的test数据为例。
import os
import pickle
import numpy as np
import mindspore
from mindspore.dataset import GeneratorDataset
class CIFAR10(object):
train_list = [
'data_batch_1',
'data_batch_2',
'data_batch_3',
'data_batch_4',
'data_batch_5',
]
test_list = [
'test_batch',
]
def __init__(self, root, train, transform=None, target_transform=None):
super(CIFAR10, self).__init__()
self.root = root
self.train = train # training set or test set
if self.train:
downloaded_list = self.train_list
else:
downloaded_list = self.test_list
self.data = []
self.targets = []
self.transform = transform
self.target_transform = target_transform
# now load the picked numpy arrays
for file_name in downloaded_list:
file_path = os.path.join(self.root, file_name)
with open(file_path, 'rb') as f:
entry = pickle.load(f, encoding='latin1')
self.data.append(entry['data'])
if 'labels' in entry:
self.targets.extend(entry['labels'])
else:
self.targets.extend(entry['fine_labels'])
self.data = np.vstack(self.data).reshape(-1, 3, 32, 32)
self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is index of the target class.
"""
img, target = self.data[index], self.targets[index]
# doing this so that it is consistent with all other datasets
# to return a PIL Image
img = Image.fromarray(img)
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self):
return len(self.data)
cifar10_test = CIFAR10(root="./cifar10/cifar-10-batches-py", train=False)
cifar10_test = GeneratorDataset(source=cifar10_test, column_names=["image", "label"])
cifar10_test = cifar10_test.batch(128)
for data in cifar10_test.create_dict_iterator():
print(data["image"