手写汉字
最终无PCA版:解决维度错误
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.svm import SVC
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from sklearn.pipeline import Pipeline
import time
import joblib
---------------------- 1. 核心配置 ----------------------
DATA_ROOT = "chinese_digit_dataset"
IMAGE_SIZE = (32, 32)
CLASS_NAMES = ["零", "一", "二", "三", "四", "五", "六", "七", "八", "九", "十", "佰", "仟", "万", "亿"]
NUM_CLASSES = len(CLASS_NAMES)
RANDOM_SEED = 42
---------------------- 2. 自动创建文件夹结构 ----------------------
if not os.path.exists(DATA_ROOT):
os.makedirs(DATA_ROOT)
for label in range(NUM_CLASSES):
class_folder = os.path.join(DATA_ROOT, str(label))
os.makedirs(class_folder)
print(f"✅ 已自动创建数据集文件夹:{DATA_ROOT}(包含0-14子文件夹)")
else:
print(f"📂 数据集文件夹已存在:{DATA_ROOT}")
---------------------- 3. 自动生成测试图片 ----------------------
def generate_test_images(data_root, class_count, img_size):
for label_idx in range(class_count):
class_folder = os.path.join(data_root, str(label_idx))
img_path = os.path.join(class_folder, f"{label_idx}.png")
# 生成白色背景灰度图
img = np.ones((img_size[0], img_size[1]), dtype=np.uint8) * 255
# 绘制手写数字
if label_idx == 0: # 零
cv2.circle(img, (img_size[0]//2, img_size[1]//2), img_size[0]//4, 0, 2)
elif label_idx == 1: # 一
cv2.line(img, (img_size[0]//2, img_size[1]//4), (img_size[0]//2, img_size[1]//4*3), 0, 2)
elif label_idx == 2: # 二
cv2.line(img, (img_size[0]//4, img_size[1]//2), (img_size[0]//4*3, img_size[1]//2), 0, 2)
cv2.line(img, (img_size[0]//4*3, img_size[1]//2), (img_size[0]//4*3, img_size[1]//4*3), 0, 2)
cv2.line(img, (img_size[0]//4*3, img_size[1]//4*3), (img_size[0]//4, img_size[1]//4*3), 0, 2)
elif label_idx == 3: # 三
cv2.line(img, (img_size[0]//4, img_size[1]//3), (img_size[0]//4*3, img_size[1]//3), 0, 2)
cv2.line(img, (img_size[0]//4, img_size[1]//2), (img_size[0]//4*3, img_size[1]//2), 0, 2)
cv2.line(img, (img_size[0]//4, img_size[1]//3*2), (img_size[0]//4*3, img_size[1]//3*2), 0, 2)
else: # 其他数字
text = str(label_idx) if label_idx <=9 else CLASS_NAMES[label_idx][0]
cv2.putText(img, text, (img_size[0]//4, img_size[1]//4*3),
cv2.FONT_HERSHEY_SIMPLEX, 0.8, 0, 2)
cv2.imwrite(img_path, img)
print(f"✅ 已自动生成15张测试图片,保存到{DATA_ROOT}子文件夹中")
调用生成函数
generate_test_images(DATA_ROOT, NUM_CLASSES, IMAGE_SIZE)
---------------------- 4. 数据加载与预处理 ----------------------
def load_data(data_root, img_size):
features = []
labels = []
for label_idx in range(NUM_CLASSES):
class_folder = os.path.join(data_root, str(label_idx))
for img_name in os.listdir(class_folder):
img_path = os.path.join(class_folder, img_name)
if img_name.lower().endswith((".png", ".jpg", ".jpeg")):
img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
if img is not None:
img = cv2.resize(img, img_size)
features.append(img.flatten())
labels.append(label_idx)
return np.array(features), np.array(labels)
加载数据
print("开始加载数据...")
X, y = load_data(DATA_ROOT, IMAGE_SIZE)
print(f"数据加载完成:共{X.shape[0]}张图像,特征维度{X.shape}")
划分训练集/测试集
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.25, random_state=RANDOM_SEED
)
---------------------- 关键修改:移除PCA ----------------------
print("\n构建模型...")
model = Pipeline([
("scaler", StandardScaler()),
("svm", SVC(kernel="rbf", C=1.0, gamma="scale", probability=True, random_state=RANDOM_SEED))
])
print("开始训练模型...")
start_time = time.time()
model.fit(X_train, y_train)
print(f"训练完成!耗时:{time.time() - start_time:.2f}秒")
---------------------- 6. 模型评估 ----------------------
print("\n=== 模型评估结果 ===")
y_pred = model.predict(X_test)
print(f"测试集准确率:{accuracy_score(y_test, y_pred):.4f}")
print("\n分类报告:")
print(classification_report(y_test, y_pred, target_names=CLASS_NAMES, zero_division=0))
混淆矩阵可视化
def plot_confusion_matrix(y_true, y_pred, class_names):
cm = confusion_matrix(y_true, y_pred)
plt.figure(figsize=(10, 8))
plt.imshow(cm, interpolation="nearest", cmap=plt.cm.Blues)
plt.title("混淆矩阵", fontsize=14)
plt.colorbar()
plt.xticks(np.arange(len(class_names)), class_names, rotation=45)
plt.yticks(np.arange(len(class_names)), class_names)
plt.xlabel("预测标签", fontsize=12)
plt.ylabel("真实标签", fontsize=12)
plt.tight_layout()
plt.show()
plot_confusion_matrix(y_test, y_pred, CLASS_NAMES)
---------------------- 7. 保存模型 ----------------------
joblib.dump(model, "chinese_digit_svm_model.pkl")
print("\n模型已保存为:chinese_digit_svm_model.pkl

浙公网安备 33010602011771号