hf trainner
Hugging Face Trainer 类核心训练流程(完整梳理版)
你需要一份对 Hugging Face Trainer 类核心训练流程的完整梳理,保留所有核心细节且逻辑连贯,下面将从整体架构、三大核心方法、关键核心概念三个维度进行全面整合梳理:
一、整体架构:Trainer 训练的「三段式」全局流程
Trainer 类的所有训练逻辑均以 train() 方法为统一入口,整体遵循「初始化准备 → 核心训练循环 → 收尾保存」的三大阶段,各阶段层层递进、职责明确,全局调用链路和流程框架如下:
train() 入口方法
├── 阶段1:初始化与准备工作(训练前)
│ ├── 参数与环境校验
│ ├── 模型与优化器/调度器初始化
│ ├── 数据加载与预处理
│ └── 断点续训:Checkpoint 加载(可选)
├── 阶段2:核心训练循环(训练中,_inner_training_loop() 承载)
│ ├── 训练前二次初始化与配置
│ ├── 三层嵌套循环(Epoch → 梯度累积块 → 单Batch)
│ │ ├── 单Batch训练(training_step() 承载)
│ │ │ └── 损失计算(compute_loss() 承载)
│ │ ├── 梯度累积完成后:梯度裁剪→参数更新→学习率调度→梯度清零
│ │ └── 每步/每轮:日志记录→评估→保存Checkpoint
│ └── 训练循环终止判断
└── 阶段3:收尾与保存(训练后)
├── 最佳模型加载(可选)
├── 训练结果整理与指标汇总
├── 资源清理与Checkpoint瘦身
├── 模型推送至Hub(可选)
└── 返回 TrainOutput 训练结果
二、阶段1:初始化与准备工作(train() 方法前期)
该阶段的核心目标是为后续训练循环搭建好基础环境,完成各类配置、资源的校验与初始化,避免训练过程中出现参数不匹配、资源缺失等问题,核心操作包括:
-
参数与环境校验
- 校验
TrainingArguments关键参数一致性(如eval_strategy与eval_dataset匹配、save_strategy与metric_for_best_model匹配)。 - 初始化分布式训练环境(Accelerator)、混合精度训练(fp16/bf16)、日志系统与回调管理器(
CallbackHandler)。 - 检查模型合法性(是否为可训练的
AutoModelForXxx类),处理量化模型、PEFT 适配器模型的特殊逻辑。
- 校验
-
模型与优化器/调度器初始化
- 若传入
model_init,调用该方法实例化模型(支持超参搜索时重新初始化)。 - 调用
create_optimizer_and_scheduler()完成双层初始化:create_optimizer():划分带权重衰减/不带权重衰减的参数组,支持 AdamW、Adafactor、BitsAndBytes 8bit 优化器等。create_scheduler():基于num_training_steps初始化调度器,支持linear/cosine等多种调度策略。
- 可选激活 NEFTune:通过
_activate_neftune()为嵌入层注册前向钩子,注入噪声提升模型泛化性。
- 若传入
-
数据加载与预处理
- 调用
get_train_dataloader()构建训练数据加载器:- 自动移除模型
forward()不接受的列(_remove_unused_columns()),避免前向传播报错。 - 构建采样器(默认
RandomSampler,支持LengthGroupedSampler按序列长度分组采样,提升训练效率)。 - 封装数据整理器(
DataCollator,默认default_data_collator,支持动态 padding,减少无效计算)。
- 自动移除模型
- 计算总批次大小:
total_train_batch_size = 单设备批次 × 梯度累积步数 × 分布式进程数,作为后续训练步数计算的基础。
- 调用
-
断点续训:Checkpoint 加载(可选)
- 若指定
resume_from_checkpoint,调用_load_from_checkpoint()完整加载训练状态:- 模型权重、优化器状态、学习率调度器状态。
- RNG 随机状态(保证续跑训练的可复现性)。
- 校验检查点与当前训练参数的一致性(如批次大小、日志步数),避免训练逻辑冲突。
- 若指定
三、阶段2:核心训练循环(_inner_training_loop() 承载)
该阶段是 Trainer 训练的核心执行环节,承接阶段1的初始化准备,完成「多轮 epoch 迭代 → 梯度累积 → 单批次训练 → 参数更新 → 日志/评估/保存」的完整闭环,最终触发训练终止条件。
子阶段2.1:训练循环前置准备(二次初始化与配置)
在进入 epoch 循环前,完成分布式适配、训练状态校准、断点续训数据跳过等关键准备,确保训练启动的稳定性。
- 显存预热与批量大小适配
- 调用
self.accelerator.free_memory()释放加速器闲置显存,避免初始化阶段显存堆积。 - 若启用
auto_find_batch_size,自动调整批量大小并更新 DeepSpeed 配置,释放冗余内存。
- 调用
- 数据加载器与训练参数校准
- 获取训练数据加载器
train_dataloader,对 FSDP-XLA v2 等分布式模式做 TPU 适配处理。 - 调用
set_initial_training_values()计算核心控制参数:- 总训练 epoch 数
num_train_epochs、每 epoch 优化步数num_update_steps_per_epoch。 - 总优化步数
max_steps(优先级高于 epoch 数,达到后直接终止训练)。 - 判定训练模式:
epoch_based(按 epoch 训练)或step_based(按步数训练)。
- 总训练 epoch 数
- 若启用
include_tokens_per_second,统计总训练令牌数num_train_tokens,用于后续速度评估。
- 获取训练数据加载器
- 调试模式与分布式兼容性检查
- 若启用
DebugOption.UNDERFLOW_OVERFLOW下溢/上溢调试,校验分布式模式(仅支持 DDP,不支持 DP),初始化调试工具DebugUnderflowOverflow。 - 判定优化器延迟创建逻辑:SageMaker MP、FSDP/XLA 等模式需延迟创建,FSDP2 模式禁用延迟创建。
- 若启用
- 优化器与调度器最终初始化
- 若启用 DeepSpeed,通过
deepspeed_init()完成优化器、调度器的分布式初始化。 - 若非延迟创建模式,调用
create_optimizer_and_scheduler()基于max_steps创建优化器(默认 AdamW)和学习率调度器(默认线性暖场)。
- 若启用 DeepSpeed,通过
- 模型包装与分布式适配
- 初始化
TrainerState训练状态对象,记录全局步数、最佳指标、回调状态等,计算日志/评估/保存的绝对触发步数。 - 若启用
gradient_checkpointing,开启模型梯度检查点功能(牺牲少量速度,大幅节省显存)。 - 调用
_wrap_model()包装模型,适配 DDP/FSDP/DeepSpeed 等分布式模式,通过accelerator.prepare()完成模型、优化器、调度器的设备迁移与并行适配。
- 初始化
- 断点续训:训练状态完整恢复
- 若指定
resume_from_checkpoint,根据分布式模式加载模型权重、优化器状态、调度器状态、梯度缩放器状态。 - 加载
TrainerState,计算已训练 epoch 数epochs_trained和当前 epoch 已训练步数,生成数据跳过偏移量,避免重复训练已处理数据。
- 若指定
- 训练配置日志打印
输出核心训练信息,方便用户校验配置:- 数据维度:总样本数、每 epoch 样本数、总令牌数(若统计)。
- 训练维度:单设备批次大小、总批次大小(含并行/梯度累积)、梯度累积步数、总优化步数。
- 模型维度:可训练参数数量、分布式模式、混合精度类型(fp16/bf16)。
子阶段2.2:三层嵌套训练循环(核心执行逻辑)
采用 「Epoch 外层循环 → 梯度累积块中层循环 → 单 Batch 内层循环」 的三层结构,核心是实现梯度累积,平衡显存占用与训练效果。
# 外层:Epoch 循环(遍历训练轮数)
for epoch in range(epochs_trained, num_train_epochs):
# Epoch 前置准备
self._epoch_setup(epoch) # 设置 DataLoader epoch、重置 K/V 缓存等
epoch_iterator = self.get_train_dataloader()
total_updates = self._get_num_updates_per_epoch(epoch) # 本轮 epoch 需完成的更新步数
tr_loss_epoch = 0.0 # 本轮 epoch 损失累积
# 中层:梯度累积块循环(按梯度累积步数分块)
for update_step in range(total_updates):
# 获取当前累积块的所有 Batch 数据(数量 = gradient_accumulation_steps)
batch_samples, num_items_in_batch = self.get_batch_samples(
epoch_iterator, args.gradient_accumulation_steps, args.device
)
current_grad_accum_steps = len(batch_samples) # 适配最后一块的剩余 Batch
tr_loss_block = 0.0 # 本累积块损失累积
# 内层:单 Batch 循环(遍历累积块中的每个 Batch)
for i, inputs in enumerate(batch_samples):
# -------------------------- 单 Batch 训练核心 --------------------------
# 1. 令牌数统计(若启用)
if args.include_tokens_per_second:
num_tokens = self._get_num_tokens(inputs)
self.state.total_flos += self._compute_flos(inputs, num_tokens)
self.state.num_tokens_seen += num_tokens
# 2. 梯度同步上下文管理(非最后一个 Batch 禁用梯度同步,提升效率)
sync_grad = (i == current_grad_accum_steps - 1) or (self.state.global_step + 1 >= args.max_steps)
with self.accelerator.no_sync(model) if not sync_grad else nullcontext():
# 3. 调用 training_step 完成单 Batch 训练:前向传播→损失计算→反向传播
tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
tr_loss_block += tr_loss_step
tr_loss_epoch += tr_loss_step
# -------------------------- 梯度累积完成:参数更新 --------------------------
if sync_grad:
# 梯度裁剪:防止梯度爆炸(适配 Apex/DeepSpeed 等特殊模式)
if args.max_grad_norm > 0:
self.accelerator.clip_grad_norm_(model.parameters(), args.max_grad_norm)
# 触发优化器步骤前回调
self.callback_handler.on_pre_optimizer_step(self.args, self.state, self.control)
# 优化器更新参数 + 学习率调度
self.optimizer.step()
if not isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
self.lr_scheduler.step()
# 触发优化器步骤后回调
self.callback_handler.on_optimizer_step(self.args, self.state, self.control)
# 梯度清零,准备下一个累积块
model.zero_grad()
self.state.global_step += 1 # 全局步数+1
# -------------------------- 日志/评估/保存 --------------------------
self._maybe_log_save_evaluate(
model, tr_loss_block / current_grad_accum_steps, epoch, update_step
)
# -------------------------- 训练终止判断 --------------------------
if self.state.global_step >= args.max_steps or self.control.should_training_stop:
break
if self.state.global_step >= args.max_steps or self.control.should_training_stop:
break
# Epoch 收尾:触发 epoch 结束回调、评估、保存
self._epoch_end_callback(epoch, tr_loss_epoch)
if self.state.global_step >= args.max_steps or self.control.should_training_stop:
break
三层循环的核心细节
-
Epoch 外层循环
- 每轮 epoch 开始前调用
_epoch_setup():设置 DataLoader 的 epoch 种子(保证 shuffle 一致性)、重置 Transformer K/V 缓存等。 - 触发
on_epoch_begin回调,支持用户自定义 epoch 前置操作(如学习率调整、数据增强策略切换)。 - 每轮 epoch 结束后调用
_epoch_end_callback():触发on_epoch_end回调、执行评估、保存 checkpoint。
- 每轮 epoch 开始前调用
-
梯度累积块中层循环
- 按
gradient_accumulation_steps将 epoch 数据划分为多个累积块,每个块包含 N 个 Batch。 - 核心目的:累积 N 个 Batch 的梯度后再执行一次参数更新,等价于「批量大小 × N」的训练效果,但显存占用仅为「单 Batch + 梯度缓存」。
- 处理最后一个累积块的边界情况:若剩余 Batch 数不足 N,自动调整
current_grad_accum_steps,避免丢弃数据。
- 按
-
单 Batch 内层循环
- 遍历累积块中的每个 Batch,调用
training_step()完成单批次训练(前向传播→损失计算→反向传播)。 - 梯度同步控制:非累积块最后一个 Batch 时,通过
accelerator.no_sync()禁用梯度同步,减少通信开销。 - 损失累积:将每个 Batch 的损失累加到块损失
tr_loss_block和 epoch 损失tr_loss_epoch。
- 遍历累积块中的每个 Batch,调用
关键辅助操作:_maybe_log_save_evaluate()
每次参数更新后触发,是训练过程中「可视化、验证、固化成果」的核心入口:
- 日志记录:打印训练损失、学习率、梯度范数、显存占用等指标,支持 TensorBoard/WandB 等可视化工具。
- 模型评估:若达到
eval_steps或 epoch 结束,调用_evaluate()运行验证集,通过compute_metrics()计算任务指标(如准确率、BLEU 分数),更新最佳模型指标。 - Checkpoint 保存:若达到
save_steps或触发最佳指标,调用_save_checkpoint()保存模型权重、优化器状态、训练状态,支持save_total_limit清理过期 checkpoint。 - 超参搜索汇报:若启用超参搜索(Optuna/Ray),汇报当前试验结果,支持早停(Prune)效果差的试验。
子阶段2.3:训练循环终止与临时收尾
当满足以下任一条件时,立即终止训练循环:
- 全局步数
state.global_step >= max_steps(优先终止条件)。 - 回调控制
control.should_training_stop = True(用户自定义早停)。 - 所有 epoch 遍历完成。
终止后执行临时收尾:
- 清理 Transformer K/V 缓存等临时状态,释放内存。
- 统计训练过程中的总损失、总浮点运算量
total_flos等指标。
四、阶段3:收尾与保存(训练后)
该阶段的核心目标是完成训练后的资源整理、结果固化,确保训练成果可复用、可追溯,核心操作包括:
-
最佳模型加载(可选)
- 若
load_best_model_at_end=True,调用_load_best_model()加载训练过程中基于metric_for_best_model判定的最佳模型 Checkpoint。
- 若
-
训练结果整理与指标汇总
- 计算平均训练损失:
平均损失 = _total_loss_scalar / effective_global_step(做保底处理,避免除零错误)。 - 收集多维度训练指标:速度指标(每秒样本数、每秒令牌数)、浮点运算量(
total_flos)、内存使用情况、可训练参数数量等。
- 计算平均训练损失:
-
资源清理与 Checkpoint 瘦身
- 停用 NEFTune(若启用):调用
_deactivate_neftune()移除嵌入层前向钩子,释放相关资源。 - 根据
save_total_limit删除过期 Checkpoint,仅保留最新/最佳的若干个,节省存储空间。 - 清理过往状态缓存(
_past,如 Transformer K/V 缓存)、释放加速器闲置显存。
- 停用 NEFTune(若启用):调用
-
模型推送与结果返回
- 若
push_to_hub=True,调用push_to_hub()上传模型、配置文件、训练参数到 Hugging Face Hub,实现模型共享。 - 封装
TrainOutput对象,返回全局步数、平均训练损失、训练指标字典,结束整个训练流程。
- 若
五、三大核心方法:Trainer 训练的「核心引擎」
_inner_training_loop()、training_step()、compute_loss() 是 Trainer 类的三大核心方法,构成了训练流程的「三级调用链路」,也是整个训练逻辑的核心承载者,三者分工明确、层层调用。
方法1:_inner_training_loop() - 核心训练主循环(统筹全局)
该方法是「端到端训练」的核心统筹者,承接训练前的准备工作,向下调用 training_step() 完成单批次训练,最终完成所有 epoch/step 的训练、结果汇总和清理,核心职责是搭建训练框架、处理分布式适配、实现梯度累积与断点续训。
| 核心职责 | 关键操作 |
|---|---|
| 训练前配置 | 显存预热、参数校准、模型包装、断点续训状态恢复 |
| 训练循环控制 | 三层嵌套循环(Epoch→梯度累积块→单Batch)、终止条件判断 |
| 全局状态管理 | 训练状态记录、日志/评估/保存触发、回调事件调度 |
| 分布式适配 | DDP/FSDP/DeepSpeed 模式支持、多设备梯度同步与损失聚合 |
方法2:training_step() - 单批次训练(执行具体训练步骤)
该方法是「单批次训练」的执行者,负责完成单批次数据的「前向传播 → 损失计算 → 反向传播准备」的闭环,是梯度更新的基础,被 _inner_training_loop() 循环调用。
核心执行流程
-
前置准备工作
- 上下文并行准备:调用
_prepare_context_parallel_inputs生成分布式上下文并行环境(cp_context)。 - 训练模式激活:调用
model.train()切换模型为训练模式(启用 Dropout、BatchNorm 等训练专属层行为),同步激活优化器训练模式(若支持)。 - 输入数据预处理:调用
_prepare_inputs完成设备迁移、数据格式转换,确保数据符合模型输入要求。
- 上下文并行准备:调用
-
核心步骤:损失计算与反向传播准备
- SageMaker 并行适配:若启用 SageMaker 模型并行,通过
smp_forward_backward完成前向与反向传播,返回聚合损失。 - 常规损失计算:通过
self.compute_loss()调用compute_loss()方法,完成模型前向传播与批次损失计算(借助compute_loss_context_manager()提供梯度累积上下文支持)。 - 显存优化:删除临时输入数据
inputs,若配置torch_empty_cache_steps,按间隔清理对应硬件设备的闲置显存(CUDA/XPU/TPU 等),避免 OOM。
- SageMaker 并行适配:若启用 SageMaker 模型并行,通过
-
多设备与混合精度适配
- 多 GPU 损失聚合:若
n_gpu > 1,通过loss.mean()对各 GPU 损失求平均,保证损失一致性。 - 混合精度反向传播:
- 启用 Apex 混合精度:通过
amp.scale_loss缩放损失,避免梯度下溢后执行反向传播。 - 启用 Accelerator 混合精度:通过
self.accelerator.backward()完成反向传播初始化,适配 DeepSpeed 分布式模式。
- 多 GPU 损失聚合:若
-
损失返回:通过
loss.detach()使损失张量脱离计算图(不参与后续不必要的梯度传播),返回该批次的损失值。
核心关键点
- 职责明确:仅负责单批次训练的完整步骤,不参与全局循环与梯度累积判断,保证逻辑简洁。
- 显存优化优先:包含多处显存清理逻辑,平衡训练效率与显存占用。
- 计算图管理:返回脱离计算图的损失值,节省计算资源,避免梯度泄露。
方法3:compute_loss() - 损失计算(专用工具方法)
该方法是「损失计算」的专用工具,负责处理模型输入、调用前向传播、支持自定义损失与标签平滑,最终返回批次损失,被 training_step() 调用。
核心执行流程
-
标签数据预处理
- 若配置标签平滑(
label_smoother)或自定义损失函数(compute_loss_func),且输入包含labels,则弹出labels单独保存(避免重复传入模型)。 - 否则,将
labels设为None,不单独处理。
- 若配置标签平滑(
-
模型前向传播
- 若模型支持损失相关关键字参数,将
num_items_in_batch(批次样本数)注入输入字典。 - 调用
model(**inputs)执行前向传播,得到模型输出outputs,可选保存过往状态(如 Transformer K/V 缓存)用于后续优化。
- 若模型支持损失相关关键字参数,将
-
多场景损失计算(优先级:自定义损失 > 标签平滑损失 > 模型自带损失)
- 场景 1(自定义损失):若
compute_loss_func已定义,调用该函数传入模型输出、标签、批次样本数,返回自定义损失。 - 场景 2(标签平滑损失):若无自定义损失但
labels非空,通过label_smoother计算损失(因果语言模型会执行标签移位shift_labels=True,避免看到未来标签)。 - 场景 3(模型自带损失):若既无自定义损失也无单独标签,从模型输出中提取损失(字典输出取
outputs["loss"],元组输出取outputs[0]),无法提取则抛出异常。
- 场景 1(自定义损失):若
-
多设备适配与结果返回
- 若配置跨设备令牌平均,将损失值乘以进程数/GPU 数,保证多设备损失计算准确性。
- 根据
return_outputs参数,返回「仅损失」或「(损失,模型输出)」。
核心关键点
- 专一性:仅负责损失计算,不参与反向传播与参数更新,便于用户自定义修改。
- 灵活性:支持三种损失计算场景,满足不同任务(分类、生成、微调等)的需求。
- 鲁棒性:充分考虑多设备、不同模型输出格式的适配,减少训练报错概率。
三大方法的调用关系与分工总结
| 方法名 | 核心定位 | 被调用者 | 调用者 | 核心输出 |
|---|---|---|---|---|
_inner_training_loop() |
训练主循环(统筹全局) | train() |
training_step() |
训练状态、总损失、各类指标 |
training_step() |
单批次训练(执行具体步骤) | _inner_training_loop() |
compute_loss() |
单批次损失值 |
compute_loss() |
损失计算(专用工具) | training_step() |
模型 forward() 方法 |
批次损失(可选模型输出) |
六、关键核心概念:Trainer 训练的「底层支撑」
-
梯度累积
- 核心逻辑:在
_inner_training_loop()的三层嵌套循环中,通过收集gradient_accumulation_steps个 Batch 的梯度,仅在最后一个 Batch 执行optimizer.step(),实现「小批次显存占用,大批次训练效果」。 - 核心价值:解决显存不足无法设置大批次的问题,平衡显存占用与模型训练稳定性(大批次通常能带来更好的收敛效果)。
- 核心逻辑:在
-
分布式训练
- 核心支撑:基于
Accelerator封装,无需用户手动编写分布式逻辑。 - 支持模式:分布式数据并行(DDP)、完全分片数据并行(FSDP)、DeepSpeed、SageMaker 模型并行、TPU/XLA 并行等。
- 核心适配:自动处理模型包装、梯度同步、数据分发、损失聚合,保证多设备/多进程训练的一致性。
- 核心支撑:基于
-
回调系统
- 核心载体:
CallbackHandler管理各类回调函数。 - 关键事件:
on_train_begin(训练开始)、on_epoch_begin/end(epoch 开始/结束)、on_step_end(步骤结束)、on_train_end(训练结束)。 - 核心价值:支持用户自定义扩展训练逻辑(如早停、自定义日志推送、额外评估),无需修改
Trainer核心源码。
- 核心载体:
-
Checkpoint 机制
- 保存内容:模型权重、优化器状态、学习率调度器状态、
TrainerState(训练状态)、RNG 随机状态。 - 核心方法:
_save_checkpoint()(保存)、_load_from_checkpoint()(加载)。 - 核心价值:实现断点续训(避免训练中断后重新开始)、保存最佳模型、后续模型微调与部署的基础。
- 保存内容:模型权重、优化器状态、学习率调度器状态、
-
混合精度训练
- 核心模式:fp16(半精度)、bf16(脑浮点数)。
- 核心优化:通过损失缩放(避免梯度下溢)、自动精度转换,在保证模型收敛效果的前提下,大幅节省显存占用、提升训练速度。
- 适配方式:基于
Accelerator或 Apex 封装,用户仅需配置fp16=True或bf16=True即可启用。
七、整体总结
- 流程规范:
Trainer训练流程遵循「初始化 → 核心循环 → 收尾」的三段式架构,逻辑清晰、层层递进,保证了训练的规范性和可复现性。 - 分工明确:三大核心方法构成「三级调用链路」,各司其职,既保证了核心逻辑的封装性,又为用户自定义提供了灵活入口。
- 功能完备:内置梯度累积、分布式训练、断点续训、混合精度等核心功能,无需用户关注底层实现细节,降低上手门槛。
- 扩展性强:通过自定义
compute_loss()、重写training_step()、自定义回调函数,可满足各类个性化训练需求(如特殊任务损失、自定义优化逻辑)。
Trainer 类的核心价值在于封装了复杂的底层训练逻辑,将用户从繁琐的分布式适配、梯度管理、显存优化中解放出来,专注于模型与任务本身,是 Hugging Face Transformers 库中自然语言处理、计算机视觉等任务训练的核心基石。
Hugging Face Trainer 核心流程关键节点对照表
这份对照表整合了训练全流程的阶段划分、核心操作、输入输出、关键方法、注意事项,方便你快速查阅和定位核心逻辑。
| 层级 | 名称 | 核心操作 | 输入 | 输出 | 关键调用方法 | 注意事项 |
|---|---|---|---|---|---|---|
| 全局入口 | train() 方法 |
串联三大阶段,完成训练启动与结果返回 | TrainingArguments、模型、数据集 |
TrainOutput(步数、损失、指标) |
_inner_training_loop() |
是所有训练逻辑的统一入口,前置完成初始化校验 |
| 阶段1 | 初始化与准备工作 | 1. 参数/环境校验 2. 模型/优化器/调度器初始化 3. 数据加载预处理 4. 断点续训加载 |
配置参数、空白模型、原始数据集 | 校验后的训练环境、初始化的模型/优化器、预处理数据加载器、恢复的训练状态 | create_optimizer_and_scheduler()get_train_dataloader()_load_from_checkpoint() |
1. 自动过滤模型不接受的输入列 2. 总批次大小 = 单设备批次 × 梯度累积 × 进程数 3. 断点续训需校验参数一致性 |
| 阶段2.1 | 训练循环前置准备 | 1. 显存预热与批量适配 2. 核心训练参数计算 3. 分布式模型包装 4. 训练状态校准 |
阶段1输出的环境/数据/模型 | 适配分布式的模型/优化器、校准后的训练步数(max_steps)、跳过已训练数据的偏移量 |
set_initial_training_values()_wrap_model()accelerator.prepare() |
1. max_steps 优先级高于 num_train_epochs2. 梯度检查点会牺牲速度换显存 3. 不同分布式模式(DDP/FSDP)的模型包装逻辑不同 |
| 阶段2.2 | Epoch 外层循环 | 1. 每轮 epoch 初始化(种子/缓存) 2. 触发 epoch 前后回调 3. 控制 epoch 迭代终止 |
校准后的训练状态、数据加载器 | 每轮 epoch 的损失累积值、更新的训练状态 | _epoch_setup()_epoch_end_callback() |
1. 设置 DataLoader epoch 保证 shuffle 一致性 2. 支持用户自定义 epoch 级操作(回调) |
| 阶段2.2 | 梯度累积块中层循环 | 1. 按累积步数划分 Batch 块 2. 处理最后一块的边界 Batch |
单 epoch 数据迭代器 | 累积块的 Batch 列表、块内样本数 | get_batch_samples() |
核心目的:小显存实现大批次效果,最后一块自动适配剩余 Batch 数 |
| 阶段2.2 | 单 Batch 内层循环 | 1. 单批次训练(前向/损失/反向) 2. 梯度同步控制 3. 损失累积 |
单个 Batch 数据、包装后的模型 | 单 Batch 损失值、累积的梯度 | training_step()accelerator.no_sync() |
非最后一个 Batch 禁用梯度同步,减少通信开销 |
| 阶段2.2 | 参数更新步骤 | 1. 梯度裁剪 2. 优化器更新 3. 学习率调度 4. 梯度清零 |
累积的梯度、优化器/调度器 | 更新后的模型参数、更新的学习率、+1 的全局步数 | accelerator.clip_grad_norm_()optimizer.step()lr_scheduler.step() |
1. 梯度裁剪防止梯度爆炸 2. 仅在累积块完成后执行一次更新 |
| 阶段2.2 | 日志/评估/保存 | 1. 训练指标记录 2. 验证集评估 3. Checkpoint 保存 4. 超参搜索汇报 |
训练状态、模型、验证集 | 日志指标、最佳模型指标、保存的 Checkpoint | _maybe_log_save_evaluate()_evaluate()_save_checkpoint() |
1. 按 eval_steps/save_steps 触发2. save_total_limit 自动清理过期 Checkpoint |
| 阶段2.3 | 训练循环终止 | 1. 判断终止条件 2. 临时资源清理 |
全局步数、回调控制状态 | 终止信号、清理后的临时内存 | - | 终止条件:达到 max_steps / 回调早停 / epoch 遍历完成 |
| 阶段3 | 收尾与保存 | 1. 加载最佳模型(可选) 2. 训练指标汇总 3. 资源清理与 Checkpoint 瘦身 4. 模型推送到 Hub |
训练状态、保存的 Checkpoint | 最佳模型权重、训练指标字典、清理后的存储空间 | _load_best_model()speed_metrics()push_to_hub() |
1. 最佳模型基于 metric_for_best_model 判定2. 停用 NEFTune 钩子,释放显存 |
| 核心方法 | training_step() |
单 Batch 完整训练步骤:前向→损失→反向→显存优化 | 模型、单 Batch 数据、样本数 | 脱离计算图的单 Batch 损失值 | compute_loss()accelerator.backward() |
1. 自动适配 SageMaker 并行 2. 定期清理显存避免 OOM |
| 核心方法 | compute_loss() |
多场景损失计算:自定义损失→标签平滑→模型自带损失 | 模型、Batch 数据、标签(可选) | 批次损失值(可选模型输出) | model(**inputs)label_smoother() |
1. 因果语言模型自动执行标签移位 2. 支持用户自定义损失函数 |

浙公网安备 33010602011771号