文字识别
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
from torchvision.transforms.functional import to_tensor, rgb_to_grayscale, resize
---------------------- 1. 完善字符集(比如加“测”“试”“字”) ----------------------
CHARS = '0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ测试字'
char_to_idx = {c: i for i, c in enumerate(CHARS)}
num_chars = len(CHARS)
---------------------- 2. 更鲁棒的模型 ----------------------
class BetterOCR(nn.Module):
def init(self):
super().init()
self.conv = nn.Sequential(
nn.Conv2d(1, 32, 3, 1, 1),
nn.ReLU(),
nn.MaxPool2d(2), # 输出:1x32x24x24
nn.Conv2d(32, 64, 3, 1, 1),
nn.ReLU(),
nn.MaxPool2d(2), # 输出:1x64x12x12
nn.Conv2d(64, 128, 3, 1, 1),
nn.ReLU(),
nn.MaxPool2d(2) # 输出:1x128x6x6
)
self.fc = nn.Sequential(
nn.Linear(12866, 256), # 展平后维度:12866=4608
nn.ReLU(),
nn.Linear(256, num_chars)
)
def forward(self, x):
x = self.conv(x)
x = x.flatten(1)
x = self.fc(x)
return x
---------------------- 3. 专业图片处理(适配真实图片) ----------------------
def process_image(image_path):
"""用torch自带工具处理真实图片,转48x48灰度图"""
from PIL import Image # 这里仅临时用(若没有,可注释,用之前的方法)
img = Image.open(image_path).convert('L') # 转灰度
img = resize(img, (48, 48)) # 统一尺寸
img_tensor = to_tensor(img).unsqueeze(0) # [1,1,48,48]
return img_tensor
---------------------- 4. 极简训练(让模型“学会”识别) ----------------------
def train_model(model, train_data, epochs=10):
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
model.train()
for epoch in range(epochs):
total_loss = 0.0
for img, label in train_data:
optimizer.zero_grad()
output = model(img)
loss = criterion(output, label)
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f"Epoch {epoch+1}, Loss: {total_loss:.4f}")
return model
---------------------- 5. 生成训练数据(模拟“打一个字”的样本) ----------------------
def make_train_data():
"""生成3个样本:“测”“试”“字”的图片Tensor+标签"""
train_data = []
# 模拟样本1:“测”的图片Tensor(可替换为真实图片处理后的Tensor)
img_test1 = process_image("D:\test1.jpg") # 替换为含“测”的图片
label_test1 = torch.tensor([char_to_idx['测']])
train_data.append((img_test1, label_test1))
img_test2 = process_image("D:\\test2.jpg") # 含“试”的图片
label_test2 = torch.tensor([char_to_idx['试']])
train_data.append((img_test2, label_test2))
img_test3 = process_image("D:\\test3.jpg") # 含“字”的图片
label_test3 = torch.tensor([char_to_idx['字']])
train_data.append((img_test3, label_test3))
return train_data
---------------------- 运行:训练+识别 ----------------------
if name == "main":
# 初始化+训练模型
model = BetterOCR()
print("生成训练数据...")
train_data = make_train_data()
print("训练模型...")
model = train_model(model, train_data)
# 识别目标图片(比如含“字”的图片)
IMAGE_PATH = "D://test3.jpg"
if not os.path.exists(IMAGE_PATH):
print(f"图片不存在!")
else:
img = process_image(IMAGE_PATH)
model.eval()
with torch.no_grad():
output = model(img)
prob = F.softmax(output, dim=1)
pred_idx = prob.argmax(1).item()
confidence = prob[0][pred_idx] * 100
result = CHARS[pred_idx]
print(f"识别结果:{result} | 匹配度:{confidence:.2f}%")

浙公网安备 33010602011771号