tensorflow2.0 - 保存模型(含自定义模型的保存)

tensorflow2.0保存模型的方式有很多,这里只介绍两种。

一、 使用官方模型

这种情况可以直接保存整个模型,如下所示,可以将模型保存为HDF5文件

# 创建模型实例
model = create_model()
# 保存模型到HDF5文件
model.save('my_model.h5')
# 读取模型
model = keras.models.load_model('my_model.h5')

二、自定义模型

如果是自定义模型使用上述方法保存会报错且保存失败,报错为:

NotImplementedError: Saving the model to HDF5 format requires the model to be a Functional model or a Sequential model. It does not work for subclassed models, because such models are defined via the body of a Python method, which isn’t safely serializable. Consider saving to the Tensorflow SavedModel format (by setting save_format=“tf”) or using save_weights.

这种情况可以保存weight。

# 创建模型
model = create_model()
# 保存权重
model.save_weights('model_weight')
# 创建新模型读取权重
newModel = create_model()
# 读取权重到新模型
newModel.load_weights('model_weight')

参考文献:

posted @ 2020-02-04 15:18  _吟游诗人  阅读(3228)  评论(0编辑  收藏  举报