点击查看代码
import os
import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from tqdm import tqdm
import random
# ===================== 1. 全局配置(训练轮次改为20) =====================
class Config:
root_dir = "D:/Pysch2/Pytorch/MSRA-TD500" # 数据集根目录
batch_size = 8
epochs = 20 # 训练轮次
lr = 1e-4
num_workers = 0 # Windows下建议设为0
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
char2idx = None
idx2char = None
num_classes = None
# 合成文本池(可自定义扩展)
text_pool = [
"一二三四", "五六七八", "九十百千", "甲乙丙丁", "金木水火",
"天地人和", "上下左右", "前后内外", "春夏秋冬", "东南西北",
"ABCDE", "FGHIJ", "12345", "67890", "XYZUV",
"测试文本", "识别训练", "流程验证", "数据合成", "模型测试"
]
# ===================== 2. 数据集加载(自动补充合成文本) =====================
class MSRA_TD500_Synth_Dataset(Dataset):
def __init__(self, root_dir, is_train=True, transform=None):
self.root_dir = root_dir
self.is_train = is_train
self.transform = transform
self.data_dir = os.path.join(root_dir, "train" if is_train else "test")
self.image_names = [f for f in os.listdir(self.data_dir) if f.endswith(('.jpg', '.JPG'))]
random.seed(42) # 固定随机种子,确保文本标签可复现
def __len__(self):
return len(self.image_names)
def __getitem__(self, idx):
# 加载图像
img_name = self.image_names[idx]
img_path = os.path.join(self.data_dir, img_name)
image = cv2.imread(img_path)
if image is None:
raise FileNotFoundError(f"图像文件不存在:{img_path}")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
h, w = image.shape[:2]
# 生成标注文件路径
img_base_name = os.path.splitext(img_name)[0]
label_path = os.path.join(self.data_dir, f"{img_base_name}.gt")
if not os.path.exists(label_path):
raise FileNotFoundError(f"标注文件不存在:{label_path}")
# 解析.gt文件并补充合成文本
boxes, texts = self.parse_gt_with_synth_text(label_path)
# 裁剪第一个文本区域
if len(boxes) > 0:
box = boxes[0]
x_min = min(box[::2])
x_max = max(box[::2])
y_min = min(box[1::2])
y_max = max(box[1::2])
x_min, x_max = max(0, x_min), min(w, x_max)
y_min, y_max = max(0, y_min), min(h, y_max)
crop_img = image[y_min:y_max, x_min:x_max]
else:
crop_img = image
if self.transform:
crop_img = self.transform(crop_img)
return crop_img, texts[0] if texts else ""
@staticmethod
def parse_gt_with_synth_text(gt_path):
"""解析仅坐标的.gt文件,自动补充合成文本"""
boxes = []
texts = []
with open(gt_path, 'r', encoding='gbk') as f:
lines = f.readlines()
for line in lines:
line = line.strip()
if not line:
continue
parts = line.split(',')
if len(parts) == 8:
try:
coords = list(map(int, parts))
boxes.append(coords)
texts.append(random.choice(Config.text_pool))
except ValueError:
continue
return boxes, texts
# ===================== 3. 生成字符映射表(基于合成文本池) =====================
def generate_synth_char_map():
char_set = set()
for text in Config.text_pool:
char_set.update(list(text))
char_list = sorted(list(char_set))
Config.char2idx = {char: idx+1 for idx, char in enumerate(char_list)}
Config.char2idx['<blank>'] = 0
Config.idx2char = {v: k for k, v in Config.char2idx.items()}
Config.num_classes = len(Config.char2idx)
print(f"合成字符集大小:{Config.num_classes}(含空白字符)")
print(f"字符集内容:{char_list}")
# ===================== 4. 数据预处理 =====================
train_transform = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize((32, 128)),
transforms.RandomAffine(degrees=5, translate=(0.05, 0.05)),
transforms.RandomApply([transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 0.5))], p=0.5),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
test_transform = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize((32, 128)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# ===================== 5. CRNN 模型定义(修复LSTM元组错误) =====================
class CRNN(nn.Module):
def __init__(self, num_classes):
super(CRNN, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2, 2),
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2, 2),
nn.Conv2d(128, 256, kernel_size=3, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(),
nn.Conv2d(256, 256, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d((2, 1), (2, 1)),
nn.Conv2d(256, 512, kernel_size=3, padding=1),
nn.BatchNorm2d(512),
nn.ReLU(),
nn.Conv2d(512, 512, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d((2, 1), (2, 1)),
nn.Conv2d(512, 512, kernel_size=2, padding=0),
nn.BatchNorm2d(512),
nn.ReLU()
)
# 单独定义LSTM层,避免Sequential包装
self.lstm1 = nn.LSTM(512, 256, bidirectional=True, batch_first=True)
self.lstm2 = nn.LSTM(512, 256, bidirectional=True, batch_first=True)
self.fc = nn.Linear(512, num_classes)
def forward(self, x):
conv_out = self.conv(x)
conv_out = conv_out.squeeze(2)
conv_out = conv_out.permute(0, 2, 1)
# 手动处理LSTM输出(仅取序列输出)
rnn_out, _ = self.lstm1(conv_out)
rnn_out, _ = self.lstm2(rnn_out)
logits = self.fc(rnn_out)
return logits
# ===================== 6. 训练与测试工具函数 =====================
def encode_text(texts):
labels = []
lengths = []
for text in texts:
label = [Config.char2idx[c] for c in text if c in Config.char2idx]
labels.extend(label)
lengths.append(len(label))
return torch.tensor(labels, dtype=torch.long), torch.tensor(lengths, dtype=torch.long)
def ctc_decode(log_probs):
outputs = []
for prob in log_probs:
pred = torch.argmax(prob, dim=1).cpu().numpy()
text = []
prev_char = None
for c in pred:
if c != 0 and c != prev_char:
text.append(Config.idx2char[c])
prev_char = c
outputs.append(''.join(text))
return outputs
# ===================== 7. 主训练流程 =====================
def main():
generate_synth_char_map()
if Config.num_classes < 2:
print("错误:字符集为空,请扩展Config.text_pool!")
return
# 初始化数据集
try:
train_dataset = MSRA_TD500_Synth_Dataset(
root_dir=Config.root_dir, is_train=True, transform=train_transform
)
test_dataset = MSRA_TD500_Synth_Dataset(
root_dir=Config.root_dir, is_train=False, transform=test_transform
)
except Exception as e:
print(f"数据集初始化失败:{e}")
return
# 数据加载器
train_loader = DataLoader(
train_dataset, batch_size=Config.batch_size, shuffle=True,
num_workers=Config.num_workers, pin_memory=True, drop_last=True
)
test_loader = DataLoader(
test_dataset, batch_size=Config.batch_size, shuffle=False,
num_workers=Config.num_workers, pin_memory=True
)
# 模型、损失函数、优化器
model = CRNN(num_classes=Config.num_classes).to(Config.device)
criterion = nn.CTCLoss(blank=0, zero_infinity=True)
optimizer = torch.optim.AdamW(model.parameters(), lr=Config.lr, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=Config.epochs)
# 训练循环
print(f"开始训练(设备:{Config.device})")
print(f"训练集样本数:{len(train_dataset)},测试集样本数:{len(test_dataset)}")
for epoch in range(Config.epochs):
model.train()
total_loss = 0.0
pbar = tqdm(train_loader, desc=f"Epoch [{epoch+1}/{Config.epochs}]")
for images, texts in pbar:
images = images.to(Config.device)
labels, label_lengths = encode_text(texts)
labels = labels.to(Config.device)
label_lengths = label_lengths.to(Config.device)
# 前向传播
logits = model(images)
log_probs = F.log_softmax(logits, dim=2)
input_lengths = torch.full(
(logits.size(0),), logits.size(1), dtype=torch.long
).to(Config.device)
# 计算损失
if label_lengths.sum() == 0:
continue
loss = criterion(log_probs, labels, input_lengths, label_lengths)
# 反向传播与优化
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
total_loss += loss.item() * images.size(0)
pbar.set_postfix({"Loss": loss.item()})
# 学习率调度
scheduler.step()
# 打印训练信息
avg_loss = total_loss / len(train_loader.dataset)
print(f"Epoch [{epoch+1}/{Config.epochs}], Average Loss: {avg_loss:.4f}")
# 每5个epoch测试一次
if (epoch + 1) % 5 == 0:
model.eval()
correct = 0
total = 0
with torch.no_grad():
for images, texts in test_loader:
images = images.to(Config.device)
logits = model(images)
log_probs = F.log_softmax(logits, dim=2)
preds = ctc_decode(log_probs)
# 打印部分预测结果
for i in range(min(3, len(preds))):
print(f"真实文本:{texts[i]},预测文本:{preds[i]}")
# 计算准确率
for pred, text in zip(preds, texts):
if pred == text:
correct += 1
total += 1
if total > 0:
acc = correct / total
print(f"Test Accuracy: {acc:.4f}")
else:
print("测试集无有效样本")
# 保存模型
torch.save(model.state_dict(), "crnn_msra_td500_synth_epoch20.pth")
print("训练完成,20轮合成文本训练模型已保存!")
if __name__ == "__main__":
main()