文字识别

import torch
import torch.nn as nn
import torch.nn.functional as F
import os
from torchvision.transforms.functional import to_tensor, rgb_to_grayscale, resize

---------------------- 1. 完善字符集(比如加“测”“试”“字”) ----------------------

CHARS = '0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ测试字'
char_to_idx = {c: i for i, c in enumerate(CHARS)}
num_chars = len(CHARS)

---------------------- 2. 更鲁棒的模型 ----------------------

class BetterOCR(nn.Module):
def init(self):
super().init()
self.conv = nn.Sequential(
nn.Conv2d(1, 32, 3, 1, 1),
nn.ReLU(),
nn.MaxPool2d(2), # 输出:1x32x24x24
nn.Conv2d(32, 64, 3, 1, 1),
nn.ReLU(),
nn.MaxPool2d(2), # 输出:1x64x12x12
nn.Conv2d(64, 128, 3, 1, 1),
nn.ReLU(),
nn.MaxPool2d(2) # 输出:1x128x6x6
)
self.fc = nn.Sequential(
nn.Linear(12866, 256), # 展平后维度:12866=4608
nn.ReLU(),
nn.Linear(256, num_chars)
)

def forward(self, x):
    x = self.conv(x)
    x = x.flatten(1)
    x = self.fc(x)
    return x

---------------------- 3. 专业图片处理(适配真实图片) ----------------------

def process_image(image_path):
"""用torch自带工具处理真实图片,转48x48灰度图"""
from PIL import Image # 这里仅临时用(若没有,可注释,用之前的方法)
img = Image.open(image_path).convert('L') # 转灰度
img = resize(img, (48, 48)) # 统一尺寸
img_tensor = to_tensor(img).unsqueeze(0) # [1,1,48,48]
return img_tensor

---------------------- 4. 极简训练(让模型“学会”识别) ----------------------

def train_model(model, train_data, epochs=10):
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
model.train()
for epoch in range(epochs):
total_loss = 0.0
for img, label in train_data:
optimizer.zero_grad()
output = model(img)
loss = criterion(output, label)
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f"Epoch {epoch+1}, Loss: {total_loss:.4f}")
return model

---------------------- 5. 生成训练数据(模拟“打一个字”的样本) ----------------------

def make_train_data():
"""生成3个样本:“测”“试”“字”的图片Tensor+标签"""
train_data = []
# 模拟样本1:“测”的图片Tensor(可替换为真实图片处理后的Tensor)
img_test1 = process_image("D:\test1.jpg") # 替换为含“测”的图片
label_test1 = torch.tensor([char_to_idx['测']])
train_data.append((img_test1, label_test1))

img_test2 = process_image("D:\\test2.jpg")  # 含“试”的图片
label_test2 = torch.tensor([char_to_idx['试']])
train_data.append((img_test2, label_test2))

img_test3 = process_image("D:\\test3.jpg")  # 含“字”的图片
label_test3 = torch.tensor([char_to_idx['字']])
train_data.append((img_test3, label_test3))
return train_data

---------------------- 运行:训练+识别 ----------------------

if name == "main":
# 初始化+训练模型
model = BetterOCR()
print("生成训练数据...")
train_data = make_train_data()
print("训练模型...")
model = train_model(model, train_data)

# 识别目标图片(比如含“字”的图片)
IMAGE_PATH = "D://test3.jpg"
if not os.path.exists(IMAGE_PATH):
    print(f"图片不存在!")
else:
    img = process_image(IMAGE_PATH)
    model.eval()
    with torch.no_grad():
        output = model(img)
        prob = F.softmax(output, dim=1)
        pred_idx = prob.argmax(1).item()
        confidence = prob[0][pred_idx] * 100
    result = CHARS[pred_idx]
    print(f"识别结果:{result} | 匹配度:{confidence:.2f}%")

1cb1608b39fa8d3a41dd6f17e9b78319

posted @ 2025-11-21 00:09  kkkk0515  阅读(0)  评论(0)    收藏  举报