hf trainner

Hugging Face Trainer 类核心训练流程(完整梳理版)

你需要一份对 Hugging Face Trainer 类核心训练流程的完整梳理,保留所有核心细节且逻辑连贯,下面将从整体架构三大核心方法关键核心概念三个维度进行全面整合梳理:

一、整体架构:Trainer 训练的「三段式」全局流程

Trainer 类的所有训练逻辑均以 train() 方法为统一入口,整体遵循「初始化准备 → 核心训练循环 → 收尾保存」的三大阶段,各阶段层层递进、职责明确,全局调用链路和流程框架如下:

train() 入口方法
├── 阶段1:初始化与准备工作(训练前)
│   ├── 参数与环境校验
│   ├── 模型与优化器/调度器初始化
│   ├── 数据加载与预处理
│   └── 断点续训:Checkpoint 加载(可选)
├── 阶段2:核心训练循环(训练中,_inner_training_loop() 承载)
│   ├── 训练前二次初始化与配置
│   ├── 三层嵌套循环(Epoch → 梯度累积块 → 单Batch)
│   │   ├── 单Batch训练(training_step() 承载)
│   │   │   └── 损失计算(compute_loss() 承载)
│   │   ├── 梯度累积完成后:梯度裁剪→参数更新→学习率调度→梯度清零
│   │   └── 每步/每轮:日志记录→评估→保存Checkpoint
│   └── 训练循环终止判断
└── 阶段3:收尾与保存(训练后)
    ├── 最佳模型加载(可选)
    ├── 训练结果整理与指标汇总
    ├── 资源清理与Checkpoint瘦身
    ├── 模型推送至Hub(可选)
    └── 返回 TrainOutput 训练结果

二、阶段1:初始化与准备工作(train() 方法前期)

该阶段的核心目标是为后续训练循环搭建好基础环境,完成各类配置、资源的校验与初始化,避免训练过程中出现参数不匹配、资源缺失等问题,核心操作包括:

  1. 参数与环境校验

    • 校验 TrainingArguments 关键参数一致性(如 eval_strategyeval_dataset 匹配、save_strategymetric_for_best_model 匹配)。
    • 初始化分布式训练环境(Accelerator)、混合精度训练(fp16/bf16)、日志系统与回调管理器(CallbackHandler)。
    • 检查模型合法性(是否为可训练的 AutoModelForXxx 类),处理量化模型、PEFT 适配器模型的特殊逻辑。
  2. 模型与优化器/调度器初始化

    • 若传入 model_init,调用该方法实例化模型(支持超参搜索时重新初始化)。
    • 调用 create_optimizer_and_scheduler() 完成双层初始化:
      • create_optimizer():划分带权重衰减/不带权重衰减的参数组,支持 AdamW、Adafactor、BitsAndBytes 8bit 优化器等。
      • create_scheduler():基于 num_training_steps 初始化调度器,支持 linear/cosine 等多种调度策略。
    • 可选激活 NEFTune:通过 _activate_neftune() 为嵌入层注册前向钩子,注入噪声提升模型泛化性。
  3. 数据加载与预处理

    • 调用 get_train_dataloader() 构建训练数据加载器:
      • 自动移除模型 forward() 不接受的列(_remove_unused_columns()),避免前向传播报错。
      • 构建采样器(默认 RandomSampler,支持 LengthGroupedSampler 按序列长度分组采样,提升训练效率)。
      • 封装数据整理器(DataCollator,默认 default_data_collator,支持动态 padding,减少无效计算)。
    • 计算总批次大小:total_train_batch_size = 单设备批次 × 梯度累积步数 × 分布式进程数,作为后续训练步数计算的基础。
  4. 断点续训:Checkpoint 加载(可选)

    • 若指定 resume_from_checkpoint,调用 _load_from_checkpoint() 完整加载训练状态:
      • 模型权重、优化器状态、学习率调度器状态。
      • RNG 随机状态(保证续跑训练的可复现性)。
    • 校验检查点与当前训练参数的一致性(如批次大小、日志步数),避免训练逻辑冲突。

三、阶段2:核心训练循环(_inner_training_loop() 承载)

该阶段是 Trainer 训练的核心执行环节,承接阶段1的初始化准备,完成「多轮 epoch 迭代 → 梯度累积 → 单批次训练 → 参数更新 → 日志/评估/保存」的完整闭环,最终触发训练终止条件。

子阶段2.1:训练循环前置准备(二次初始化与配置)

在进入 epoch 循环前,完成分布式适配、训练状态校准、断点续训数据跳过等关键准备,确保训练启动的稳定性。

  1. 显存预热与批量大小适配
    • 调用 self.accelerator.free_memory() 释放加速器闲置显存,避免初始化阶段显存堆积。
    • 若启用 auto_find_batch_size,自动调整批量大小并更新 DeepSpeed 配置,释放冗余内存。
  2. 数据加载器与训练参数校准
    • 获取训练数据加载器 train_dataloader,对 FSDP-XLA v2 等分布式模式做 TPU 适配处理。
    • 调用 set_initial_training_values() 计算核心控制参数:
      • 总训练 epoch 数 num_train_epochs、每 epoch 优化步数 num_update_steps_per_epoch
      • 总优化步数 max_steps(优先级高于 epoch 数,达到后直接终止训练)。
      • 判定训练模式:epoch_based(按 epoch 训练)或 step_based(按步数训练)。
    • 若启用 include_tokens_per_second,统计总训练令牌数 num_train_tokens,用于后续速度评估。
  3. 调试模式与分布式兼容性检查
    • 若启用 DebugOption.UNDERFLOW_OVERFLOW 下溢/上溢调试,校验分布式模式(仅支持 DDP,不支持 DP),初始化调试工具 DebugUnderflowOverflow
    • 判定优化器延迟创建逻辑:SageMaker MP、FSDP/XLA 等模式需延迟创建,FSDP2 模式禁用延迟创建。
  4. 优化器与调度器最终初始化
    • 若启用 DeepSpeed,通过 deepspeed_init() 完成优化器、调度器的分布式初始化。
    • 若非延迟创建模式,调用 create_optimizer_and_scheduler() 基于 max_steps 创建优化器(默认 AdamW)和学习率调度器(默认线性暖场)。
  5. 模型包装与分布式适配
    • 初始化 TrainerState 训练状态对象,记录全局步数、最佳指标、回调状态等,计算日志/评估/保存的绝对触发步数。
    • 若启用 gradient_checkpointing,开启模型梯度检查点功能(牺牲少量速度,大幅节省显存)。
    • 调用 _wrap_model() 包装模型,适配 DDP/FSDP/DeepSpeed 等分布式模式,通过 accelerator.prepare() 完成模型、优化器、调度器的设备迁移与并行适配。
  6. 断点续训:训练状态完整恢复
    • 若指定 resume_from_checkpoint,根据分布式模式加载模型权重、优化器状态、调度器状态、梯度缩放器状态。
    • 加载 TrainerState,计算已训练 epoch 数 epochs_trained 和当前 epoch 已训练步数,生成数据跳过偏移量,避免重复训练已处理数据。
  7. 训练配置日志打印
    输出核心训练信息,方便用户校验配置:
    • 数据维度:总样本数、每 epoch 样本数、总令牌数(若统计)。
    • 训练维度:单设备批次大小、总批次大小(含并行/梯度累积)、梯度累积步数、总优化步数。
    • 模型维度:可训练参数数量、分布式模式、混合精度类型(fp16/bf16)。

子阶段2.2:三层嵌套训练循环(核心执行逻辑)

采用 「Epoch 外层循环 → 梯度累积块中层循环 → 单 Batch 内层循环」 的三层结构,核心是实现梯度累积,平衡显存占用与训练效果。

# 外层:Epoch 循环(遍历训练轮数)
for epoch in range(epochs_trained, num_train_epochs):
    # Epoch 前置准备
    self._epoch_setup(epoch)  # 设置 DataLoader epoch、重置 K/V 缓存等
    epoch_iterator = self.get_train_dataloader()
    total_updates = self._get_num_updates_per_epoch(epoch)  # 本轮 epoch 需完成的更新步数
    tr_loss_epoch = 0.0  # 本轮 epoch 损失累积
    
    # 中层:梯度累积块循环(按梯度累积步数分块)
    for update_step in range(total_updates):
        # 获取当前累积块的所有 Batch 数据(数量 = gradient_accumulation_steps)
        batch_samples, num_items_in_batch = self.get_batch_samples(
            epoch_iterator, args.gradient_accumulation_steps, args.device
        )
        current_grad_accum_steps = len(batch_samples)  # 适配最后一块的剩余 Batch
        tr_loss_block = 0.0  # 本累积块损失累积
        
        # 内层:单 Batch 循环(遍历累积块中的每个 Batch)
        for i, inputs in enumerate(batch_samples):
            # -------------------------- 单 Batch 训练核心 --------------------------
            # 1. 令牌数统计(若启用)
            if args.include_tokens_per_second:
                num_tokens = self._get_num_tokens(inputs)
                self.state.total_flos += self._compute_flos(inputs, num_tokens)
                self.state.num_tokens_seen += num_tokens
            
            # 2. 梯度同步上下文管理(非最后一个 Batch 禁用梯度同步,提升效率)
            sync_grad = (i == current_grad_accum_steps - 1) or (self.state.global_step + 1 >= args.max_steps)
            with self.accelerator.no_sync(model) if not sync_grad else nullcontext():
                # 3. 调用 training_step 完成单 Batch 训练:前向传播→损失计算→反向传播
                tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
                tr_loss_block += tr_loss_step
                tr_loss_epoch += tr_loss_step
            
            # -------------------------- 梯度累积完成:参数更新 --------------------------
            if sync_grad:
                # 梯度裁剪:防止梯度爆炸(适配 Apex/DeepSpeed 等特殊模式)
                if args.max_grad_norm > 0:
                    self.accelerator.clip_grad_norm_(model.parameters(), args.max_grad_norm)
                
                # 触发优化器步骤前回调
                self.callback_handler.on_pre_optimizer_step(self.args, self.state, self.control)
                
                # 优化器更新参数 + 学习率调度
                self.optimizer.step()
                if not isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
                    self.lr_scheduler.step()
                
                # 触发优化器步骤后回调
                self.callback_handler.on_optimizer_step(self.args, self.state, self.control)
                
                # 梯度清零,准备下一个累积块
                model.zero_grad()
                self.state.global_step += 1  # 全局步数+1
                
                # -------------------------- 日志/评估/保存 --------------------------
                self._maybe_log_save_evaluate(
                    model, tr_loss_block / current_grad_accum_steps, epoch, update_step
                )
                
                # -------------------------- 训练终止判断 --------------------------
                if self.state.global_step >= args.max_steps or self.control.should_training_stop:
                    break
        if self.state.global_step >= args.max_steps or self.control.should_training_stop:
            break
    # Epoch 收尾:触发 epoch 结束回调、评估、保存
    self._epoch_end_callback(epoch, tr_loss_epoch)
    if self.state.global_step >= args.max_steps or self.control.should_training_stop:
        break

三层循环的核心细节

  1. Epoch 外层循环

    • 每轮 epoch 开始前调用 _epoch_setup():设置 DataLoader 的 epoch 种子(保证 shuffle 一致性)、重置 Transformer K/V 缓存等。
    • 触发 on_epoch_begin 回调,支持用户自定义 epoch 前置操作(如学习率调整、数据增强策略切换)。
    • 每轮 epoch 结束后调用 _epoch_end_callback():触发 on_epoch_end 回调、执行评估、保存 checkpoint。
  2. 梯度累积块中层循环

    • gradient_accumulation_steps 将 epoch 数据划分为多个累积块,每个块包含 N 个 Batch。
    • 核心目的:累积 N 个 Batch 的梯度后再执行一次参数更新,等价于「批量大小 × N」的训练效果,但显存占用仅为「单 Batch + 梯度缓存」。
    • 处理最后一个累积块的边界情况:若剩余 Batch 数不足 N,自动调整 current_grad_accum_steps,避免丢弃数据。
  3. 单 Batch 内层循环

    • 遍历累积块中的每个 Batch,调用 training_step() 完成单批次训练(前向传播→损失计算→反向传播)。
    • 梯度同步控制:非累积块最后一个 Batch 时,通过 accelerator.no_sync() 禁用梯度同步,减少通信开销。
    • 损失累积:将每个 Batch 的损失累加到块损失 tr_loss_block 和 epoch 损失 tr_loss_epoch

关键辅助操作:_maybe_log_save_evaluate()

每次参数更新后触发,是训练过程中「可视化、验证、固化成果」的核心入口:

  1. 日志记录:打印训练损失、学习率、梯度范数、显存占用等指标,支持 TensorBoard/WandB 等可视化工具。
  2. 模型评估:若达到 eval_steps 或 epoch 结束,调用 _evaluate() 运行验证集,通过 compute_metrics() 计算任务指标(如准确率、BLEU 分数),更新最佳模型指标。
  3. Checkpoint 保存:若达到 save_steps 或触发最佳指标,调用 _save_checkpoint() 保存模型权重、优化器状态、训练状态,支持 save_total_limit 清理过期 checkpoint。
  4. 超参搜索汇报:若启用超参搜索(Optuna/Ray),汇报当前试验结果,支持早停(Prune)效果差的试验。

子阶段2.3:训练循环终止与临时收尾

当满足以下任一条件时,立即终止训练循环:

  1. 全局步数 state.global_step >= max_steps(优先终止条件)。
  2. 回调控制 control.should_training_stop = True(用户自定义早停)。
  3. 所有 epoch 遍历完成。

终止后执行临时收尾:

  • 清理 Transformer K/V 缓存等临时状态,释放内存。
  • 统计训练过程中的总损失、总浮点运算量 total_flos 等指标。

四、阶段3:收尾与保存(训练后)

该阶段的核心目标是完成训练后的资源整理、结果固化,确保训练成果可复用、可追溯,核心操作包括:

  1. 最佳模型加载(可选)

    • load_best_model_at_end=True,调用 _load_best_model() 加载训练过程中基于 metric_for_best_model 判定的最佳模型 Checkpoint。
  2. 训练结果整理与指标汇总

    • 计算平均训练损失:平均损失 = _total_loss_scalar / effective_global_step(做保底处理,避免除零错误)。
    • 收集多维度训练指标:速度指标(每秒样本数、每秒令牌数)、浮点运算量(total_flos)、内存使用情况、可训练参数数量等。
  3. 资源清理与 Checkpoint 瘦身

    • 停用 NEFTune(若启用):调用 _deactivate_neftune() 移除嵌入层前向钩子,释放相关资源。
    • 根据 save_total_limit 删除过期 Checkpoint,仅保留最新/最佳的若干个,节省存储空间。
    • 清理过往状态缓存(_past,如 Transformer K/V 缓存)、释放加速器闲置显存。
  4. 模型推送与结果返回

    • push_to_hub=True,调用 push_to_hub() 上传模型、配置文件、训练参数到 Hugging Face Hub,实现模型共享。
    • 封装 TrainOutput 对象,返回全局步数、平均训练损失、训练指标字典,结束整个训练流程。

五、三大核心方法:Trainer 训练的「核心引擎」

_inner_training_loop()training_step()compute_loss()Trainer 类的三大核心方法,构成了训练流程的「三级调用链路」,也是整个训练逻辑的核心承载者,三者分工明确、层层调用。

方法1:_inner_training_loop() - 核心训练主循环(统筹全局)

该方法是「端到端训练」的核心统筹者,承接训练前的准备工作,向下调用 training_step() 完成单批次训练,最终完成所有 epoch/step 的训练、结果汇总和清理,核心职责是搭建训练框架、处理分布式适配、实现梯度累积与断点续训。

核心职责 关键操作
训练前配置 显存预热、参数校准、模型包装、断点续训状态恢复
训练循环控制 三层嵌套循环(Epoch→梯度累积块→单Batch)、终止条件判断
全局状态管理 训练状态记录、日志/评估/保存触发、回调事件调度
分布式适配 DDP/FSDP/DeepSpeed 模式支持、多设备梯度同步与损失聚合

方法2:training_step() - 单批次训练(执行具体训练步骤)

该方法是「单批次训练」的执行者,负责完成单批次数据的「前向传播 → 损失计算 → 反向传播准备」的闭环,是梯度更新的基础,被 _inner_training_loop() 循环调用。

核心执行流程

  1. 前置准备工作

    • 上下文并行准备:调用 _prepare_context_parallel_inputs 生成分布式上下文并行环境(cp_context)。
    • 训练模式激活:调用 model.train() 切换模型为训练模式(启用 Dropout、BatchNorm 等训练专属层行为),同步激活优化器训练模式(若支持)。
    • 输入数据预处理:调用 _prepare_inputs 完成设备迁移、数据格式转换,确保数据符合模型输入要求。
  2. 核心步骤:损失计算与反向传播准备

    • SageMaker 并行适配:若启用 SageMaker 模型并行,通过 smp_forward_backward 完成前向与反向传播,返回聚合损失。
    • 常规损失计算:通过 self.compute_loss() 调用 compute_loss() 方法,完成模型前向传播与批次损失计算(借助 compute_loss_context_manager() 提供梯度累积上下文支持)。
    • 显存优化:删除临时输入数据 inputs,若配置 torch_empty_cache_steps,按间隔清理对应硬件设备的闲置显存(CUDA/XPU/TPU 等),避免 OOM。
  3. 多设备与混合精度适配

    • 多 GPU 损失聚合:若 n_gpu > 1,通过 loss.mean() 对各 GPU 损失求平均,保证损失一致性。
    • 混合精度反向传播:
    • 启用 Apex 混合精度:通过 amp.scale_loss 缩放损失,避免梯度下溢后执行反向传播。
    • 启用 Accelerator 混合精度:通过 self.accelerator.backward() 完成反向传播初始化,适配 DeepSpeed 分布式模式。
  4. 损失返回:通过 loss.detach() 使损失张量脱离计算图(不参与后续不必要的梯度传播),返回该批次的损失值。

核心关键点

  • 职责明确:仅负责单批次训练的完整步骤,不参与全局循环与梯度累积判断,保证逻辑简洁。
  • 显存优化优先:包含多处显存清理逻辑,平衡训练效率与显存占用。
  • 计算图管理:返回脱离计算图的损失值,节省计算资源,避免梯度泄露。

方法3:compute_loss() - 损失计算(专用工具方法)

该方法是「损失计算」的专用工具,负责处理模型输入、调用前向传播、支持自定义损失与标签平滑,最终返回批次损失,被 training_step() 调用。

核心执行流程

  1. 标签数据预处理

    • 若配置标签平滑(label_smoother)或自定义损失函数(compute_loss_func),且输入包含 labels,则弹出 labels 单独保存(避免重复传入模型)。
    • 否则,将 labels 设为 None,不单独处理。
  2. 模型前向传播

    • 若模型支持损失相关关键字参数,将 num_items_in_batch(批次样本数)注入输入字典。
    • 调用 model(**inputs) 执行前向传播,得到模型输出 outputs,可选保存过往状态(如 Transformer K/V 缓存)用于后续优化。
  3. 多场景损失计算(优先级:自定义损失 > 标签平滑损失 > 模型自带损失)

    • 场景 1(自定义损失):若 compute_loss_func 已定义,调用该函数传入模型输出、标签、批次样本数,返回自定义损失。
    • 场景 2(标签平滑损失):若无自定义损失但 labels 非空,通过 label_smoother 计算损失(因果语言模型会执行标签移位 shift_labels=True,避免看到未来标签)。
    • 场景 3(模型自带损失):若既无自定义损失也无单独标签,从模型输出中提取损失(字典输出取 outputs["loss"],元组输出取 outputs[0]),无法提取则抛出异常。
  4. 多设备适配与结果返回

    • 若配置跨设备令牌平均,将损失值乘以进程数/GPU 数,保证多设备损失计算准确性。
    • 根据 return_outputs 参数,返回「仅损失」或「(损失,模型输出)」。

核心关键点

  • 专一性:仅负责损失计算,不参与反向传播与参数更新,便于用户自定义修改。
  • 灵活性:支持三种损失计算场景,满足不同任务(分类、生成、微调等)的需求。
  • 鲁棒性:充分考虑多设备、不同模型输出格式的适配,减少训练报错概率。

三大方法的调用关系与分工总结

方法名 核心定位 被调用者 调用者 核心输出
_inner_training_loop() 训练主循环(统筹全局) train() training_step() 训练状态、总损失、各类指标
training_step() 单批次训练(执行具体步骤) _inner_training_loop() compute_loss() 单批次损失值
compute_loss() 损失计算(专用工具) training_step() 模型 forward() 方法 批次损失(可选模型输出)

六、关键核心概念:Trainer 训练的「底层支撑」

  1. 梯度累积

    • 核心逻辑:在 _inner_training_loop() 的三层嵌套循环中,通过收集 gradient_accumulation_steps 个 Batch 的梯度,仅在最后一个 Batch 执行 optimizer.step(),实现「小批次显存占用,大批次训练效果」。
    • 核心价值:解决显存不足无法设置大批次的问题,平衡显存占用与模型训练稳定性(大批次通常能带来更好的收敛效果)。
  2. 分布式训练

    • 核心支撑:基于 Accelerator 封装,无需用户手动编写分布式逻辑。
    • 支持模式:分布式数据并行(DDP)、完全分片数据并行(FSDP)、DeepSpeed、SageMaker 模型并行、TPU/XLA 并行等。
    • 核心适配:自动处理模型包装、梯度同步、数据分发、损失聚合,保证多设备/多进程训练的一致性。
  3. 回调系统

    • 核心载体:CallbackHandler 管理各类回调函数。
    • 关键事件:on_train_begin(训练开始)、on_epoch_begin/end(epoch 开始/结束)、on_step_end(步骤结束)、on_train_end(训练结束)。
    • 核心价值:支持用户自定义扩展训练逻辑(如早停、自定义日志推送、额外评估),无需修改 Trainer 核心源码。
  4. Checkpoint 机制

    • 保存内容:模型权重、优化器状态、学习率调度器状态、TrainerState(训练状态)、RNG 随机状态。
    • 核心方法:_save_checkpoint()(保存)、_load_from_checkpoint()(加载)。
    • 核心价值:实现断点续训(避免训练中断后重新开始)、保存最佳模型、后续模型微调与部署的基础。
  5. 混合精度训练

    • 核心模式:fp16(半精度)、bf16(脑浮点数)。
    • 核心优化:通过损失缩放(避免梯度下溢)、自动精度转换,在保证模型收敛效果的前提下,大幅节省显存占用、提升训练速度。
    • 适配方式:基于 Accelerator 或 Apex 封装,用户仅需配置 fp16=Truebf16=True 即可启用。

七、整体总结

  1. 流程规范Trainer 训练流程遵循「初始化 → 核心循环 → 收尾」的三段式架构,逻辑清晰、层层递进,保证了训练的规范性和可复现性。
  2. 分工明确:三大核心方法构成「三级调用链路」,各司其职,既保证了核心逻辑的封装性,又为用户自定义提供了灵活入口。
  3. 功能完备:内置梯度累积、分布式训练、断点续训、混合精度等核心功能,无需用户关注底层实现细节,降低上手门槛。
  4. 扩展性强:通过自定义 compute_loss()、重写 training_step()、自定义回调函数,可满足各类个性化训练需求(如特殊任务损失、自定义优化逻辑)。

Trainer 类的核心价值在于封装了复杂的底层训练逻辑,将用户从繁琐的分布式适配、梯度管理、显存优化中解放出来,专注于模型与任务本身,是 Hugging Face Transformers 库中自然语言处理、计算机视觉等任务训练的核心基石。

Hugging Face Trainer 核心流程关键节点对照表

这份对照表整合了训练全流程的阶段划分、核心操作、输入输出、关键方法、注意事项,方便你快速查阅和定位核心逻辑。

层级 名称 核心操作 输入 输出 关键调用方法 注意事项
全局入口 train() 方法 串联三大阶段,完成训练启动与结果返回 TrainingArguments、模型、数据集 TrainOutput(步数、损失、指标) _inner_training_loop() 是所有训练逻辑的统一入口,前置完成初始化校验
阶段1 初始化与准备工作 1. 参数/环境校验
2. 模型/优化器/调度器初始化
3. 数据加载预处理
4. 断点续训加载
配置参数、空白模型、原始数据集 校验后的训练环境、初始化的模型/优化器、预处理数据加载器、恢复的训练状态 create_optimizer_and_scheduler()
get_train_dataloader()
_load_from_checkpoint()
1. 自动过滤模型不接受的输入列
2. 总批次大小 = 单设备批次 × 梯度累积 × 进程数
3. 断点续训需校验参数一致性
阶段2.1 训练循环前置准备 1. 显存预热与批量适配
2. 核心训练参数计算
3. 分布式模型包装
4. 训练状态校准
阶段1输出的环境/数据/模型 适配分布式的模型/优化器、校准后的训练步数(max_steps)、跳过已训练数据的偏移量 set_initial_training_values()
_wrap_model()
accelerator.prepare()
1. max_steps 优先级高于 num_train_epochs
2. 梯度检查点会牺牲速度换显存
3. 不同分布式模式(DDP/FSDP)的模型包装逻辑不同
阶段2.2 Epoch 外层循环 1. 每轮 epoch 初始化(种子/缓存)
2. 触发 epoch 前后回调
3. 控制 epoch 迭代终止
校准后的训练状态、数据加载器 每轮 epoch 的损失累积值、更新的训练状态 _epoch_setup()
_epoch_end_callback()
1. 设置 DataLoader epoch 保证 shuffle 一致性
2. 支持用户自定义 epoch 级操作(回调)
阶段2.2 梯度累积块中层循环 1. 按累积步数划分 Batch 块
2. 处理最后一块的边界 Batch
单 epoch 数据迭代器 累积块的 Batch 列表、块内样本数 get_batch_samples() 核心目的:小显存实现大批次效果,最后一块自动适配剩余 Batch 数
阶段2.2 单 Batch 内层循环 1. 单批次训练(前向/损失/反向)
2. 梯度同步控制
3. 损失累积
单个 Batch 数据、包装后的模型 单 Batch 损失值、累积的梯度 training_step()
accelerator.no_sync()
非最后一个 Batch 禁用梯度同步,减少通信开销
阶段2.2 参数更新步骤 1. 梯度裁剪
2. 优化器更新
3. 学习率调度
4. 梯度清零
累积的梯度、优化器/调度器 更新后的模型参数、更新的学习率、+1 的全局步数 accelerator.clip_grad_norm_()
optimizer.step()
lr_scheduler.step()
1. 梯度裁剪防止梯度爆炸
2. 仅在累积块完成后执行一次更新
阶段2.2 日志/评估/保存 1. 训练指标记录
2. 验证集评估
3. Checkpoint 保存
4. 超参搜索汇报
训练状态、模型、验证集 日志指标、最佳模型指标、保存的 Checkpoint _maybe_log_save_evaluate()
_evaluate()
_save_checkpoint()
1. 按 eval_steps/save_steps 触发
2. save_total_limit 自动清理过期 Checkpoint
阶段2.3 训练循环终止 1. 判断终止条件
2. 临时资源清理
全局步数、回调控制状态 终止信号、清理后的临时内存 - 终止条件:达到 max_steps / 回调早停 / epoch 遍历完成
阶段3 收尾与保存 1. 加载最佳模型(可选)
2. 训练指标汇总
3. 资源清理与 Checkpoint 瘦身
4. 模型推送到 Hub
训练状态、保存的 Checkpoint 最佳模型权重、训练指标字典、清理后的存储空间 _load_best_model()
speed_metrics()
push_to_hub()
1. 最佳模型基于 metric_for_best_model 判定
2. 停用 NEFTune 钩子,释放显存
核心方法 training_step() 单 Batch 完整训练步骤:前向→损失→反向→显存优化 模型、单 Batch 数据、样本数 脱离计算图的单 Batch 损失值 compute_loss()
accelerator.backward()
1. 自动适配 SageMaker 并行
2. 定期清理显存避免 OOM
核心方法 compute_loss() 多场景损失计算:自定义损失→标签平滑→模型自带损失 模型、Batch 数据、标签(可选) 批次损失值(可选模型输出) model(**inputs)
label_smoother()
1. 因果语言模型自动执行标签移位
2. 支持用户自定义损失函数

posted @ 2026-01-13 15:44  玉米面手雷王  阅读(2)  评论(0)    收藏  举报