模型保存与加载
一个模型,我们训练了很长时间,我们不可能使用的时候再次训练,那么如何保存呢?
训练好的模型保存非常重要,这样可以在需要时直接加载使用,无需重新训练。以下是保存和加载模型的常用方法:
1. Python内置模块:pickle
import pickle
# 保存模型
with open('model.pkl', 'wb') as f:
pickle.dump(trained_model, f)
# 加载模型
with open('model.pkl', 'rb') as f:
loaded_model = pickle.load(f)
2. joblib (适合大型numpy数组的模型)
from joblib import dump, load
# 保存模型
dump(trained_model, 'model.joblib')
# 加载模型
loaded_model = load('model.joblib')
3. 框架特定保存方法
Scikit-learn 模型
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
import joblib
# 训练模型
iris = load_iris()
X, y = iris.data, iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y)
model = RandomForestClassifier().fit(X_train, y_train)
# 保存
joblib.dump(model, 'rf_model.pkl')
# 加载
loaded_model = joblib.load('rf_model.pkl')
TensorFlow/Keras 模型
# 保存整个模型
model.save('my_model.keras') # 或者 my_model.h5
# 加载
from tensorflow.keras.models import load_model
loaded_model = load_model('my_model.keras')
# 保存权重
model.save_weights('model_weights.weights.h5')
# 加载权重
model.load_weights('model_weights.weights.h5')
PyTorch 模型
# 保存整个模型
torch.save(model, 'model.pth')
# 加载
loaded_model = torch.load('model.pth')
# 保存模型状态字典(推荐)
torch.save(model.state_dict(), 'model_state.pth')
# 加载状态字典
model = ModelClass() # 必须先重新创建模型结构
model.load_state_dict(torch.load('model_state.pth'))
model.eval() # 切换为评估模式
# 将Scikit-learn模型转为ONNX
from skl2onnx import convert_sklearn
from skl2onnx.common.data_types import FloatTensorType
initial_type = [('float_input', FloatTensorType([None, 4]))]
onnx_model = convert_sklearn(model, initial_types=initial_type)
# 保存
with open("model.onnx", "wb") as f:
f.write(onnx_model.SerializeToString())
# 加载(使用ONNX Runtime)
import onnxruntime as ort
sess = ort.InferenceSession("model.onnx")
input_name = sess.get_inputs()[0].name
predictions = sess.run(None, {input_name: X_test.astype(np.float32)})[0]
5. 云端存储(如AWS S3, Google Cloud Storage)
import boto3
# 保存到S3
s3 = boto3.client('s3')
s3.upload_file('model.pkl', 'my-bucket', 'models/model.pkl')
# 从S3加载
s3.download_file('my-bucket', 'models/model.pkl', 'local_model.pkl')
with open('local_model.pkl', 'rb') as f:
model = pickle.load(f)
import joblib
import datetime
model_info = {
'model': trained_model,
'version': '1.0',
'training_date': datetime.datetime.now(),
'features': ['age', 'income', ...],
'metrics': {'accuracy': 0.95, 'f1': 0.92}
}
joblib.dump(model_info, 'model_with_metadata.pkl')
4.自动化部署:使用CI/CD管道自动部署模型更新
5.版本控制:结合Git管理模型文件(注意大文件用Git LFS)
6.格式选择:
-
-
- 生产环境:ONNX 或框架原生格式
- 开发环境:pickle/joblib
- 跨平台:ONNX
-
示例:完整的模型部署流程
# 训练和保存
import joblib
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import load_iris
iris = load_iris()
model = RandomForestClassifier()
model.fit(iris.data, iris.target)
joblib.dump(model, 'iris_classifier.joblib')
# -------------------------
# 在另一个脚本/应用中使用
from flask import Flask, request, jsonify
import joblib
import numpy as np
app = Flask(__name__)
model = joblib.load('iris_classifier.joblib')
@app.route('/predict', methods=['POST'])
def predict():
data = request.json
features = np.array(data['features']).reshape(1, -1)
prediction = model.predict(features)[0]
return jsonify({'class': int(prediction)})
if __name__ == '__main__':
app.run(port=5000)
注意事项
- 依赖管理:保持训练和预测环境相同(Python版本、库版本)
- 安全考虑:不要加载不可信的序列化文件(pickle有安全风险)
- 大文件处理:大型模型考虑使用分布式文件系统
- 持续监控:部署后监控模型性能衰减
- 反序列化兼容性:避免库版本升级导致加载失败
通过合理选择保存方法和遵循最佳实践,可以确保训练好的模型能够长期可靠地服务于生产环境。