数字识别模型
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, classification_report
import joblib
import time
class DigitRecognizer:
def init(self):
self.model = None
self.training_time = None
def load_data(self, n_samples=10000):
"""加载MNIST数据集"""
print("正在下载MNIST数据集...")
try:
mnist = fetch_openml('mnist_784', version=1, as_frame=False, parser='auto')
X, y = mnist.data, mnist.target.astype(int)
# 使用部分数据加快训练速度
if n_samples < len(X):
X = X[:n_samples]
y = y[:n_samples]
# 数据预处理
X = X / 255.0 # 归一化到0-1
# 分割数据集
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42, stratify=y
)
print(f"数据加载完成: 训练集 {X_train.shape[0]} 样本, 测试集 {X_test.shape[0]} 样本")
return (X_train, y_train), (X_test, y_test)
except Exception as e:
print(f"数据加载错误: {e}")
return self.load_backup_data()
def load_backup_data(self):
"""备用数据加载方法"""
print("使用备用数据生成方法...")
# 生成简单的模拟数据用于演示
from sklearn.datasets import make_classification
X, y = make_classification(
n_samples=2000, n_features=64, n_informative=20,
n_redundant=10, n_classes=10, random_state=42
)
X = np.abs(X) / np.max(np.abs(X)) # 模拟像素值
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42, stratify=y
)
return (X_train, y_train), (X_test, y_test)
def build_model(self, model_type='random_forest'):
"""构建模型"""
if model_type == 'random_forest':
self.model = RandomForestClassifier(
n_estimators=100,
max_depth=20,
random_state=42,
n_jobs=-1 # 使用所有CPU核心
)
elif model_type == 'logistic':
from sklearn.linear_model import LogisticRegression
self.model = LogisticRegression(
multi_class='multinomial',
solver='lbfgs',
max_iter=1000,
random_state=42
)
print(f"使用 {model_type} 模型")
return self.model
def train(self, model_type='random_forest', n_samples=10000):
"""训练模型"""
print("开始训练数字识别模型...")
# 加载数据
(X_train, y_train), (X_test, y_test) = self.load_data(n_samples)
# 构建模型
self.build_model(model_type)
# 训练模型
start_time = time.time()
print("训练中...")
self.model.fit(X_train, y_train)
self.training_time = time.time() - start_time
# 评估模型
print("评估模型性能...")
y_pred = self.model.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
print(f"\n=== 模型训练完成 ===")
print(f"训练时间: {self.training_time:.2f} 秒")
print(f"测试准确率: {accuracy:.4f}")
print(f"错误率: {(1-accuracy):.4f}")
# 显示详细报告
print("\n分类报告:")
print(classification_report(y_test, y_pred))
return accuracy
def predict_single(self, image_array):
"""预测单个数字"""
if self.model is None:
print("请先训练模型!")
return None
# 确保输入格式正确
if len(image_array.shape) == 2:
image_flat = image_array.flatten()
else:
image_flat = image_array
# 归一化
image_flat = image_flat / 255.0
# 预测
prediction = self.model.predict([image_flat])
probabilities = self.model.predict_proba([image_flat])
confidence = np.max(probabilities)
return prediction[0], confidence, probabilities[0]
def show_sample_predictions(self, n_samples=5):
"""显示样本预测结果"""
(X_train, y_train), (X_test, y_test) = self.load_data(1000)
print(f"\n随机显示 {n_samples} 个测试样本的预测结果:")
indices = np.random.choice(len(X_test), n_samples, replace=False)
for i, idx in enumerate(indices):
test_image = X_test[idx]
true_label = y_test[idx]
pred, confidence, probs = self.predict_single(test_image)
print(f"样本 {i+1}: 真实值={true_label}, 预测值={pred}, 置信度={confidence:.4f}")
if true_label != pred:
print(f" ✗ 预测错误! 各类别概率: {[f'{p:.3f}' for p in probs]}")
def save_model(self, filename='digit_recognizer_rf.pkl'):
"""保存模型"""
if self.model is not None:
joblib.dump(self.model, filename)
print(f"\n模型已保存为: {filename}")
else:
print("没有训练好的模型可保存")
def load_model(self, filename='digit_recognizer_rf.pkl'):
"""加载模型"""
try:
self.model = joblib.load(filename)
print(f"模型已从 {filename} 加载")
except FileNotFoundError:
print(f"模型文件 {filename} 不存在")
使用示例
def main():
print("=" * 50)
print(" 手写数字识别模型 (使用Scikit-learn)")
print("=" * 50)
# 创建识别器
recognizer = DigitRecognizer()
# 训练模型(使用较小的数据集加快速度)
accuracy = recognizer.train(
model_type='random_forest', # 可选 'random_forest' 或 'logistic'
n_samples=5000 # 使用5000个样本,可根据需要调整
)
# 显示样本预测
recognizer.show_sample_predictions(3)
# 保存模型
recognizer.save_model()
print(f"\n🎉 模型训练完成!最终准确率: {accuracy:.4f}")
快速测试版本
def quick_test():
"""快速测试版本"""
print("快速测试数字识别模型...")
from sklearn.datasets import load_digits
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
# 使用更小的数据集
digits = load_digits()
X, y = digits.data, digits.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# 训练模型
model = RandomForestClassifier(n_estimators=50, random_state=42)
model.fit(X_train, y_train)
# 评估
y_pred = model.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
print(f"快速测试准确率: {accuracy:.4f}")
# 保存模型
joblib.dump(model, 'quick_digit_model.pkl')
print("快速模型已保存为 'quick_digit_model.pkl'")
if name == "main":
# 运行完整版本
main()
# 或者运行快速测试版本
# quick_test()

浙公网安备 33010602011771号