dataset模块之Dataset类

在上一篇文章中我们讲到了Cifar10Dataset类的父类MappableDataset类以及MappableDataset类的父类SourceDataset类。本篇文章我们将介绍Cifar10Dataset类最原始的基类Dataset类,它是SourceDataset类的父类,讲完它之后也就意味着我们对Cifar10Dataset类的探索结束了。话不多说,让我们直接进入正题。

Dataset类

类定义:

class Dataset:

该类用来表示DataEngine数据管道中的数据集。并且它还是SourceDataset类的基类,表示数据流图中的节点。

初始化方法:

    def __init__(self, children=None, num_parallel_workers=None, cache=None):
        # children是否为None,若为None的话,则用[]替代
        self.children = replace_none(children, [])
        # 判断children是否为元组类型,是的话转化为list类型
        if isinstance(self.children, tuple):
            self.children = list(self.children)
        if not isinstance(self.children, list):
            self.children = [self.children]

        self.parent = []
        for child in self.children:
            child.parent.append(weakref.ref(self))
        self.num_parallel_workers = num_parallel_workers
        self.cache = cache

        self._device_iter = 0
        self._input_indexs = ()
        self.saved_output_types = None
        self.saved_output_shapes = None
        self.dynamic_setting = [False, None]
        self.saved_min_shapes = None
        self.saved_max_shapes = None
        self._col_names = None
        self.dataset_size = None
        self._batch_size = None
        self._num_classes = None
        self._repeat_count = None
        self._class_indexing = None
        self._sync = False

参数:num_parallel_workers:数据格式:int(可选) 含义:并行处理数据集的worker数(即使用处理器内核数)

注意:变量children和parent是内部变量(局部变量),不建议在外部类使用

create_ir_tree方法:

    def create_ir_tree(self):
        parent = self.parent
        self.parent = []
        # copy.deepcopy方法为对python对象进行深度复刻
        dataset = copy.deepcopy(self)
        global _OP_NAME
        # Dataset._get_operator_id:迭代树并获取每一个操作符的id
        _OP_NAME = Dataset._get_operator_id(dataset)
        # dataset.parse_tree:将API树解析为IR树
        ir_tree = dataset.parse_tree()
        self.parent = parent
        _init_device_info()
        return ir_tree, dataset

功能:构建IR树(中间表达式)的内部方法

参数:无

返回值:DatasetNode:IR树的根节点 Dataset:IR树的根数据集

close_pool方法:

    def close_pool(self):
        if hasattr(self, 'process_pool') and self.process_pool is not None:
            self.process_pool.close()
        for child in self.children:
            child.close_pool()

功能:关闭数据集中多处理池(在初始化参数中可以设置并行处理数据集)

参数:无

to_json方法:

   def to_json(self, filename=""):
        ir_tree, _ = self.create_ir_tree()
        return json.loads(ir_tree.to_json(filename))

功能:将管道序列化为JSON字符串并转储到文件(文件名为参数filename)中

参数:filename:数据格式:str 含义:要另存为JSON文件的文件名

返回值:位于管道中的JSON字符串

batch方法:

    @check_batch
    def batch(self, batch_size, drop_remainder=False, num_parallel_workers=None, per_batch_map=None,
              input_columns=None, output_columns=None, column_order=None, pad_info=None, python_multiprocessing=False):
        return BatchDataset(self, batch_size, drop_remainder, num_parallel_workers, per_batch_map, input_columns,
                            output_columns, column_order, pad_info, python_multiprocessing)

功能:将batch_size相同的连续行数合并为batch。对于任何的子节点,batch都被视为一行。对于任何列,该列中的所有元素都必须具有相同的shape。并且如果提供了可调用的_batch_map映射,则它应该被用于张量(Tensor)的batches。

函数修饰器: @check_batch:检查batch方法中输入的参数是否合理

参数:

参数数据格式含义
batch_size int或者是函数 创建每个batch的行数(即数据个数)
drop_remainder bool,默认值为false 确定是否删除数据行数(数据个数)小于batch_size的最后一个batch
num_parallel_workers int,默认值为None 并行处理数据集的worker数(即使用处理器内核数)
per_batch_map function(列表的集合(list[Tensor],list[Tensor]...,BatchInfo)) 输入参数的可调用函数,每个list[Tensor]对应给定输入列中的张量,列表的数量应与输入列中的条目数量相匹配,可调用的最后一个参数应始终是BatchInfo对象
input_columns str或list[str] 输入数据集各列名称的一个列表,列表的大小应与per_batch_map相匹配
output_columns str或list[str],默认值None,输出列将与输入列具有相同的名称。 为上一次操作输出的列指定的名称列表,此列表的大小必须与上次操作的输出列数匹配
column_order str或list[str] 指定整个数据集中所需的所有列的列表
pad_info dict 是否对所选列执行填充。若pad_info={“col1”:([224224],0)},将名称为“col1”的列填充到大小为[224224]的张量中,并用0(默认值=无)填充缺少的列。
python_multiprocessing bool 使用多处理对per_batch_map的python函数进行并行化。如果函数计算量大(默认值为False),则此选项可能会很有用。

返回值:已经分批处理好(batches)好的数据集

sync_wait方法:

   @check_sync_wait
    def sync_wait(self, condition_name, num_batch=1, callback=None):
        return SyncWaitDataset(self, condition_name, num_batch, callback)

功能:向输入数据集添加阻塞条件,将应用于同步操作

函数修饰器:@check_sync_wait:检查sync_wait方法中输入的参数是否合理

参数:

参数数据格式含义
condition_name str 用于切换发送下一行的条件名称
num_batch=1 int(默认为1) 每个epoch开始时,未阻塞的batch数
callback 函数(function) 调用sync_update时将调用的回调函数

shuffle方法:

  @check_shuffle
    def shuffle(self, buffer_size):
        return ShuffleDataset(self, buffer_size)

功能:使用一下策略随机混肴此数据集的行,可以提供一个随机种子用于第一个epoch,在随后的每个epoch中,该种子都会被更改为新的随机生成值

  1. Make a shuffle buffer that contains the first buffer_size rows.
  2. Randomly select an element from the shuffle buffer to be the next row propagated to the child node.
  3. Get the next row (if any) from the parent node and put it in the shuffle buffer.
  4. Repeat steps 2 and 3 until there are no more rows left in the shuffle buffer.

参数:buffer_size:数据格式:int 含义:混肴的缓存区(必须大于1),将buffer_size设置为整个数据集中的行数(一行代表一个数据)将导致全局混肴

返回值:已经被混肴的数据集

flat_map方法:

   def flat_map(self, func):
        dataset = None
        if not hasattr(func, '__call__'):
            logger.error("func must be a function.")
            raise TypeError("func must be a function.")

        for row_data in self.create_tuple_iterator(output_numpy=True):
            if dataset is None:
                dataset = func(row_data)
            else:
                dataset += func(row_data)

        if not isinstance(dataset, Dataset):
            logger.error("flat_map must return a Dataset object.")
            raise TypeError("flat_map must return a Dataset object.")
        return dataset

功能:将func作用到数据集中的每一行并在最后展平结果

参数:func:一个函数 注意:func必须为将一个'Ndarray'作为参数并返回'Dataset'的函数

返回值:数据集,该数据集的每一行都被应用过func

在这里附一个官网的例程:

            >>> dataset = ds.NumpySlicesDataset([[0, 1], [2, 3]])
            >>>
            >>> def flat_map_func(array):
            ...     # create a NumpySlicesDataset with the array
            ...     dataset = ds.NumpySlicesDataset(array)
            ...     # repeat the dataset twice
            ...     dataset = dataset.repeat(2)
            ...     return dataset
            >>>
            >>> dataset = dataset.flat_map(flat_map_func)
            >>> # [[0, 1], [0, 1], [2, 3], [2, 3]]

map方法:

   def map(self, operations, input_columns=None, output_columns=None, column_order=None,
            num_parallel_workers=None, python_multiprocessing=False, cache=None, callbacks=None):
        return MapDataset(self, operations, input_columns, output_columns, column_order, num_parallel_workers,
                          python_multiprocessing, cache, callbacks)

功能:将operations中的每一个操作都应用于此数据集,操作顺序由每一个操作在operations中的顺序决定,第一次应用operations[0] , 然后应用perations[1], 然后应用operations[2], 等等等。每个operation将从数据集中传递一个或多个列作为输入,并输出零个或多个列。第一个operation将作为输入传递到input_columns中指定的列。如果操作中有多个运算符,则前一个operation的输出列将用作下一个operation的输入列。最后一次操作输出的列将被分配由输出列指定的名称。

参数:它的很多参数我们都在上面的batch方法中介绍过,这里我们看一下没有介绍过的参数operations。

operations:数据格式:list[TesnorOp],list[functions] 含义:要应用于数据集的operations,操作将按照它们在此列表中出现的顺序应用。

返回值:MapDataset,应用完operations中的操作后的数据集

filter方法:

   @check_filter
    def filter(self, predicate, input_columns=None, num_parallel_workers=None):
        return FilterDataset(self, predicate, input_columns, num_parallel_workers)

功能:通过预测过滤数据集

函数修饰器: @check_filter:检查输入参数是否符合规范

参数:predicate:可以返回bool值得函数(Python可调用),如果返回false,则将该数据过滤掉。

其它参数我们均已讲过。

返回值:FilterDataset,已经被筛选过的数据集

skip方法:

    @check_skip
    def skip(self, count):
        return SkipDataset(self, count)

功能:跳过此数据集的前n个元素

函数修饰器:@check_skip:检查输入参数是否符合规范

参数:count:数据格式:int 含义:此数据集中需要被跳过的数据的个数

返回值:SkipDataset,已经被跳过前count个元素的数据集

zip方法:

    @check_zip_dataset
    def zip(self, datasets):

        if isinstance(datasets, tuple):
            datasets = (self, *datasets)
        elif isinstance(datasets, Dataset):
            datasets = (self, datasets)
        else:
            raise TypeError("Invalid datasets, expected Dataset object or tuple of Dataset, but got %s!" % datasets)
        return ZipDataset(datasets)

功能:在数据集的输入元组的意义上压缩数据集。输入数据集中的列必须具有不同的名称

函数修饰器:@check_zip_dataset:检查输入参数是否符合规范

参数:datasets:数据集的元组或单个数据集类,要联合于该数据集进行压缩

返回值:ZipDataset,压缩过后的数据集

对于上面所述的方法,它们是我目前使用过的或者是见过他人使用的。实际上Dataset类还有很多其它的方法,它们在某些具体的场合且起着重要的作用。感兴趣的读者可前往https://gitee.com/mindspore/mindspore/blob/master/mindspore/dataset/engine/datasets.py和https://www.mindspore.cn/docs/api/zh-CN/r1.3/api_python/mindspore.dataset.html。在下篇文章中,我将带着大家去实际操作Cifar10Dataset类的一些方法。让我们下篇文章再见。

posted @ 2021-12-27 16:40  MS小白  阅读(168)  评论(0)    收藏  举报