文字识别系统

import torch
import torch.nn as nn
import torch.nn.functional as F
import os
from torchvision.transforms.functional import to_tensor, rgb_to_grayscale, resize

# ---------------------- 1. 完善字符集(比如加“测”“试”“字”) ----------------------
CHARS = '0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ测试字'
char_to_idx = {c: i for i, c in enumerate(CHARS)}
num_chars = len(CHARS)

# ---------------------- 2. 更鲁棒的模型 ----------------------
class BetterOCR(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 32, 3, 1, 1),
            nn.ReLU(),
            nn.MaxPool2d(2),  # 输出:1x32x24x24
            nn.Conv2d(32, 64, 3, 1, 1),
            nn.ReLU(),
            nn.MaxPool2d(2),  # 输出:1x64x12x12
            nn.Conv2d(64, 128, 3, 1, 1),
            nn.ReLU(),
            nn.MaxPool2d(2)   # 输出:1x128x6x6
        )
        self.fc = nn.Sequential(
            nn.Linear(128*6*6, 256),  # 展平后维度:128*6*6=4608
            nn.ReLU(),
            nn.Linear(256, num_chars)
        )

    def forward(self, x):
        x = self.conv(x)
        x = x.flatten(1)
        x = self.fc(x)
        return x

# ---------------------- 3. 专业图片处理(适配真实图片) ----------------------
def process_image(image_path):
    """用torch自带工具处理真实图片,转48x48灰度图"""
    from PIL import Image  # 这里仅临时用(若没有,可注释,用之前的方法)
    img = Image.open(image_path).convert('L')  # 转灰度
    img = resize(img, (48, 48))  # 统一尺寸
    img_tensor = to_tensor(img).unsqueeze(0)  # [1,1,48,48]
    return img_tensor

# ---------------------- 4. 极简训练(让模型“学会”识别) ----------------------
def train_model(model, train_data, epochs=10):
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()
    model.train()
    for epoch in range(epochs):
        total_loss = 0.0
        for img, label in train_data:
            optimizer.zero_grad()
            output = model(img)
            loss = criterion(output, label)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"Epoch {epoch+1}, Loss: {total_loss:.4f}")
    return model

# ---------------------- 5. 生成训练数据(模拟“打一个字”的样本) ----------------------
def make_train_data():
    """生成3个样本:“测”“试”“字”的图片Tensor+标签"""
    train_data = []
    # 模拟样本1:“测”的图片Tensor(可替换为真实图片处理后的Tensor)
    img_test1 = process_image("D:\\test1.jpg")  # 替换为含“测”的图片
    label_test1 = torch.tensor([char_to_idx['测']])
    train_data.append((img_test1, label_test1))
    
    img_test2 = process_image("D:\\test2.jpg")  # 含“试”的图片
    label_test2 = torch.tensor([char_to_idx['试']])
    train_data.append((img_test2, label_test2))
    
    img_test3 = process_image("D:\\test3.jpg")  # 含“字”的图片
    label_test3 = torch.tensor([char_to_idx['字']])
    train_data.append((img_test3, label_test3))
    return train_data

# ---------------------- 运行:训练+识别 ----------------------
if __name__ == "__main__":
    # 初始化+训练模型
    model = BetterOCR()
    print("生成训练数据...")
    train_data = make_train_data()
    print("训练模型...")
    model = train_model(model, train_data)

    # 识别目标图片(比如含“字”的图片)
    IMAGE_PATH = "D://test3.jpg"
    if not os.path.exists(IMAGE_PATH):
        print(f"图片不存在!")
    else:
        img = process_image(IMAGE_PATH)
        model.eval()
        with torch.no_grad():
            output = model(img)
            prob = F.softmax(output, dim=1)
            pred_idx = prob.argmax(1).item()
            confidence = prob[0][pred_idx] * 100
        result = CHARS[pred_idx]
        print(f"识别结果:{result} | 匹配度:{confidence:.2f}%")

微信图片_20251120234310_216_36

posted @ 2025-11-20 23:56  叶柯鑫  阅读(0)  评论(0)    收藏  举报