dataset模块之Dataset类
在上一篇文章中我们讲到了Cifar10Dataset类的父类MappableDataset类以及MappableDataset类的父类SourceDataset类。本篇文章我们将介绍Cifar10Dataset类最原始的基类Dataset类,它是SourceDataset类的父类,讲完它之后也就意味着我们对Cifar10Dataset类的探索结束了。话不多说,让我们直接进入正题。
Dataset类
类定义:
class Dataset:
该类用来表示DataEngine数据管道中的数据集。并且它还是SourceDataset类的基类,表示数据流图中的节点。
初始化方法:
def __init__(self, children=None, num_parallel_workers=None, cache=None):
# children是否为None,若为None的话,则用[]替代
self.children = replace_none(children, [])
# 判断children是否为元组类型,是的话转化为list类型
if isinstance(self.children, tuple):
self.children = list(self.children)
if not isinstance(self.children, list):
self.children = [self.children]
self.parent = []
for child in self.children:
child.parent.append(weakref.ref(self))
self.num_parallel_workers = num_parallel_workers
self.cache = cache
self._device_iter = 0
self._input_indexs = ()
self.saved_output_types = None
self.saved_output_shapes = None
self.dynamic_setting = [False, None]
self.saved_min_shapes = None
self.saved_max_shapes = None
self._col_names = None
self.dataset_size = None
self._batch_size = None
self._num_classes = None
self._repeat_count = None
self._class_indexing = None
self._sync = False
参数:num_parallel_workers:数据格式:int(可选) 含义:并行处理数据集的worker数(即使用处理器内核数)
注意:变量children和parent是内部变量(局部变量),不建议在外部类使用
create_ir_tree方法:
def create_ir_tree(self):
parent = self.parent
self.parent = []
# copy.deepcopy方法为对python对象进行深度复刻
dataset = copy.deepcopy(self)
global _OP_NAME
# Dataset._get_operator_id:迭代树并获取每一个操作符的id
_OP_NAME = Dataset._get_operator_id(dataset)
# dataset.parse_tree:将API树解析为IR树
ir_tree = dataset.parse_tree()
self.parent = parent
_init_device_info()
return ir_tree, dataset
功能:构建IR树(中间表达式)的内部方法
参数:无
返回值:DatasetNode:IR树的根节点 Dataset:IR树的根数据集
close_pool方法:
def close_pool(self):
if hasattr(self, 'process_pool') and self.process_pool is not None:
self.process_pool.close()
for child in self.children:
child.close_pool()
功能:关闭数据集中多处理池(在初始化参数中可以设置并行处理数据集)
参数:无
to_json方法:
def to_json(self, filename=""):
ir_tree, _ = self.create_ir_tree()
return json.loads(ir_tree.to_json(filename))
功能:将管道序列化为JSON字符串并转储到文件(文件名为参数filename)中
参数:filename:数据格式:str 含义:要另存为JSON文件的文件名
返回值:位于管道中的JSON字符串
batch方法:
@check_batch
def batch(self, batch_size, drop_remainder=False, num_parallel_workers=None, per_batch_map=None,
input_columns=None, output_columns=None, column_order=None, pad_info=None, python_multiprocessing=False):
return BatchDataset(self, batch_size, drop_remainder, num_parallel_workers, per_batch_map, input_columns,
output_columns, column_order, pad_info, python_multiprocessing)
功能:将batch_size相同的连续行数合并为batch。对于任何的子节点,batch都被视为一行。对于任何列,该列中的所有元素都必须具有相同的shape。并且如果提供了可调用的_batch_map映射,则它应该被用于张量(Tensor)的batches。
函数修饰器: @check_batch:检查batch方法中输入的参数是否合理
参数:
参数 | 数据格式 | 含义 |
---|---|---|
batch_size | int或者是函数 | 创建每个batch的行数(即数据个数) |
drop_remainder | bool,默认值为false | 确定是否删除数据行数(数据个数)小于batch_size的最后一个batch |
num_parallel_workers | int,默认值为None | 并行处理数据集的worker数(即使用处理器内核数) |
per_batch_map | function(列表的集合(list[Tensor],list[Tensor]...,BatchInfo)) | 输入参数的可调用函数,每个list[Tensor]对应给定输入列中的张量,列表的数量应与输入列中的条目数量相匹配,可调用的最后一个参数应始终是BatchInfo对象 |
input_columns | str或list[str] | 输入数据集各列名称的一个列表,列表的大小应与per_batch_map相匹配 |
output_columns | str或list[str],默认值None,输出列将与输入列具有相同的名称。 | 为上一次操作输出的列指定的名称列表,此列表的大小必须与上次操作的输出列数匹配 |
column_order | str或list[str] | 指定整个数据集中所需的所有列的列表 |
pad_info | dict | 是否对所选列执行填充。若pad_info={“col1”:([224224],0)},将名称为“col1”的列填充到大小为[224224]的张量中,并用0(默认值=无)填充缺少的列。 |
python_multiprocessing | bool | 使用多处理对per_batch_map的python函数进行并行化。如果函数计算量大(默认值为False),则此选项可能会很有用。 |
返回值:已经分批处理好(batches)好的数据集
sync_wait方法:
@check_sync_wait
def sync_wait(self, condition_name, num_batch=1, callback=None):
return SyncWaitDataset(self, condition_name, num_batch, callback)
功能:向输入数据集添加阻塞条件,将应用于同步操作
函数修饰器:@check_sync_wait:检查sync_wait方法中输入的参数是否合理
参数:
参数 | 数据格式 | 含义 |
---|---|---|
condition_name | str | 用于切换发送下一行的条件名称 |
num_batch=1 | int(默认为1) | 每个epoch开始时,未阻塞的batch数 |
callback | 函数(function) | 调用sync_update时将调用的回调函数 |
shuffle方法:
@check_shuffle
def shuffle(self, buffer_size):
return ShuffleDataset(self, buffer_size)
功能:使用一下策略随机混肴此数据集的行,可以提供一个随机种子用于第一个epoch,在随后的每个epoch中,该种子都会被更改为新的随机生成值
- Make a shuffle buffer that contains the first buffer_size rows.
- Randomly select an element from the shuffle buffer to be the next row propagated to the child node.
- Get the next row (if any) from the parent node and put it in the shuffle buffer.
- Repeat steps 2 and 3 until there are no more rows left in the shuffle buffer.
参数:buffer_size:数据格式:int 含义:混肴的缓存区(必须大于1),将buffer_size设置为整个数据集中的行数(一行代表一个数据)将导致全局混肴
返回值:已经被混肴的数据集
flat_map方法:
def flat_map(self, func):
dataset = None
if not hasattr(func, '__call__'):
logger.error("func must be a function.")
raise TypeError("func must be a function.")
for row_data in self.create_tuple_iterator(output_numpy=True):
if dataset is None:
dataset = func(row_data)
else:
dataset += func(row_data)
if not isinstance(dataset, Dataset):
logger.error("flat_map must return a Dataset object.")
raise TypeError("flat_map must return a Dataset object.")
return dataset
功能:将func作用到数据集中的每一行并在最后展平结果
参数:func:一个函数 注意:func必须为将一个'Ndarray'作为参数并返回'Dataset'的函数
返回值:数据集,该数据集的每一行都被应用过func
在这里附一个官网的例程:
>>> dataset = ds.NumpySlicesDataset([[0, 1], [2, 3]])
>>>
>>> def flat_map_func(array):
... # create a NumpySlicesDataset with the array
... dataset = ds.NumpySlicesDataset(array)
... # repeat the dataset twice
... dataset = dataset.repeat(2)
... return dataset
>>>
>>> dataset = dataset.flat_map(flat_map_func)
>>> # [[0, 1], [0, 1], [2, 3], [2, 3]]
map方法:
def map(self, operations, input_columns=None, output_columns=None, column_order=None,
num_parallel_workers=None, python_multiprocessing=False, cache=None, callbacks=None):
return MapDataset(self, operations, input_columns, output_columns, column_order, num_parallel_workers,
python_multiprocessing, cache, callbacks)
功能:将operations中的每一个操作都应用于此数据集,操作顺序由每一个操作在operations中的顺序决定,第一次应用operations[0] , 然后应用perations[1], 然后应用operations[2], 等等等。每个operation将从数据集中传递一个或多个列作为输入,并输出零个或多个列。第一个operation将作为输入传递到input_columns中指定的列。如果操作中有多个运算符,则前一个operation的输出列将用作下一个operation的输入列。最后一次操作输出的列将被分配由输出列指定的名称。
参数:它的很多参数我们都在上面的batch方法中介绍过,这里我们看一下没有介绍过的参数operations。
operations:数据格式:list[TesnorOp],list[functions] 含义:要应用于数据集的operations,操作将按照它们在此列表中出现的顺序应用。
返回值:MapDataset,应用完operations中的操作后的数据集
filter方法:
@check_filter
def filter(self, predicate, input_columns=None, num_parallel_workers=None):
return FilterDataset(self, predicate, input_columns, num_parallel_workers)
功能:通过预测过滤数据集
函数修饰器: @check_filter:检查输入参数是否符合规范
参数:predicate:可以返回bool值得函数(Python可调用),如果返回false,则将该数据过滤掉。
其它参数我们均已讲过。
返回值:FilterDataset,已经被筛选过的数据集
skip方法:
@check_skip
def skip(self, count):
return SkipDataset(self, count)
功能:跳过此数据集的前n个元素
函数修饰器:@check_skip:检查输入参数是否符合规范
参数:count:数据格式:int 含义:此数据集中需要被跳过的数据的个数
返回值:SkipDataset,已经被跳过前count个元素的数据集
zip方法:
@check_zip_dataset
def zip(self, datasets):
if isinstance(datasets, tuple):
datasets = (self, *datasets)
elif isinstance(datasets, Dataset):
datasets = (self, datasets)
else:
raise TypeError("Invalid datasets, expected Dataset object or tuple of Dataset, but got %s!" % datasets)
return ZipDataset(datasets)
功能:在数据集的输入元组的意义上压缩数据集。输入数据集中的列必须具有不同的名称
函数修饰器:@check_zip_dataset:检查输入参数是否符合规范
参数:datasets:数据集的元组或单个数据集类,要联合于该数据集进行压缩
返回值:ZipDataset,压缩过后的数据集
对于上面所述的方法,它们是我目前使用过的或者是见过他人使用的。实际上Dataset类还有很多其它的方法,它们在某些具体的场合且起着重要的作用。感兴趣的读者可前往https://gitee.com/mindspore/mindspore/blob/master/mindspore/dataset/engine/datasets.py和https://www.mindspore.cn/docs/api/zh-CN/r1.3/api_python/mindspore.dataset.html。在下篇文章中,我将带着大家去实际操作Cifar10Dataset类的一些方法。让我们下篇文章再见。