• 博客园logo
  • 会员
  • 众包
  • 新闻
  • 博问
  • 闪存
  • 赞助商
  • HarmonyOS
  • Chat2DB
    • 搜索
      所有博客
    • 搜索
      当前博客
  • 写随笔 我的博客 短消息 简洁模式
    用户头像
    我的博客 我的园子 账号设置 会员中心 简洁模式 ... 退出登录
    注册 登录

LR233

  • 博客园
  • 联系
  • 订阅
  • 管理

公告

View Post

3、ModelCheckPoint

1、导包

1 from tensorflow.keras.callbacks import ModelCheckpoint

2、介绍

  在训练机器学习模型时,经常需要缓存模型。

  ModelCheckpoint是Pytorch Lightning中的一个Callback,它就是用于模型缓存的。

  它会监视某个指标,每次指标达到最好的时候,它就缓存当前模型。

  在每个epoch结束作为回调函数,保存模型。

3、参数介绍

3.1、monitor='val_loss', 我们想要监视的指标 ,val_acc或val_loss。

3.2、dirpath='my/path/', 模型缓存目录

3.3、verbose: 详细信息模式,0 或者1。 0为不打印输出信息,1为打印

3.4、save_best_only: True,将只保存在验证集上性能最好的模型mode: {auto, min, max} 的其中之一。是否覆盖保存文件的决定就取决于被监测数据的最大或者最小值。 

对于val_acc,模式就会是max;而对于val_loss,模式就需要是min。在auto模式中,方式会自动从被监测的数据的名字中判断出来。

3.5、save_weights_only: 如果 True,那么只有模型的权重会被保存 (model.save_weights(filepath)), 否则的话,整个模型会被保存 (model.save(filepath))。

3.6、period: 每个检查点之间的间隔(训练轮数)。

posted on 2022-10-24 21:27  LR233  阅读(1501)  评论(0)    收藏  举报

刷新页面返回顶部
 
博客园  ©  2004-2025
浙公网安备 33010602011771号 浙ICP备2021040463号-3