手写汉字

最终无PCA版:解决维度错误

import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.svm import SVC
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from sklearn.pipeline import Pipeline
import time
import joblib

---------------------- 1. 核心配置 ----------------------

DATA_ROOT = "chinese_digit_dataset"
IMAGE_SIZE = (32, 32)
CLASS_NAMES = ["零", "一", "二", "三", "四", "五", "六", "七", "八", "九", "十", "佰", "仟", "万", "亿"]
NUM_CLASSES = len(CLASS_NAMES)
RANDOM_SEED = 42

---------------------- 2. 自动创建文件夹结构 ----------------------

if not os.path.exists(DATA_ROOT):
os.makedirs(DATA_ROOT)
for label in range(NUM_CLASSES):
class_folder = os.path.join(DATA_ROOT, str(label))
os.makedirs(class_folder)
print(f"✅ 已自动创建数据集文件夹:{DATA_ROOT}(包含0-14子文件夹)")
else:
print(f"📂 数据集文件夹已存在:{DATA_ROOT}")

---------------------- 3. 自动生成测试图片 ----------------------

def generate_test_images(data_root, class_count, img_size):
for label_idx in range(class_count):
class_folder = os.path.join(data_root, str(label_idx))
img_path = os.path.join(class_folder, f"{label_idx}.png")

    # 生成白色背景灰度图
    img = np.ones((img_size[0], img_size[1]), dtype=np.uint8) * 255
    # 绘制手写数字
    if label_idx == 0:  # 零
        cv2.circle(img, (img_size[0]//2, img_size[1]//2), img_size[0]//4, 0, 2)
    elif label_idx == 1:  # 一
        cv2.line(img, (img_size[0]//2, img_size[1]//4), (img_size[0]//2, img_size[1]//4*3), 0, 2)
    elif label_idx == 2:  # 二
        cv2.line(img, (img_size[0]//4, img_size[1]//2), (img_size[0]//4*3, img_size[1]//2), 0, 2)
        cv2.line(img, (img_size[0]//4*3, img_size[1]//2), (img_size[0]//4*3, img_size[1]//4*3), 0, 2)
        cv2.line(img, (img_size[0]//4*3, img_size[1]//4*3), (img_size[0]//4, img_size[1]//4*3), 0, 2)
    elif label_idx == 3:  # 三
        cv2.line(img, (img_size[0]//4, img_size[1]//3), (img_size[0]//4*3, img_size[1]//3), 0, 2)
        cv2.line(img, (img_size[0]//4, img_size[1]//2), (img_size[0]//4*3, img_size[1]//2), 0, 2)
        cv2.line(img, (img_size[0]//4, img_size[1]//3*2), (img_size[0]//4*3, img_size[1]//3*2), 0, 2)
    else:  # 其他数字
        text = str(label_idx) if label_idx <=9 else CLASS_NAMES[label_idx][0]
        cv2.putText(img, text, (img_size[0]//4, img_size[1]//4*3), 
                    cv2.FONT_HERSHEY_SIMPLEX, 0.8, 0, 2)
    
    cv2.imwrite(img_path, img)
print(f"✅ 已自动生成15张测试图片,保存到{DATA_ROOT}子文件夹中")

调用生成函数

generate_test_images(DATA_ROOT, NUM_CLASSES, IMAGE_SIZE)

---------------------- 4. 数据加载与预处理 ----------------------

def load_data(data_root, img_size):
features = []
labels = []

for label_idx in range(NUM_CLASSES):
    class_folder = os.path.join(data_root, str(label_idx))
    for img_name in os.listdir(class_folder):
        img_path = os.path.join(class_folder, img_name)
        if img_name.lower().endswith((".png", ".jpg", ".jpeg")):
            img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
            if img is not None:
                img = cv2.resize(img, img_size)
                features.append(img.flatten())
                labels.append(label_idx)

return np.array(features), np.array(labels)

加载数据

print("开始加载数据...")
X, y = load_data(DATA_ROOT, IMAGE_SIZE)
print(f"数据加载完成:共{X.shape[0]}张图像,特征维度{X.shape}")

划分训练集/测试集

X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.25, random_state=RANDOM_SEED
)

---------------------- 关键修改:移除PCA ----------------------

print("\n构建模型...")
model = Pipeline([
("scaler", StandardScaler()),
("svm", SVC(kernel="rbf", C=1.0, gamma="scale", probability=True, random_state=RANDOM_SEED))
])

print("开始训练模型...")
start_time = time.time()
model.fit(X_train, y_train)
print(f"训练完成!耗时:{time.time() - start_time:.2f}秒")

---------------------- 6. 模型评估 ----------------------

print("\n=== 模型评估结果 ===")
y_pred = model.predict(X_test)
print(f"测试集准确率:{accuracy_score(y_test, y_pred):.4f}")
print("\n分类报告:")
print(classification_report(y_test, y_pred, target_names=CLASS_NAMES, zero_division=0))

混淆矩阵可视化

def plot_confusion_matrix(y_true, y_pred, class_names):
cm = confusion_matrix(y_true, y_pred)
plt.figure(figsize=(10, 8))
plt.imshow(cm, interpolation="nearest", cmap=plt.cm.Blues)
plt.title("混淆矩阵", fontsize=14)
plt.colorbar()
plt.xticks(np.arange(len(class_names)), class_names, rotation=45)
plt.yticks(np.arange(len(class_names)), class_names)
plt.xlabel("预测标签", fontsize=12)
plt.ylabel("真实标签", fontsize=12)
plt.tight_layout()
plt.show()

plot_confusion_matrix(y_test, y_pred, CLASS_NAMES)

---------------------- 7. 保存模型 ----------------------

joblib.dump(model, "chinese_digit_svm_model.pkl")
print("\n模型已保存为:chinese_digit_svm_model.pkl

posted @ 2025-11-13 23:50  千树(好困版)  阅读(9)  评论(0)    收藏  举报