最近在使用手写数字识别模型,发现一个问题,当我在pc上训练后,识别精度是很好的,但是放在了嵌入式板卡上运行,发现了新的问题。
找到原因,发现其实是因为图像处理的问题,当手写数字经过opencv处理后,效果不理想,输入模型也无法识别,经过调整后,可以识别0~10的数字,效果还可以。因此记录一下
找到原因,发现其实是因为图像处理的问题,当手写数字经过opencv处理后,效果不理想,输入模型也无法识别,经过调整后,可以识别0~10的数字,效果还可以。因此记录一下
import cv2 import numpy as np import matplotlib.pyplot as plt from tensorflow.keras.models import load_model def enhanced_preprocess(img_path): """ 增强版图像预处理(兼容PC端和嵌入式设备) 返回: - 用于PC端模型的float32格式图像(0-1范围) - 用于嵌入式设备的uint8格式图像(0-255范围) """ img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE) # 1. 改进的二值化(保护细线条) blur = cv2.GaussianBlur(img, (3,3), 0) binary = cv2.adaptiveThreshold( blur, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY_INV, blockSize=11, # 增大块大小 C=2 # 降低敏感度 ) # 2. 智能形态学处理 kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (2,2)) processed = cv2.morphologyEx(binary, cv2.MORPH_CLOSE, kernel) # 3. 动态标准化(关键修改) contours, _ = cv2.findContours(processed, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) if contours: contour = max(contours, key=cv2.contourArea) x,y,w,h = cv2.boundingRect(contour) digit = processed[y:y+h, x:x+w] # 动态膨胀核(根据凸包缺陷调整) hull = cv2.convexHull(contour, returnPoints=False) defects = cv2.convexityDefects(contour, hull) if hull is not None else None base_size = min(5, max(2, int(min(h,w)/4))) # 基础核大小 if defects is not None and len(defects) > 1: # 存在开口 dilation_size = max(1, base_size-2) # 减小膨胀 else: dilation_size = base_size digit = cv2.dilate(digit, np.ones((12, 12), np.uint8)) # 更保守的边距设置 margin = int(max(h,w)*0.40) canvas = np.zeros((h+2*margin, w+2*margin), dtype=np.uint8) canvas[margin:margin+h, margin:margin+w] = digit # 改用面积插值(保护开口) resized = cv2.resize(canvas, (28,28), interpolation=cv2.INTER_AREA) else: resized = np.zeros((28,28), dtype=np.uint8) return resized.astype(np.float32) / 255.0 def predict_with_enhanced_preprocess(image_path, model_path='mnist_model.h5'): """集成增强预处理的预测函数""" # 1. 加载模型 model = load_model(model_path) # 2. 使用增强预处理 float_img = enhanced_preprocess(image_path) plt.imshow(float_img, cmap='gray') # 4. 根据模型输入形状调整维度 if len(model.input_shape) == 3: # 如 (None,28,28) input_data = float_img # 直接使用(28,28) elif len(model.input_shape) == 4: # 如 (None,28,28,1) input_data = np.expand_dims(float_img, axis=-1) # (28,28,1) else: raise ValueError(f"不支持的模型输入形状: {model.input_shape}") # 5. 添加批次维度 input_data = np.expand_dims(input_data, axis=0) # (1,28,28) 或 (1,28,28,1) # 6. 执行预测 try: prediction = model.predict(input_data) except Exception as e: print("\n=== 调试信息 ===") print(f"模型输入形状: {model.input_shape}") print(f"实际输入形状: {input_data.shape}") print(f"错误详情: {str(e)}") raise # 7. 可视化结果 return prediction # 使用示例 if __name__ == "__main__": # PC端模型验证 prediction = predict_with_enhanced_preprocess("0.jpg") #plt.imshow(img, cmap='gray') plt.title(f"pridict num: {np.argmax(prediction)} (believe: {np.max(prediction):.2%})") plt.axis('off') plt.show() # 嵌入式设备使用(获取uint8格式) # _, embedded_input = enhanced_preprocess("example.jpg") # print("\n嵌入式设备输入示例:") # print(f"数据类型: {embedded_input.dtype}") # print(f"形状: {embedded_input.shape}") # print("左上角5x5像素值:") # print(embedded_input[:5, :5])