wavefinder

奇妙冒险
最近在使用手写数字识别模型,发现一个问题,当我在pc上训练后,识别精度是很好的,但是放在了嵌入式板卡上运行,发现了新的问题。
找到原因,发现其实是因为图像处理的问题,当手写数字经过opencv处理后,效果不理想,输入模型也无法识别,经过调整后,可以识别0~10的数字,效果还可以。因此记录一下
import cv2
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.models import load_model

def enhanced_preprocess(img_path):
    """
    增强版图像预处理(兼容PC端和嵌入式设备)
    返回:
        - 用于PC端模型的float32格式图像(0-1范围)
        - 用于嵌入式设备的uint8格式图像(0-255范围)
    """
    img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
    
    # 1. 改进的二值化(保护细线条)
    blur = cv2.GaussianBlur(img, (3,3), 0)
    binary = cv2.adaptiveThreshold(
        blur, 255, 
        cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
        cv2.THRESH_BINARY_INV, 
        blockSize=11,  # 增大块大小
        C=2            # 降低敏感度
    )
    
    # 2. 智能形态学处理
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (2,2))
    processed = cv2.morphologyEx(binary, cv2.MORPH_CLOSE, kernel)
    
    # 3. 动态标准化(关键修改)
    contours, _ = cv2.findContours(processed, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    if contours:
        contour = max(contours, key=cv2.contourArea)
        x,y,w,h = cv2.boundingRect(contour)
        digit = processed[y:y+h, x:x+w]
        
        # 动态膨胀核(根据凸包缺陷调整)
        hull = cv2.convexHull(contour, returnPoints=False)
        defects = cv2.convexityDefects(contour, hull) if hull is not None else None
        
        base_size = min(5, max(2, int(min(h,w)/4)))  # 基础核大小
        if defects is not None and len(defects) > 1:  # 存在开口
            dilation_size = max(1, base_size-2)       # 减小膨胀
        else:
            dilation_size = base_size
        
        digit = cv2.dilate(digit, np.ones((12, 12), np.uint8))
        
        # 更保守的边距设置
        margin = int(max(h,w)*0.40)
        canvas = np.zeros((h+2*margin, w+2*margin), dtype=np.uint8)
        canvas[margin:margin+h, margin:margin+w] = digit
        
        # 改用面积插值(保护开口)
        resized = cv2.resize(canvas, (28,28), interpolation=cv2.INTER_AREA)
    else:
        resized = np.zeros((28,28), dtype=np.uint8)
    
    return resized.astype(np.float32) / 255.0

def predict_with_enhanced_preprocess(image_path, model_path='mnist_model.h5'):
    """集成增强预处理的预测函数"""
    # 1. 加载模型
    model = load_model(model_path)
    
    # 2. 使用增强预处理
    float_img  = enhanced_preprocess(image_path)
    


    plt.imshow(float_img, cmap='gray')
    # 4. 根据模型输入形状调整维度
    if len(model.input_shape) == 3:  # 如 (None,28,28)
        input_data = float_img  # 直接使用(28,28)
    elif len(model.input_shape) == 4:  # 如 (None,28,28,1)
        input_data = np.expand_dims(float_img, axis=-1)  # (28,28,1)
    else:
        raise ValueError(f"不支持的模型输入形状: {model.input_shape}")


    # 5. 添加批次维度
    input_data = np.expand_dims(input_data, axis=0)  # (1,28,28) 或 (1,28,28,1)
    
    # 6. 执行预测
    try:
        prediction = model.predict(input_data)
    except Exception as e:
        print("\n=== 调试信息 ===")
        print(f"模型输入形状: {model.input_shape}")
        print(f"实际输入形状: {input_data.shape}")
        print(f"错误详情: {str(e)}")
        raise
    
    # 7. 可视化结果

    
    return prediction

# 使用示例
if __name__ == "__main__":
    # PC端模型验证
    prediction = predict_with_enhanced_preprocess("0.jpg")

    #plt.imshow(img, cmap='gray')
    plt.title(f"pridict num: {np.argmax(prediction)} (believe: {np.max(prediction):.2%})")
    plt.axis('off')
    plt.show()
    
    # 嵌入式设备使用(获取uint8格式)
    # _, embedded_input = enhanced_preprocess("example.jpg")
    # print("\n嵌入式设备输入示例:")
    # print(f"数据类型: {embedded_input.dtype}")
    # print(f"形状: {embedded_input.shape}")
    # print("左上角5x5像素值:")
    # print(embedded_input[:5, :5])