主题:训练循环定制化实战:以CustomTrainer.fit为例

Posted on 2025-11-27 15:59  0泡  阅读(1)  评论(0)    收藏  举报

主题:训练循环定制化实战:以CustomTrainer.fit为例

日期:2025-11-27

引言

大多数推荐系统训练器都沿用“训练-验证-保存”的套路,但在真实验里,我们往往需要扩展日志、显存控制、早停策略甚至外部回调。
Gaowei在训练自己的一个推荐模型的时候就写了一个CustomTrainer.fit,以“定制化训练方法技巧”为核心,抽象出通用实践,给大家分享下。
具体代码如下:

    def fit(self, train_data, valid_data=None, verbose=True, saved=True, show_progress=False, callback_fn=None):
        r"""Train the model based on the train data and the valid data.
        
        This method overrides the parent fit method to add TensorBoard logging for all evaluation metrics.
        
        Args:
            train_data (DataLoader): the train data
            valid_data (DataLoader, optional): the valid data, default: None.
            verbose (bool, optional): whether to write training and evaluation information to logger, default: True
            saved (bool, optional): whether to save the model parameters, default: True
            show_progress (bool): Show the progress of training epoch and evaluate epoch. Defaults to ``False``.
            callback_fn (callable): Optional callback function executed at end of epoch.
                                    Includes (epoch_idx, valid_score) input arguments.

        Returns:
             (float, dict): best valid score and best valid result. If valid_data is None, it returns (-1, None)
        """
        # Call parent fit method but override the validation logging part
        if saved and self.start_epoch >= self.epochs:
            self._save_checkpoint(-1)

        self.eval_collector.data_collect(train_data)

        # Initial validation before training
        if valid_data is not None:
            valid_start_time = time()
            valid_score, valid_result = self._valid_epoch(valid_data, show_progress=show_progress)
            valid_end_time = time()
            valid_score_output = (set_color("epoch %d evaluating", 'green') + " [" + set_color("time", 'blue')
                                + ": %.2fs, " + set_color("valid_score", 'blue') + ": %f]") % \
                                    (-1, valid_end_time - valid_start_time, valid_score)
            valid_result_output = set_color('valid result', 'blue') + ': \n' + dict2str(valid_result)
            self.logger.info(valid_score_output)
            self.logger.info(valid_result_output)
            
            # Log all metrics to TensorBoard
            self._log_metrics_to_tensorboard(valid_result, -1, prefix='Valid')
            
            # Clear GPU cache after evaluation to free memory for training
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

        for epoch_idx in range(self.start_epoch, self.epochs):
            # train
            training_start_time = time()
            train_loss = self._train_epoch(train_data, epoch_idx, show_progress=show_progress)
            self.train_loss_dict[epoch_idx] = sum(train_loss) if isinstance(train_loss, tuple) else train_loss
            training_end_time = time()
            train_loss_output = \
                self._generate_train_loss_output(epoch_idx, training_start_time, training_end_time, train_loss)
            if verbose:
                self.logger.info(train_loss_output)
            
            # Log training loss to TensorBoard
            if self.tensorboard is not None:
                if isinstance(train_loss, tuple):
                    total_loss = sum(train_loss)
                else:
                    total_loss = train_loss
                self.tensorboard.add_scalar('Train/Loss', total_loss, epoch_idx)

            # eval
            if self.eval_step <= 0 or not valid_data:
                if saved:
                    self._save_checkpoint(epoch_idx)
                    update_output = set_color('Saving current', 'blue') + ': %s' % self.saved_model_file
                    if verbose:
                        self.logger.info(update_output)
                continue
            if (epoch_idx + 1) % self.eval_step == 0:
                valid_start_time = time()
                valid_score, valid_result = self._valid_epoch(valid_data, show_progress=show_progress)
                # Clear GPU cache after evaluation
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
                self.best_valid_score, self.cur_step, stop_flag, update_flag = early_stopping(
                    valid_score,
                    self.best_valid_score,
                    self.cur_step,
                    max_step=self.stopping_step,
                    bigger=self.valid_metric_bigger
                )
                valid_end_time = time()
                valid_score_output = (set_color("epoch %d evaluating", 'green') + " [" + set_color("time", 'blue')
                                    + ": %.2fs, " + set_color("valid_score", 'blue') + ": %f]") % \
                                     (epoch_idx, valid_end_time - valid_start_time, valid_score)
                valid_result_output = set_color('valid result', 'blue') + ': \n' + dict2str(valid_result)
                if verbose:
                    self.logger.info(valid_score_output)
                    self.logger.info(valid_result_output)
                
                # Log all metrics to TensorBoard (including Valid_score)
                self._log_metrics_to_tensorboard(valid_result, epoch_idx, prefix='Valid')

                if update_flag:
                    if saved:
                        self._save_checkpoint(epoch_idx)
                        update_output = set_color('Saving current best', 'blue') + ': %s' % self.saved_model_file
                        if verbose:
                            self.logger.info(update_output)
                    self.best_valid_result = valid_result

                if callback_fn:
                    callback_fn(epoch_idx, valid_score)

                if stop_flag:
                    stop_output = 'Finished training, best eval result in epoch %d' % \
                                  (epoch_idx - self.cur_step * self.eval_step)
                    if verbose:
                        self.logger.info(stop_output)
                    break
        self._add_hparam_to_tensorboard(self.best_valid_score)
        return self.best_valid_score, self.best_valid_result
flowchart LR A[初始化与首轮验证] --> B[主训练循环] B --> C{是否到评估间隔?} C -- 否 --> B C -- 是 --> D[验证+显存回收] D --> E[早停/最优模型维护] E --> F[回调与日志扩展] F --> B E -->|stop| G[训练结束+超参记录]

1. 训练前的初始化与首轮验证:保障状态一致

进入fit后,优先做一些初始化工作:

  • 断点续训保障:首先检查是否需要保存模型,如果需要保存模型,则保存当前权重。若start_epoch已达epochs,立即保存当前权重,实现场景化“resume”。
  • 评估数据收集器eval_collector提前聚合训练数据,保障后续指标计算数据一致。
  • 首轮验证即基准:如果存在验证数据,则进行首轮验证,并输出日志和TensorBoard。训练前跑一次_valid_epoch,输出日志 + _log_metrics_to_tensorboard(valid_result, -1, prefix='Valid'),让后续曲线有参照;验证后若GPU可用则清理缓存,确保第一轮训练显存稳定。

2. 主训练循环:统一多损失与可视化

进入主训练循环后,开始训练模型,并记录训练损失、时间等指标,方便后续查看:

  • 训练耗时统计:训练开始和结束时记录时间,配合_generate_train_loss_output输出结构化日志。
  • 多损失归一train_loss可能是tuple(主损失 + 正则/对比项),因此在写入self.train_loss_dict、TensorBoard前统一求和,保证图表只有一条主曲线。
  • TensorBoard集成self.tensorboard.add_scalar('Train/Loss', total_loss, epoch_idx)提供最小工作量的可视化;若需要曲线对齐,可增加add_scalar('Train/Reg', ...)等子指标。
  • 训练损失记录:训练损失记录到self.train_loss_dict中,方便后续查看。

3. 动态评估、早停与最优模型存档

  • 跳过验证的快速模式eval_step <= 0或无验证集时直接_save_checkpointcontinue,适合只关心最终模型的debug场景。
  • 按步评估:达到eval_step倍数时执行验证_valid_epoch,结果通过_log_metrics_to_tensorboard(valid_result, epoch_idx, prefix='Valid')一次性写入,多指标训练也能保持结构一致。
  • 显存控制:验证后再次torch.cuda.empty_cache(),避免评估阶段的推理激增导致下一训练轮OOM。
  • 早停内核early_stopping(valid_score, best_valid_score, cur_step, ...)返回stop_flagupdate_flag,分别指示“是否该停”“是否刷新最好成绩”。通过stopping_step+eval_step组合,可以构造“x个评估周期内无提升即停”的策略。
    • early_stopping实现方法直观解释如下:如果当前指标value大于历史最佳best,则更新bestcur_step为当前值,并设置update_flag为True;否则,cur_step加1,如果cur_step大于max_step,则设置stop_flag为True。
  • 最优模型持久化:当update_flag为真时立即保存,并更新self.best_valid_result为最新指标快照,防止后续评估震荡覆盖佳绩。
  • 回调钩子if callback_fn: callback_fn(epoch_idx, valid_score)为外部调度系统提供切入点,可用于动态调参、Slack通知等。
  • 早停提示stop_flag触发后输出Finished training, best eval result...,追踪信息完整。
  • 超参记录:退出循环时调用_add_hparam_to_tensorboard(self.best_valid_score),将关键score等写入TensorBoard hparams表。若你的平台是MLflow/W&B,可在同位置统一上报。

5. _log_metrics_to_tensorboard实现细节

_log_metrics_to_tensorboard的实现细节如下:

  • 定义要记录的指标:首先定义要记录的指标metrics_to_log,然后遍历结果字典,记录所有匹配的指标。
  • 遍历结果字典:在遍历结果字典时,需要检查指标名称是否匹配,如果匹配,则用add_scalar记录到TensorBoard。
  • 记录验证分数:最后记录验证分数。