点击查看代码
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import os
from PIL import Image
import random
# ===================== 1. 配置项(全量最低配) =====================
DEVICE = torch.device("cpu") # 强制CPU(最低配,无GPU也能跑)
IMG_SIZE = 32 # 图片缩到32x32(越小越省资源)
BATCH_SIZE = 8 # 批次大小降到8(低配内存友好)
EPOCHS = 5 # 仅训练5轮(快速验证)
LR = 0.01 # 学习率适中
# 类别数:HWDB1.1常用3755个一级汉字,这里简化为100类(测试用,可按需改)
NUM_CLASSES = 100
# 数据集路径(需自行替换为HWDB1.1的图片路径)
HWDB_ROOT = "D:\Pysch2\Pytorch\HWDB1.1trn_gnt\HWDB1.1trm_nt"
# ===================== 2. 数据集加载(极简版) =====================
class HWDB11Dataset(Dataset):
def __init__(self, root, img_size=32, num_classes=100):
self.root = root
self.img_size = img_size
self.num_classes = num_classes
self.img_paths, self.labels = self._load_data()
def _load_data(self):
"""极简加载:仅读取前num_classes类的少量样本"""
img_paths = []
labels = []
# 遍历类别文件夹(HWDB1.1按汉字编码分文件夹)
class_dirs = [d for d in os.listdir(self.root) if os.path.isdir(os.path.join(self.root, d))]
# 仅取前num_classes类(降低复杂度)
selected_classes = class_dirs[:self.num_classes]
for label, cls_dir in enumerate(selected_classes):
cls_path = os.path.join(self.root, cls_dir)
# 每个类别仅取20张图(最低配,减少数据量)
img_files = [f for f in os.listdir(cls_path) if f.endswith((".png", ".jpg"))][:20]
for img_file in img_files:
img_paths.append(os.path.join(cls_path, img_file))
labels.append(label)
return img_paths, labels
def __len__(self):
return len(self.img_paths)
def __getitem__(self, idx):
# 极简预处理:灰度化+缩放+归一化
img_path = self.img_paths[idx]
img = Image.open(img_path).convert("L") # 转灰度图(省内存)
img = img.resize((self.img_size, self.img_size)) # 缩放到固定尺寸
# 转张量:(H,W) → (1,H,W),归一化到0-1(最低配预处理)
img_tensor = torch.tensor(np.array(img), dtype=torch.float32).unsqueeze(0) / 255.0
label = torch.tensor(self.labels[idx], dtype=torch.long)
return img_tensor, label
# ===================== 3. 极简CNN模型(最低配) =====================
class SimpleHWDBModel(nn.Module):
def __init__(self, num_classes=100):
super(SimpleHWDBModel, self).__init__()
# 仅2个卷积层(最低配,减少参数)
self.features = nn.Sequential(
nn.Conv2d(1, 8, kernel_size=3, stride=1, padding=1), # 8个卷积核(极少)
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2), # 32→16
nn.Conv2d(8, 16, kernel_size=3, stride=1, padding=1), # 16个卷积核
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2), # 16→8
)
# 全连接层(参数极少)
self.classifier = nn.Sequential(
nn.Flatten(),
nn.Linear(16 * 8 * 8, 64), # 中间层仅64维
nn.ReLU(),
nn.Linear(64, num_classes) # 输出类别
)
def forward(self, x):
x = self.features(x)
x = self.classifier(x)
return x
# ===================== 4. 训练+验证(极简流程) =====================
def main():
# 1. 加载数据集(最低配)
dataset = HWDB11Dataset(HWDB_ROOT, IMG_SIZE, NUM_CLASSES)
# 划分训练/测试集(极简:按8:2拆分)
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
# 2. 初始化模型(最低配)
model = SimpleHWDBModel(NUM_CLASSES).to(DEVICE)
criterion = nn.CrossEntropyLoss() # 分类损失
optimizer = optim.SGD(model.parameters(), lr=LR) # SGD比Adam更省资源
# 3. 极简训练
model.train()
for epoch in range(EPOCHS):
total_loss = 0.0
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(DEVICE), target.to(DEVICE)
# 前向+反向+优化
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
total_loss += loss.item()
# 打印每轮损失(极简日志)
avg_loss = total_loss / len(train_loader)
print(f"Epoch [{epoch+1}/{EPOCHS}], Loss: {avg_loss:.4f}")
# 4. 极简验证
model.eval()
correct = 0
total = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(DEVICE), target.to(DEVICE)
output = model(data)
_, predicted = torch.max(output.data, 1)
total += target.size(0)
correct += (predicted == target).sum().item()
print(f"\n测试集准确率: {100 * correct / total:.2f}%")
if __name__ == "__main__":
# 检查数据集路径(需自行替换为实际HWDB1.1路径)
if not os.path.exists(HWDB_ROOT):
print(f"错误:请将HWDB_ROOT改为你的HWDB1.1数据集路径,当前路径:{HWDB_ROOT}")
else:
main()