dataset模块之采样器

本篇文章将对我们先前提到过的数据采样器(sampler)进行讲解。对于采样器,我们可以将它理解为帮助用户对数据集进行不同形似的采样的工具。使用它之后可以满足训练所需的要求,并且可以解决诸多如数据集过大或样本分布不均等问题。在大多数情况下,使用采样器可以给最终的训练结果带来提升。

SequentialSampler

SequentialSampler,中文名称为顺序采样器。官方对于它的介绍是按照原始顺序采集数据集中的元素,相当于不使用采样器。在实操Cifar10Dataset类的相关方法时,我们已经使用过该顺序采样器去采集前几个样本数据数据,并将其相关信息打印出来以及可视化。为了进一步的了解它,我们来看一下它的源码。

class SequentialSampler(BuiltinSampler):
    """

    参数:
        start_index (int, 可选): 开始采样的索引值。 (默认值为None,从头开始采集)
        num_samples (int, 可选): 需要采集的元素的数目 (默认值为None,意味着采集数据集中所有元素).

       异常:
        TypeError: 参数start_index的数据格式不为int
        TypeError: 参数num_samples的数据格式不为int
        RuntimeError: 若start_index(开始索引值)为负数
        ValueError: 若num_samples(采集元素数目)为负数
    """

    def __init__(self, start_index=None, num_samples=None):
        if start_index is not None and not isinstance(start_index, int):
            raise TypeError("start_index must be integer but was: {}.".format(start_index))

        if num_samples is not None:
            if not isinstance(num_samples, int):
                raise TypeError("num_samples must be integer but was: {}.".format(num_samples))
            if num_samples < 0 or num_samples > validator.INT64_MAX:
                raise ValueError("num_samples exceeds the boundary between {} and {}(INT64_MAX)!"
                                 .format(0, validator.INT64_MAX))

        self.start_index = start_index
        super().__init__(num_samples)

    def parse(self):
        """ 解析采样器 """
        start_index = self.start_index if self.start_index is not None else 0
        num_samples = self.num_samples if self.num_samples is not None else 0
        c_sampler = cde.SequentialSamplerObj(start_index, num_samples)
        c_child_sampler = self.parse_child()
        c_sampler.add_child(c_child_sampler)
        return c_sampler

    def parse_for_minddataset(self):
        """ 解析MindRecord文件格式采样器 """
        start_index = self.start_index if self.start_index is not None else 0
        num_samples = self.num_samples if self.num_samples is not None else 0
        c_sampler = cde.MindrecordSequentialSampler(num_samples, start_index)
        c_child_sampler = self.parse_child_for_minddataset()
        c_sampler.add_child(c_child_sampler)
        return c_sampler

    def is_shuffled(self):
        if self.child_sampler is None:
            return False

        return self.child_sampler.is_shuffled()

    def is_sharded(self):
        if self.child_sampler is None:
            return False

        return self.child_sampler.is_sharded()

BuiltinSampler

实际上SequentialSampler采样器还有几个常用的方法,比如说add_child(为给定采样器添加子采样器)等。它们被封装在BuiltinSampler类中,该类是所有内置采样器的基类,不能够被用户所继承。接下来我们来看一下它的源码。

class BuiltinSampler:

    def __init__(self, num_samples=None):
        self.child_sampler = None
        self.num_samples = num_samples

    def parse(self):
        """ 解析采样器 """

    def add_child(self, sampler):
        """
          为给定采样器添加子采样器,子采样器将接收父采样器的输出,并将其应用新采样器的逻辑,返回新的样本。

        参数:
            sampler (Sampler类): dataset模块中用于采集元素的类, 仅能包含内置的被提供的采样器,包括
            (DistributedSampler, PKSampler, RandomSampler, SequentialSampler,
                SubsetRandomSampler, WeightedRandomSampler)。

        例子(MindSpore三种环境均可):
            >>> sampler = ds.SequentialSampler(start_index=0, num_samples=3)
            >>> sampler.add_child(ds.RandomSampler(num_samples=2))
            >>> dataset = ds.Cifar10Dataset(cifar10_dataset_dir, sampler=sampler)
        """
        self.child_sampler = sampler

    def get_child(self):
        """ 获得子采样器的实例 """
        return self.child_sampler

    def parse_child(self):
        """ 解析子采样器 """
        c_child_sampler = None
        if self.child_sampler is not None:
            c_child_sampler = self.child_sampler.parse()
        return c_child_sampler

    def parse_child_for_minddataset(self):
        """ 解析MindRecord格式的子采样器 """
        c_child_sampler = None
        if self.child_sampler is not None:
            c_child_sampler = self.child_sampler.parse_for_minddataset()
        return c_child_sampler

    def is_shuffled(self):
        """ Not implemented. """
        raise NotImplementedError("Sampler must implement is_shuffled.")

    def is_sharded(self):
        """ Not implemented. """
        raise NotImplementedError("Sampler must implement is_sharded.")

    def get_num_samples(self):
        """
        所有采样器都可以包含一个数值为num_samples值(也可以是默认值None)
        子采样器可有可无
        同理,若子采样器存在,则它的计数可以是一个数值(也可以是默认值None)
        上述条件会影响所使用的采样器总数
        下面是调用此函数可能产生的结果(例子来于官网),用户使用该方法时可以对应相关情况

        .. list-table::
           :widths: 25 25 25 25
           :header-rows: 1

            (后面所有样例将按照第一行为child sampler,第二行为num_samples,第三行
            为child_samples,第四行为result(结果)的形式给出)

           * - child sampler
             - num_samples
             - child_samples
             - result
           * - T
             - x
             - y
             - min(x, y)
           * - T
             - x
             - None
             - x
           * - T
             - None
             - y
             - y
           * - T
             - None
             - None
             - None
           * - None
             - x
             - n/a
             - x
           * - None
             - None
             - n/a
             - None

        返回值:
            int, sampler的个数或者是None
        """
        if self.child_sampler is not None:
            child_samples = self.child_sampler.get_num_samples()
            if self.num_samples is not None:
                if child_samples is not None:
                    return min(self.num_samples, child_samples)

                return self.num_samples

            return child_samples

        return self.num_samples

RandomSampler

与该类的名字一样,它是一个随机采样器,随机地抽取数据集中的元素样本。相对于顺序采样器来说,它的使用频率是较高的。同分析顺序采样器一样,我们来看一下它的源码。

class RandomSampler(BuiltinSampler):
    """
    Samples the elements randomly.

    参数:
        replacement (bool, 可选):    如果为true,则将该采样器ID放回下一次draw,默认值为false
        num_samples (int, 可选): 需要被采集的元素数 (默认值为None,意味着采集所有元素).

    异常:
        TypeError: 若参数replacement的数据格式不为布尔类型
        TypeError: 若参数num_samples的数据格式不为整型(int)
        ValueError: 若参数num_samples为负数
     """

    def __init__(self, replacement=False, num_samples=None):
        if not isinstance(replacement, bool):
            raise TypeError("replacement must be a boolean value but was: {}.".format(replacement))

        if num_samples is not None:
            if not isinstance(num_samples, int):
                raise TypeError("num_samples must be integer but was: {}.".format(num_samples))
            if num_samples < 0 or num_samples > validator.INT64_MAX:
                raise ValueError("num_samples exceeds the boundary between {} and {}(INT64_MAX)!"
                                 .format(0, validator.INT64_MAX))

        self.deterministic = False
        self.replacement = replacement
        self.reshuffle_each_epoch = True
        super().__init__(num_samples)

    def parse(self):
        """ 解析采样器 """
        num_samples = self.num_samples if self.num_samples is not None else 0
        replacement = self.replacement if self.replacement is not None else False
        c_sampler = cde.RandomSamplerObj(replacement, num_samples, self.reshuffle_each_epoch)
        c_child_sampler = self.parse_child()
        c_sampler.add_child(c_child_sampler)
        return c_sampler

    def parse_for_minddataset(self):
        """解析MindRecord文件格式采样器"""
        num_samples = self.num_samples if self.num_samples is not None else 0
        c_sampler = cde.MindrecordRandomSampler(num_samples, self.replacement, self.reshuffle_each_epoch)
        c_child_sampler = self.parse_child_for_minddataset()
        c_sampler.add_child(c_child_sampler)
        return c_sampler

    def is_shuffled(self):
        return True

    def is_sharded(self):
        if self.child_sampler is None:
            return False

        return self.child_sampler.is_sharded()

上述的两个采样器是比较常见的,我们在很多场合下都能使用到它们。同时我们还讲了这些内置采样器的基类BuiltinSampler类的一些常见方法(它们几乎被所有内置采样器继承并且使用),相信读者看了这部分内容后一定可以在借助API的情况下,完成对其它常见采样器(包括带权随机采样器WeightedRandomSampler,子集随机采样器SubsetRandomSampler等)的使用。

自定义采样器

接下来,我们来讲解一下如何通过自定义采样器来完成对数据集的样本数据采集工作,内容相对也比较简单。

首先,我们构造的采样器类必须继承自Sampler类,并且通过实现__iter__方法来自定义采样器的采样方式。

在这里,我将构建一个从下表0~20,间隔为3的采样器,并将其作用于CIFAR-10数据集,展示抽取样本数据的标签和形状。由于该采样器的逻辑比较简单,我们可以通过遍历CIFAR-10数据集的0~20个样本来验证该采样器成功实现我们所写的逻辑。

import mindspore.dataset as ds

class MySampler(ds.Sampler):
    def __iter__(self):
        for i in range(0, 10, 2):
            yield i

DATA_DIR = "./datasets/cifar-10-batches-bin/train/"

dataset = ds.Cifar10Dataset(DATA_DIR, sampler=MySampler())

for data in dataset.create_dict_iterator():
    print("Image shape:", data['image'].shape, ", Label:", data['label'])

print()
dataset = ds.Cifar10Dataset(DATA_DIR)
i = 0
for data in dataset.create_dict_iterator():
    if i % 2 == 0 and i <= 9:
    print("Image shape:", data['image'].shape, ", Label:", data['label'])
    i = i + 1

运行结果:

Image shape: (32, 32, 3) , Label: 6
Image shape: (32, 32, 3) , Label: 9
Image shape: (32, 32, 3) , Label: 1
Image shape: (32, 32, 3) , Label: 2
Image shape: (32, 32, 3) , Label: 8

Image shape: (32, 32, 3) , Label: 6
Image shape: (32, 32, 3) , Label: 9
Image shape: (32, 32, 3) , Label: 1
Image shape: (32, 32, 3) , Label: 2
Image shape: (32, 32, 3) , Label: 8

上述结果显而易见地证明了自定义采样器的合理性和正确性。

posted @ 2021-12-27 16:39  MS小白  阅读(44)  评论(0)    收藏  举报