mxnet modle.fit源码解析

 model.fit源码分析

首先来到module模块中,即https://github.com/apache/incubator-mxnet/tree/master/python/mxnet/module,进入base_module.py中,我们便可以看到fit()的原型。

 

class BaseModule(object):
    ################################################################################
    # High Level API
    ################################################################################
    def forward_backward(self, data_batch):
        """A convenient function that calls both ``forward`` and ``backward``."""
        self.forward(data_batch, is_train=True)
        self.backward()

    # 验证集评测
    def score(self, eval_data, eval_metric, num_batch=None, batch_end_callback=None,
              score_end_callback=None,
              reset=True, epoch=0, sparse_row_id_fn=None):
        """Runs prediction on ``eval_data`` and evaluates the performance according to
        the given ``eval_metric``.

        Checkout `Module Tutorial <https://mxnet.apache.org/api/python/tutorials/packages/module/index.html>`_
        to see an end-to-end use-case.

        Parameters
        ----------
        eval_data : DataIter
            Evaluation data to run prediction on.
        eval_metric : EvalMetric or list of EvalMetrics
            Evaluation metric to use.
        num_batch : int
            Number of batches to run. Defaults to ``None``, indicating run until the `DataIter`
            finishes.
        batch_end_callback : function
            Could also be a list of functions.
        reset : bool
            Defaults to ``True``. Indicates whether we should reset `eval_data` before starting
            evaluating.
        epoch : int
            Defaults to 0. For compatibility, this will be passed to callbacks (if any).
            During training, this will correspond to the training epoch number.
        sparse_row_id_fn : A callback function
            The function  takes `data_batch` as an input and returns a dict of
            str -> NDArray. The resulting dict is used for pulling row_sparse
            parameters from the kvstore, where the str key is the name of the param,
            and the value is the row id of the param to pull.

        Examples
        --------
        >>> # An example of using score for prediction.
        >>> # Evaluate accuracy on val_dataiter
        >>> metric = mx.metric.Accuracy()
        >>> mod.score(val_dataiter, metric)
        >>> mod.score(val_dataiter, ['mse', 'acc'])
        """
        assert self.binded and self.params_initialized

        # reset验证集
        if reset:
            eval_data.reset()

        if not isinstance(eval_metric, metric.EvalMetric):
            eval_metric = metric.create(eval_metric)

        eval_metric.reset()
        actual_num_batch = 0

        # 验证集batch获取
        for nbatch, eval_batch in enumerate(eval_data):
            if num_batch is not None and nbatch == num_batch:
                break
            # 模型加载数据集
            self.prepare(eval_batch, sparse_row_id_fn=sparse_row_id_fn)
            # 前向传播
            self.forward(eval_batch, is_train=False)
            # 调用metric列表update函数
            if isinstance(eval_batch, list):
                self.update_metric(eval_metric, [eb.label for eb in eval_batch], pre_sliced=True)
            else:
                self.update_metric(eval_metric, eval_batch.label)

            # batch结束回调
            if batch_end_callback is not None:
                batch_end_params = BatchEndParam(epoch=epoch,
                                                 nbatch=nbatch,
                                                 eval_metric=eval_metric,
                                                 locals=locals())
                for callback in _as_list(batch_end_callback):
                    callback(batch_end_params)
            actual_num_batch += 1

        # 验证集评测结束回调
        if score_end_callback:
            params = BatchEndParam(epoch=epoch,
                                   nbatch=actual_num_batch,
                                   eval_metric=eval_metric,
                                   locals=locals())
            for callback in _as_list(score_end_callback):
                callback(params)

        # 返回metric列表结果name:value
        return eval_metric.get_name_value()

    def fit(self, train_data, eval_data=None, eval_metric='acc',
            epoch_end_callback=None, batch_end_callback=None, kvstore='local',
            optimizer='sgd', optimizer_params=(('learning_rate', 0.01),),
            eval_end_callback=None,
            eval_batch_end_callback=None, initializer=Uniform(0.01),
            arg_params=None, aux_params=None, allow_missing=False,
            force_rebind=False, force_init=False, begin_epoch=0, num_epoch=None,
            validation_metric=None, monitor=None, sparse_row_id_fn=None):
        """Trains the module parameters.

        Checkout `Module Tutorial <https://mxnet.apache.org/api/python/tutorials/packages/module/index.html>`_
        to see an end-to-end use-case.

        Parameters
        ----------
        train_data : DataIter
            训练集数据迭代器
        eval_data : DataIter
            如果不是'None',将用作验证集,并将评估每个时期之后的性能。
        eval_metric : str or EvalMetric
            默认是字符串'accuracy'.训练期间用来显示的绩效指标。
            其他可能的预定义指标是:'ce' (CrossEntropy), 'f1', 'mae', 'mse', 'rmse', 'top_k_accuracy'.
        epoch_end_callback : function or list of functions
            每个epoch结束时回调,参数 `epoch`, `symbol`, `arg_params`and `aux_params`
        batch_end_callback : function or list of function
            每个batch结束时回调,参数 `BatchEndParam`.
        kvstore : str or KVStore
            参数更新设备,默认值'local'.
            "device",GPU计算梯度更新权重
            "local",CPU更新
            "dist_device_sync",分布式训练
        optimizer : str or Optimizer
            优化器,默认值'sgd'.
        optimizer_params : dict
            优化器参数,默认值(('learning_rate', 0.01),)
        eval_end_callback : function or list of function
            evaluation全跑完回调
        eval_batch_end_callback : function or list of function
            evaluation一个batch跑完回调
        initializer : Initializer
            如果尚未初始化模块参数,则调用初始化程序来初始化它们
        arg_params : dict
            默认None, 值不为None,则替代initializer初始化参数
        aux_params : dict
            默认None, 值不为None,则替代initializer初始化参数
        allow_missing : bool
            默认False,是否允许丢失参数
            指示当arg_params和aux_params不为None时是否允许缺少参数。
            allow_missing=True,那么缺少的参数将通过initializer进行初始化。
        force_rebind : bool
            默认False
            如果已经绑定执行器,是否强制重新绑定执行器。
        force_init : bool
            默认False
            指示即使参数已经初始化也是否强制初始化。
        begin_epoch : int
            默认值0
            指示开始epoch。通常,如果从前一个训练阶段在Epoch[n]保存,重新训练则该值应为n+1
        num_epoch : int
            训练的epoch数量
        sparse_row_id_fn : A callback function
            The function  takes `data_batch` as an input and returns a dict of
            str -> NDArray. The resulting dict is used for pulling row_sparse
            parameters from the kvstore, where the str key is the name of the param,
            and the value is the row id of the param to pull.

        Examples
        --------
        >>> # An example of using fit for training.
        >>> # Assume training dataIter and validation dataIter are ready
        >>> # Assume loading a previously checkpointed model
        >>> sym, arg_params, aux_params = mx.model.load_checkpoint(model_prefix, 3)
        >>> mod.fit(train_data=train_dataiter, eval_data=val_dataiter, optimizer='sgd',
        ...     optimizer_params={'learning_rate':0.01, 'momentum': 0.9},
        ...     arg_params=arg_params, aux_params=aux_params,
        ...     eval_metric='acc', num_epoch=10, begin_epoch=3)
        """
        assert num_epoch is not None, 'please specify number of epochs'

        # 绑定训练集数据symbols name
        self.bind(data_shapes=train_data.provide_data, label_shapes=train_data.provide_label,
                  for_training=True, force_rebind=force_rebind)
        if monitor is not None:
            self.install_monitor(monitor)
        # 初始化权重参数,初始化策略参考以上的参数说明
        self.init_params(initializer=initializer, arg_params=arg_params, aux_params=aux_params,
                         allow_missing=allow_missing, force_init=force_init)
        # 初始化优化器
        self.init_optimizer(kvstore=kvstore, optimizer=optimizer,
                            optimizer_params=optimizer_params)

        # 验证评估
        if validation_metric is None:
            validation_metric = eval_metric
        # str类型的eval_metric转metric.EvalMetric
        if not isinstance(eval_metric, metric.EvalMetric):
            eval_metric = metric.create(eval_metric)

        ################################################################################
        # training loop
        ################################################################################
        # for循环训练
        for epoch in range(begin_epoch, num_epoch):
            tic = time.time()
            # 每一轮的评估reset
            eval_metric.reset()
            # nbatch计数
            nbatch = 0
            data_iter = iter(train_data)
            end_of_batch = False
            next_data_batch = next(data_iter)
            # 循环next()获取训练集一个batch数据
            while not end_of_batch:
                data_batch = next_data_batch
                if monitor is not None:
                    monitor.tic()
                # 前向传播 + 反向传播计算梯度
                self.forward_backward(data_batch)
                # 根据优化器梯度更新权重
                self.update()

                # 评估更新,调用metric的update
                if isinstance(data_batch, list):
                    self.update_metric(eval_metric,
                                       [db.label for db in data_batch],
                                       pre_sliced=True)
                else:
                    self.update_metric(eval_metric, data_batch.label)

                # 获取下一个batch数据
                try:
                    # pre fetch next batch
                    next_data_batch = next(data_iter)
                    self.prepare(next_data_batch, sparse_row_id_fn=sparse_row_id_fn)
                except StopIteration:
                    end_of_batch = True

                if monitor is not None:
                    monitor.toc_print()

                # 获取eval_metric列表的结果name:value
                if end_of_batch:
                    eval_name_vals = eval_metric.get_global_name_value()

                # batch结束回调
                if batch_end_callback is not None:
                    batch_end_params = BatchEndParam(epoch=epoch, nbatch=nbatch,
                                                     eval_metric=eval_metric,
                                                     locals=locals())
                    for callback in _as_list(batch_end_callback):
                        callback(batch_end_params)
                nbatch += 1

            # one epoch of training is finished
            # 每一个epoch结束,输出eval_metric评价列表结果, Train-xxx=xxx
            for name, val in eval_name_vals:
                self.logger.info('Epoch[%d] Train-%s=%f', epoch, name, val)
            # 输出每一个epoch时间
            toc = time.time()
            self.logger.info('Epoch[%d] Time cost=%.3f', epoch, (toc-tic))

            # 参数同步
            # sync aux params across devices
            arg_params, aux_params = self.get_params()
            self.set_params(arg_params, aux_params)

            # 每一个epoch结束回调
            if epoch_end_callback is not None:
                for callback in _as_list(epoch_end_callback):
                    callback(epoch, self.symbol, arg_params, aux_params)

            #----------------------------------------
            # evaluation on validation set
            # 验证集评测,validation_metric为None时与训练集的metric列表一致
            if eval_data:
                res = self.score(eval_data, validation_metric,
                                 score_end_callback=eval_end_callback,
                                 batch_end_callback=eval_batch_end_callback, epoch=epoch)
                #TODO: pull this into default
                # 输出验证集评测log
                for name, val in res:
                    self.logger.info('Epoch[%d] Validation-%s=%f', epoch, name, val)

            # end of 1 epoch, reset the data-iter for another epoch
            # 复位训练集数据
            train_data.reset()

 

训练log源码分析

_cb = mx.callback.Speedometer(batch_size, frequent)
def _batch_callback(param):
  # 显示训练log,INFO:root:Epoch[26] Batch [0-20] Speed: 257.26 samples/sec acc=0.968571 lossvalue=0.331392
  _cb(param)

 

class Speedometer(object):
    """Logs training speed and evaluation metrics periodically.

    Parameters
    ----------
    batch_size: int
        Batch size of data.
    frequent: int
        Specifies how frequently training speed and evaluation metrics
        must be logged. Default behavior is to log once every 50 batches.
    auto_reset : bool
        Reset the evaluation metrics after each log.

    Example
    -------
    >>> # Print training speed and evaluation metrics every ten batches. Batch size is one.
    >>> module.fit(iterator, num_epoch=n_epoch,
    ... batch_end_callback=mx.callback.Speedometer(1, 10))
    Epoch[0] Batch [10] Speed: 1910.41 samples/sec  Train-accuracy=0.200000
    Epoch[0] Batch [20] Speed: 1764.83 samples/sec  Train-accuracy=0.400000
    Epoch[0] Batch [30] Speed: 1740.59 samples/sec  Train-accuracy=0.500000
    """
    def __init__(self, batch_size, frequent=50, auto_reset=True):
        self.batch_size = batch_size
        self.frequent = frequent
        self.init = False
        self.tic = 0
        self.last_count = 0
        self.auto_reset = auto_reset

    def __call__(self, param):
        """Callback to Show speed."""
        count = param.nbatch
        # 跳过nbatch=0的log输出
        if self.last_count > count:
            self.init = False
        self.last_count = count

        if self.init:
            # frequent个batch进行一次log输出
            if count % self.frequent == 0:
                # #11504
                # 计算每一个frequent训练的速度,Speed: 257.26 samples/sec代表1s能训练多少张
                try:
                    speed = self.frequent * self.batch_size / (time.time() - self.tic)
                except ZeroDivisionError:
                    speed = float('inf')
                # 输出log,Speed训练速度,eval_metric列表的name:value
                if param.eval_metric is not None:
                    # 获取模型eval_metric的计算结果name:value
                    name_value = param.eval_metric.get_name_value()
                    if self.auto_reset:
                        param.eval_metric.reset_local()
                        msg = 'Epoch[%d] Batch [%d-%d]\tSpeed: %.2f samples/sec'
                        msg += '\t%s=%f'*len(name_value)
                        logging.info(msg, param.epoch, count-self.frequent, count, speed, *sum(name_value, ()))
                    else:
                        msg = 'Epoch[%d] Batch [0-%d]\tSpeed: %.2f samples/sec'
                        msg += '\t%s=%f'*len(name_value)
                        logging.info(msg, param.epoch, count, speed, *sum(name_value, ()))
                else:
                    logging.info("Iter[%d] Batch [%d]\tSpeed: %.2f samples/sec",
                                 param.epoch, count, speed)
                self.tic = time.time()
        else:
            self.init = True
            self.tic = time.time()

  

 

posted @ 2020-08-14 15:45  陈晓涛  阅读(640)  评论(0编辑  收藏  举报