主题:训练循环定制化实战:以CustomTrainer.fit为例
日期:2025-11-27
引言
大多数推荐系统训练器都沿用“训练-验证-保存”的套路,但在真实验里,我们往往需要扩展日志、显存控制、早停策略甚至外部回调。
Gaowei在训练自己的一个推荐模型的时候就写了一个CustomTrainer.fit,以“定制化训练方法技巧”为核心,抽象出通用实践,给大家分享下。
具体代码如下:
def fit(self, train_data, valid_data=None, verbose=True, saved=True, show_progress=False, callback_fn=None):
r"""Train the model based on the train data and the valid data.
This method overrides the parent fit method to add TensorBoard logging for all evaluation metrics.
Args:
train_data (DataLoader): the train data
valid_data (DataLoader, optional): the valid data, default: None.
verbose (bool, optional): whether to write training and evaluation information to logger, default: True
saved (bool, optional): whether to save the model parameters, default: True
show_progress (bool): Show the progress of training epoch and evaluate epoch. Defaults to ``False``.
callback_fn (callable): Optional callback function executed at end of epoch.
Includes (epoch_idx, valid_score) input arguments.
Returns:
(float, dict): best valid score and best valid result. If valid_data is None, it returns (-1, None)
"""
# Call parent fit method but override the validation logging part
if saved and self.start_epoch >= self.epochs:
self._save_checkpoint(-1)
self.eval_collector.data_collect(train_data)
# Initial validation before training
if valid_data is not None:
valid_start_time = time()
valid_score, valid_result = self._valid_epoch(valid_data, show_progress=show_progress)
valid_end_time = time()
valid_score_output = (set_color("epoch %d evaluating", 'green') + " [" + set_color("time", 'blue')
+ ": %.2fs, " + set_color("valid_score", 'blue') + ": %f]") % \
(-1, valid_end_time - valid_start_time, valid_score)
valid_result_output = set_color('valid result', 'blue') + ': \n' + dict2str(valid_result)
self.logger.info(valid_score_output)
self.logger.info(valid_result_output)
# Log all metrics to TensorBoard
self._log_metrics_to_tensorboard(valid_result, -1, prefix='Valid')
# Clear GPU cache after evaluation to free memory for training
if torch.cuda.is_available():
torch.cuda.empty_cache()
for epoch_idx in range(self.start_epoch, self.epochs):
# train
training_start_time = time()
train_loss = self._train_epoch(train_data, epoch_idx, show_progress=show_progress)
self.train_loss_dict[epoch_idx] = sum(train_loss) if isinstance(train_loss, tuple) else train_loss
training_end_time = time()
train_loss_output = \
self._generate_train_loss_output(epoch_idx, training_start_time, training_end_time, train_loss)
if verbose:
self.logger.info(train_loss_output)
# Log training loss to TensorBoard
if self.tensorboard is not None:
if isinstance(train_loss, tuple):
total_loss = sum(train_loss)
else:
total_loss = train_loss
self.tensorboard.add_scalar('Train/Loss', total_loss, epoch_idx)
# eval
if self.eval_step <= 0 or not valid_data:
if saved:
self._save_checkpoint(epoch_idx)
update_output = set_color('Saving current', 'blue') + ': %s' % self.saved_model_file
if verbose:
self.logger.info(update_output)
continue
if (epoch_idx + 1) % self.eval_step == 0:
valid_start_time = time()
valid_score, valid_result = self._valid_epoch(valid_data, show_progress=show_progress)
# Clear GPU cache after evaluation
if torch.cuda.is_available():
torch.cuda.empty_cache()
self.best_valid_score, self.cur_step, stop_flag, update_flag = early_stopping(
valid_score,
self.best_valid_score,
self.cur_step,
max_step=self.stopping_step,
bigger=self.valid_metric_bigger
)
valid_end_time = time()
valid_score_output = (set_color("epoch %d evaluating", 'green') + " [" + set_color("time", 'blue')
+ ": %.2fs, " + set_color("valid_score", 'blue') + ": %f]") % \
(epoch_idx, valid_end_time - valid_start_time, valid_score)
valid_result_output = set_color('valid result', 'blue') + ': \n' + dict2str(valid_result)
if verbose:
self.logger.info(valid_score_output)
self.logger.info(valid_result_output)
# Log all metrics to TensorBoard (including Valid_score)
self._log_metrics_to_tensorboard(valid_result, epoch_idx, prefix='Valid')
if update_flag:
if saved:
self._save_checkpoint(epoch_idx)
update_output = set_color('Saving current best', 'blue') + ': %s' % self.saved_model_file
if verbose:
self.logger.info(update_output)
self.best_valid_result = valid_result
if callback_fn:
callback_fn(epoch_idx, valid_score)
if stop_flag:
stop_output = 'Finished training, best eval result in epoch %d' % \
(epoch_idx - self.cur_step * self.eval_step)
if verbose:
self.logger.info(stop_output)
break
self._add_hparam_to_tensorboard(self.best_valid_score)
return self.best_valid_score, self.best_valid_result
flowchart LR
A[初始化与首轮验证] --> B[主训练循环]
B --> C{是否到评估间隔?}
C -- 否 --> B
C -- 是 --> D[验证+显存回收]
D --> E[早停/最优模型维护]
E --> F[回调与日志扩展]
F --> B
E -->|stop| G[训练结束+超参记录]
1. 训练前的初始化与首轮验证:保障状态一致
进入fit后,优先做一些初始化工作:
- 断点续训保障:首先检查是否需要保存模型,如果需要保存模型,则保存当前权重。若
start_epoch已达epochs,立即保存当前权重,实现场景化“resume”。 - 评估数据收集器:
eval_collector提前聚合训练数据,保障后续指标计算数据一致。 - 首轮验证即基准:如果存在验证数据,则进行首轮验证,并输出日志和TensorBoard。训练前跑一次
_valid_epoch,输出日志 +_log_metrics_to_tensorboard(valid_result, -1, prefix='Valid'),让后续曲线有参照;验证后若GPU可用则清理缓存,确保第一轮训练显存稳定。
2. 主训练循环:统一多损失与可视化
进入主训练循环后,开始训练模型,并记录训练损失、时间等指标,方便后续查看:
- 训练耗时统计:训练开始和结束时记录时间,配合
_generate_train_loss_output输出结构化日志。 - 多损失归一:
train_loss可能是tuple(主损失 + 正则/对比项),因此在写入self.train_loss_dict、TensorBoard前统一求和,保证图表只有一条主曲线。 - TensorBoard集成:
self.tensorboard.add_scalar('Train/Loss', total_loss, epoch_idx)提供最小工作量的可视化;若需要曲线对齐,可增加add_scalar('Train/Reg', ...)等子指标。 - 训练损失记录:训练损失记录到
self.train_loss_dict中,方便后续查看。
3. 动态评估、早停与最优模型存档
- 跳过验证的快速模式:
eval_step <= 0或无验证集时直接_save_checkpoint后continue,适合只关心最终模型的debug场景。 - 按步评估:达到
eval_step倍数时执行验证_valid_epoch,结果通过_log_metrics_to_tensorboard(valid_result, epoch_idx, prefix='Valid')一次性写入,多指标训练也能保持结构一致。 - 显存控制:验证后再次
torch.cuda.empty_cache(),避免评估阶段的推理激增导致下一训练轮OOM。 - 早停内核:
early_stopping(valid_score, best_valid_score, cur_step, ...)返回stop_flag和update_flag,分别指示“是否该停”“是否刷新最好成绩”。通过stopping_step+eval_step组合,可以构造“x个评估周期内无提升即停”的策略。early_stopping实现方法直观解释如下:如果当前指标value大于历史最佳best,则更新best和cur_step为当前值,并设置update_flag为True;否则,cur_step加1,如果cur_step大于max_step,则设置stop_flag为True。
- 最优模型持久化:当
update_flag为真时立即保存,并更新self.best_valid_result为最新指标快照,防止后续评估震荡覆盖佳绩。 - 回调钩子:
if callback_fn: callback_fn(epoch_idx, valid_score)为外部调度系统提供切入点,可用于动态调参、Slack通知等。 - 早停提示:
stop_flag触发后输出Finished training, best eval result...,追踪信息完整。 - 超参记录:退出循环时调用
_add_hparam_to_tensorboard(self.best_valid_score),将关键score等写入TensorBoard hparams表。若你的平台是MLflow/W&B,可在同位置统一上报。
5. _log_metrics_to_tensorboard实现细节
_log_metrics_to_tensorboard的实现细节如下:
- 定义要记录的指标:首先定义要记录的指标
metrics_to_log,然后遍历结果字典,记录所有匹配的指标。 - 遍历结果字典:在遍历结果字典时,需要检查指标名称是否匹配,如果匹配,则用
add_scalar记录到TensorBoard。 - 记录验证分数:最后记录验证分数。
浙公网安备 33010602011771号