全部文章

模型保存与加载

一个模型,我们训练了很长时间,我们不可能使用的时候再次训练,那么如何保存呢?

训练好的模型保存非常重要,这样可以在需要时直接加载使用,无需重新训练。以下是保存和加载模型的常用方法:

1. Python内置模块:pickle

import pickle

# 保存模型
with open('model.pkl', 'wb') as f:
    pickle.dump(trained_model, f)

# 加载模型
with open('model.pkl', 'rb') as f:
    loaded_model = pickle.load(f)

2. joblib (适合大型numpy数组的模型)

from joblib import dump, load

# 保存模型
dump(trained_model, 'model.joblib')

# 加载模型
loaded_model = load('model.joblib')

3. 框架特定保存方法

Scikit-learn 模型

from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
import joblib

# 训练模型
iris = load_iris()
X, y = iris.data, iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y)
model = RandomForestClassifier().fit(X_train, y_train)

# 保存
joblib.dump(model, 'rf_model.pkl')

# 加载
loaded_model = joblib.load('rf_model.pkl')

TensorFlow/Keras 模型

# 保存整个模型
model.save('my_model.keras')  # 或者 my_model.h5

# 加载
from tensorflow.keras.models import load_model
loaded_model = load_model('my_model.keras')

# 保存权重
model.save_weights('model_weights.weights.h5')

# 加载权重
model.load_weights('model_weights.weights.h5')

PyTorch 模型

# 保存整个模型
torch.save(model, 'model.pth')

# 加载
loaded_model = torch.load('model.pth')

# 保存模型状态字典(推荐)
torch.save(model.state_dict(), 'model_state.pth')

# 加载状态字典
model = ModelClass()  # 必须先重新创建模型结构
model.load_state_dict(torch.load('model_state.pth'))
model.eval()  # 切换为评估模式

ONNX格式(跨框架)

 

# 将Scikit-learn模型转为ONNX
from skl2onnx import convert_sklearn
from skl2onnx.common.data_types import FloatTensorType

initial_type = [('float_input', FloatTensorType([None, 4]))]
onnx_model = convert_sklearn(model, initial_types=initial_type)

# 保存
with open("model.onnx", "wb") as f:
    f.write(onnx_model.SerializeToString())

# 加载(使用ONNX Runtime)
import onnxruntime as ort

sess = ort.InferenceSession("model.onnx")
input_name = sess.get_inputs()[0].name
predictions = sess.run(None, {input_name: X_test.astype(np.float32)})[0]

5. 云端存储(如AWS S3, Google Cloud Storage)

import boto3

# 保存到S3
s3 = boto3.client('s3')
s3.upload_file('model.pkl', 'my-bucket', 'models/model.pkl')

# 从S3加载
s3.download_file('my-bucket', 'models/model.pkl', 'local_model.pkl')
with open('local_model.pkl', 'rb') as f:
    model = pickle.load(f)

最佳实践

  1. ​保存时机​​:在验证集性能最好时保存
  2. ​元数据保存​​:额外保存模型版本、训练参数、特征工程流程等信息
  3. ​完整保存​​:

 

import joblib
import datetime

model_info = {
    'model': trained_model,
    'version': '1.0',
    'training_date': datetime.datetime.now(),
    'features': ['age', 'income', ...],
    'metrics': {'accuracy': 0.95, 'f1': 0.92}
}

joblib.dump(model_info, 'model_with_metadata.pkl')

4.自动化部署​​:使用CI/CD管道自动部署模型更新

​5.版本控制​​:结合Git管理模型文件(注意大文件用Git LFS)

​6.​格式选择​​:

      • 生产环境:ONNX 或框架原生格式
      • 开发环境:pickle/joblib
      • 跨平台:ONNX

示例:完整的模型部署流程

# 训练和保存
import joblib
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import load_iris

iris = load_iris()
model = RandomForestClassifier()
model.fit(iris.data, iris.target)

joblib.dump(model, 'iris_classifier.joblib')

# -------------------------
# 在另一个脚本/应用中使用
from flask import Flask, request, jsonify
import joblib
import numpy as np

app = Flask(__name__)
model = joblib.load('iris_classifier.joblib')

@app.route('/predict', methods=['POST'])
def predict():
    data = request.json
    features = np.array(data['features']).reshape(1, -1)
    prediction = model.predict(features)[0]
    return jsonify({'class': int(prediction)})

if __name__ == '__main__':
    app.run(port=5000)

注意事项

  1. ​依赖管理​​:保持训练和预测环境相同(Python版本、库版本)
  2. ​安全考虑​​:不要加载不可信的序列化文件(pickle有安全风险)
  3. ​大文件处理​​:大型模型考虑使用分布式文件系统
  4. ​持续监控​​:部署后监控模型性能衰减
  5. ​反序列化兼容性​​:避免库版本升级导致加载失败

通过合理选择保存方法和遵循最佳实践,可以确保训练好的模型能够长期可靠地服务于生产环境。

posted @ 2025-05-31 20:14  指尖下的世界  阅读(49)  评论(0)    收藏  举报