重生之从零开始的神经网络算法学习之路——第八篇 大型数据集与复杂模型的GPU训练实践
引言
在前一篇中,我们实现了基础的SRCNN超分辨率模型并掌握了后台训练技巧。本篇将进一步拓展实验规模:引入更大规模的数据集、实现更复杂的网络结构,并优化GPU训练策略,以应对更具挑战性的图像重建任务。通过这些实践,我们将深入理解大规模深度学习实验的关键技术和工程细节。
大型数据集的获取与处理
适合超分辨率任务的大型数据集
为了提升模型泛化能力,我们可以使用以下大型数据集:
1.** DIV2K扩展集 :包含1000张高分辨率训练图像和100张验证图像(2K分辨率)
2. Flickr2K :2650张来自Flickr的高分辨率自然图像(4K及以上)
3. CelebA-HQ :30,000张高质量人脸图像(1024x1024分辨率)
4. ImageNet **:百万级通用图像数据集(可用于预训练)
代码实现
git clone https://gitee.com/cmx1998/py-torch-learning.git
cd py-torch-learning/codes/esrgan-project
自动下载与解压实现
import os
import wget
import zipfile
import tarfile
from tqdm import tqdm
# 数据集下载配置
DATASETS = {
"DIV2K": {
"train": "http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_HR.zip",
"valid": "http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_valid_HR.zip"
},
"Flickr2K": {
"url": "https://cv.snu.ac.kr/research/EDSR/Flickr2K.tar"
}
}
def download_dataset(url, save_dir, filename=None):
"""带进度条的数据集下载函数"""
# 根据输入参数save_dir,新建存储目录
os.makedirs(save_dir, exist_ok=True)
# 传入文件名参数检查
if not filename:
filename = url.split("/")[-1]
save_path = os.path.join(save_dir, filename)
# 文件存在性检查
if os.path.exists(save_path):
print(f"文件 {filename} 已存在,跳过下载")
return save_path
print(f"开始下载 {filename}...")
# 使用tqdm创建进度条,减少输出频率
with tqdm(total=100, desc=f"下载 {filename}", unit="%") as pbar:
def progress_bar(current, total, width=80):
progress = current / total * 100
pbar.n = int(progress)
pbar.update(0) # 只更新进度条显示,不产生新输出
wget.download(url, save_path, bar=progress_bar)
print(f"\n{filename} 下载完成")
return save_path
def extract_archive(file_path, extract_dir):
"""解压数据集文件"""
# 根据输入参数extract_dir,新建解压目录
os.makedirs(extract_dir, exist_ok=True)
filename = os.path.basename(file_path)
# 生成解压后根目录的标识(根据压缩包名判断)
extract_flag = os.path.join(extract_dir, f".{filename}.extracted") # 标记文件
# 检查是否已解压(通过标记文件判断)
if os.path.exists(extract_flag):
print(f"文件 {filename} 已解压,跳过解压!")
return
# 执行解压
try:
# 文件名后缀检查
if file_path.endswith(".zip"):
"""处理zip压缩文件"""
with zipfile.ZipFile(file_path, 'r') as zip_ref:
# 显示解压进度
for file in tqdm(zip_ref.namelist(), desc="解压中"):
zip_ref.extract(file, extract_dir)
elif file_path.endswith(".tar") or file_path.endswith(".tar.gz"):
"""处理tar压缩文件"""
with tarfile.open(file_path, 'r') as tar_ref:
# 显示解压进度
members = tar_ref.getmembers()
for member in tqdm(members, desc="解压中"):
tar_ref.extract(member, extract_dir)
# 解压成功后创建标记文件
with open(extract_flag, 'w') as f:
f.write("Extracted successfully")
print(f"文件 {os.path.basename(file_path)} 解压完成")
except Exception as e:
print(f"解压失败:{e}")
# 失败时删除标记文件(避免误判)
if os.path.exists(extract_flag):
os.remove(extract_flag)
def prepare_large_datasets(base_dir):
"""准备所有大型数据集"""
# 下载DIV2K
div2k_dir = os.path.join(base_dir, "DIV2K")
for split, url in DATASETS["DIV2K"].items():
file_path = download_dataset(url, div2k_dir)
extract_archive(file_path, os.path.join(div2k_dir, split))
# 下载Flickr2K
flickr_dir = os.path.join(base_dir, "Flickr2K")
flickr_url = DATASETS["Flickr2K"]["url"]
file_path = download_dataset(flickr_url, flickr_dir)
extract_archive(file_path, flickr_dir)
print("所有数据集准备完成")
高效数据加载策略
对于大型数据集,需要优化数据加载流程以充分利用GPU:
from torch.utils.data import ConcatDataset
class CombinedDataset(Dataset):
"""组合多个数据集的包装类"""
def __init__(self, dataset_paths, scale_factor=4, patch_size=128, train=True,
augment=True, cache_in_memory=False):
self.datasets = []
for path in dataset_paths:
# 路径检查
if os.path.exists(path):
dataset = SuperResolutionDataset(
path,
scale_factor=scale_factor,
patch_size=patch_size,
train=train,
augment=augment,
cache_in_memory=cache_in_memory
)
self.datasets.append(dataset)
else:
print(f"警告: 数据集路径不存在: {path}")
if not self.datasets:
raise ValueError("没有有效的数据集路径")
self.combined = ConcatDataset(self.datasets)
def __len__(self):
return len(self.combined)
def __getitem__(self, idx):
return self.combined[idx]
# 优化的数据加载器
def create_optimized_dataloaders(batch_size, num_workers=8, pin_memory=True):
# 组合多个大型数据集
dataset_paths = [
os.path.join(args.dataset_path, "DIV2K"),
os.path.join(args.dataset_path, "Flickr2K")
]
train_dataset = CombinedDataset(
dataset_paths,
scale_factor=args.scale_factor,
patch_size=args.patch_size,
train=True
)
val_dataset = SuperResolutionDataset(
os.path.join(args.dataset_path, "DIV2K"),
train=False,
scale_factor=args.scale_factor,
patch_size=args.patch_size
)
# 使用预加载和多进程加速
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=pin_memory,
prefetch_factor=2, # 预加载下一批数据
persistent_workers=True # 保持工作进程存活
)
val_loader = DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=pin_memory
)
return train_loader, val_loader
复杂模型实现:ESRGAN
相比SRCNN,ESRGAN(Enhanced Super-Resolution Generative Adversarial Networks)能生成更富细节的高分辨率图像。我们实现其核心结构:
class ResidualDenseBlock(nn.Module):
"""残差密集块,ESRGAN的核心组件"""
def __init__(self, nf=64, gc=32, bias=True):
super(ResidualDenseBlock, self).__init__()
self.conv1 = nn.Conv2d(nf + 0 * gc, gc, 3, 1, 1, bias=bias)
self.conv2 = nn.Conv2d(nf + 1 * gc, gc, 3, 1, 1, bias=bias)
self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
# 初始化权重
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
def forward(self, x):
x1 = self.lrelu(self.conv1(x))
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
# 残差连接
return x5 * 0.2 + x
class RRDB(nn.Module):
"""残差在残差密集块"""
def __init__(self, nf, gc=32):
super(RRDB, self).__init__()
self.rdb1 = ResidualDenseBlock(nf, gc)
self.rdb2 = ResidualDenseBlock(nf, gc)
self.rdb3 = ResidualDenseBlock(nf, gc)
def forward(self, x):
out = self.rdb1(x)
out = self.rdb2(out)
out = self.rdb3(out)
# 残差连接
return out * 0.2 + x
class RRDBNet(nn.Module):
"""ESRGAN 生成器的基础模块(RRDB 网络)"""
def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4):
super(RRDBNet, self).__init__()
self.scale = scale
# 示例结构:卷积 + RRDB块 + 上采样 + 输出卷积
self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1, bias=True)
self.body = self._make_rrdb_blocks(num_feat, num_block, num_grow_ch)
self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
self.upsampler = self._make_upsampler(num_feat, scale)
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1, bias=True)
def _make_rrdb_blocks(self, num_feat, num_block, num_grow_ch):
blocks = []
for _ in range(num_block):
blocks.append(RRDB(num_feat, num_grow_ch))
return nn.Sequential(*blocks)
def _make_upsampler(self, num_feat, scale):
# 实现上采样模块
upsampler = []
for _ in range(int(torch.log2(torch.tensor(scale)))):
upsampler.append(nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1, bias=True))
upsampler.append(nn.PixelShuffle(2))
return nn.Sequential(*upsampler)
def forward(self, x):
# 实现前向传播逻辑
feat = self.conv_first(x)
body_feat = self.conv_body(self.body(feat))
feat = feat + body_feat
out = self.conv_last(self.upsampler(feat))
return out
# 定义ESRGAN生成器(继承RRDB网络,保持接口一致性)
class ESRGAN(RRDBNet):
"""ESRGAN生成器类"""
def __init__(self, scale_factor=4, num_block=23, num_grow_ch=32, **kwargs):
super(ESRGAN, self).__init__(
scale=scale_factor,
num_block=num_block,
num_grow_ch=num_grow_ch,** kwargs
)
self.scale_factor = scale_factor
self.conv_first = nn.Conv2d(3, 64, 3, 1, 1, bias=True)
# 保存参数为实例变量,供后续调用
self.num_rrdb_blocks = num_block # RRDB块数量
self.num_grow_ch = num_grow_ch # 增长通道数
# 正确调用_make_rrdb_blocks,使用实例变量
self.RRDB_trunk = self._make_rrdb_blocks(64, self.num_rrdb_blocks, self.num_grow_ch)
self.trunk_conv = nn.Conv2d(64, 64, 3, 1, 1, bias=True)
self.HRconv = nn.Conv2d(64, 64, 3, 1, 1, bias=True)
self.conv_last = nn.Conv2d(64, 3, 3, 1, 1, bias=True)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
# 根据缩放因子添加上采样层
self.upsampler = self._make_upsampler(64, scale_factor) # 复用父类的上采样方法
def forward(self, x):
fea = self.conv_first(x)
trunk = self.trunk_conv(self.RRDB_trunk(fea))
fea = fea + trunk
# 上采样逻辑
fea = self.upsampler(fea) # 先上采样到高分辨率尺寸
fea = self.lrelu(self.HRconv(fea))
out = self.conv_last(fea)
return out
生成对抗训练策略
ESRGAN使用GAN损失函数,需要定义生成器和判别器:
# 判别器定义
class Discriminator(nn.Module):
def __init__(self, num_in_ch=3, num_feat=64, skip_connection=True):
super(Discriminator, self).__init__()
self.skip_connection = skip_connection
self.features = nn.Sequential(
# 第一层:输入为3通道(RGB图像),输出64通道
nn.Conv2d(num_in_ch, num_feat, 3, 1, 1),
nn.LeakyReLU(0.2, True),
# 第二层:输入64通道(承接上一层),输出64通道,步长2(下采样)
nn.Conv2d(num_feat, num_feat, 3, 2, 1),
nn.BatchNorm2d(num_feat),
nn.LeakyReLU(0.2, True),
# 第三层:输入64通道,输出128通道
nn.Conv2d(num_feat, num_feat * 2, 3, 1, 1), # 64 -> 128
nn.BatchNorm2d(num_feat * 2),
nn.LeakyReLU(0.2, True),
# 第四层:输入128通道(承接上一层),输出1280通道,步长2
nn.Conv2d(num_feat * 2, num_feat * 2, 3, 2, 1), # 128 -> 128
nn.BatchNorm2d(num_feat * 2),
nn.LeakyReLU(0.2, True),
# 第五层:输入128通道,输出256通道
nn.Conv2d(num_feat * 2, num_feat * 4, 3, 1, 1), # 128 -> 256
nn.BatchNorm2d(num_feat * 4),
nn.LeakyReLU(0.2, True),
# 第六层:输入256通道,输出256通道,步长2
nn.Conv2d(num_feat * 4, num_feat * 4, 3, 2, 1), # 256 -> 256
nn.BatchNorm2d(num_feat * 4),
nn.LeakyReLU(0.2, True),
# 第七层:输入256通道,输出512通道
nn.Conv2d(num_feat * 4, num_feat * 8 ,3, 1, 1), # 256 -> 512
nn.BatchNorm2d(num_feat * 8),
nn.LeakyReLU(0.2, True),
# 第八层:输入512通道,输出512通道,步长2
nn.Conv2d(num_feat * 8, num_feat * 8, 3, 2, 1), # 512 -> 512
nn.BatchNorm2d(num_feat * 8),
nn.LeakyReLU(0.2, True),
)
# 分类头(判断真假)
self.classifier = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(num_feat * 8, num_feat * 16, 1, 1, 0),
nn.LeakyReLU(0.2, True),
nn.Conv2d(num_feat * 16, 1, 1, 1, 0),
)
def forward(self, x):
x = self.features(x)
x = self.classifier(x)
return x
# 混合损失函数
class ContentLoss(nn.Module):
def __init__(self):
super(ContentLoss, self).__init__()
# 使用预训练的VGG作为特征提取器
vgg = torchvision.models.vgg19(pretrained=True).features[:35].eval()
for param in vgg.parameters():
param.requires_grad = False
self.vgg = vgg.to(device)
self.criterion = nn.L1Loss()
self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
def forward(self, sr, hr):
# 归一化输入以匹配VGG训练条件
sr_norm = self.normalize(sr)
hr_norm = self.normalize(hr)
# 提取特征
sr_feat = self.vgg(sr_norm)
hr_feat = self.vgg(hr_norm)
return self.criterion(sr_feat, hr_feat)
GPU训练优化技巧
混合精度训练
from torch.amp import GradScaler, autocast
def train():
# 解析输入参数
args = parse_args()
# 初始化日志
train_logger = setup_logger(
logger_name="train",
log_file="train.log",
log_dir="logs/train",
level=logging.DEBUG # 调试级别,输出更详细信息
)
# 加载数据集
if args.download_datasets:
train_logger.info("开始自动下载数据集...")
prepare_large_datasets(args.dataset_path) # 下载到指定路径
train_logger.info("数据集下载完成")
train_loader, val_loader = create_optimized_dataloaders(
batch_size=args.batch_size,
num_workers=args.num_workers,
pin_memory=args.pin_memory
)
train_logger.info(f"数据集加载完成 - 训练集: {len(train_loader.dataset)} 样本, 验证集: {len(val_loader.dataset)} 样本")
# 初始化TensorBoard
tb_writer = init_tensorboard(os.path.join(args.log_dir, "tensorboard"))
# 设置设备
device = torch.device(args.device)
train_logger.info(f"使用设备: {device}")
# 初始化模型
generator = ESRGAN(scale_factor=args.scale_factor).to(device)
discriminator = Discriminator().to(device)
# 初始化损失函数
content_criterion = ContentLoss(device)
gan_criterion = GANLoss(gan_type='vanilla').to(device)
# 初始化优化器
g_optimizer = optim.Adam(generator.parameters(), lr=args.lr, betas=(0.9, 0.999))
d_optimizer = optim.Adam(discriminator.parameters(), lr=args.lr, betas=(0.9, 0.999))
# 学习率调度器
g_scheduler = CosineAnnealingWarmRestarts(g_optimizer, T_0=100, T_mult=2)
d_scheduler = CosineAnnealingWarmRestarts(d_optimizer, T_0=100, T_mult=2)
# 混合精度训练
scaler = GradScaler('cuda', enabled=args.use_amp) # 显式指定cuda设备,虽然默认也是cuda
# 恢复训练(如果有 checkpoint)
start_epoch = 0
if args.resume:
if os.path.isfile(args.resume):
checkpoint = torch.load(args.resume, map_location=device)
generator.load_state_dict(checkpoint['generator_state_dict'])
discriminator.load_state_dict(checkpoint['discriminator_state_dict'])
g_optimizer.load_state_dict(checkpoint['g_optimizer_state_dict'])
d_optimizer.load_state_dict(checkpoint['d_optimizer_state_dict'])
start_epoch = checkpoint['epoch'] + 1
train_logger.info(f"从检查点恢复训练: {args.resume}, 开始于 epoch {start_epoch}")
else:
train_logger.warning(f"未找到检查点文件: {args.resume}, 从头开始训练")
# 训练循环
train_logger.info("开始训练...")
for epoch in range(start_epoch, args.epochs):
generator.train()
discriminator.train()
total_g_loss = 0.0
total_d_loss = 0.0
# 进度条
pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{args.epochs}")
for batch_idx, (lr_imgs, hr_imgs) in enumerate(pbar):
lr_imgs = lr_imgs.to(device)
hr_imgs = hr_imgs.to(device)
# ---------------------
# 训练生成器
# ---------------------
g_optimizer.zero_grad() # 初始化梯度
grad_accum_steps = 4
with autocast('cuda', enabled=args.use_amp): # 确保在autocase上下文内计算损失
# 生成超分辨率图像(此时sr_imgs尺寸已正确放大)
sr_imgs = generator(lr_imgs)
# 计算生成器损失
print(f"SR尺寸:{sr_imgs.shape}, HR尺寸:{hr_imgs.shape}")
assert sr_imgs.shape == hr_imgs.shape, "SR与HR尺寸不匹配!"
content_loss = content_criterion(sr_imgs, hr_imgs) # 现在尺寸匹配,可正常计算
fake_pred = discriminator(sr_imgs)
gan_loss = gan_criterion(fake_pred, True)
# 总生成器损失 (内容损失权重更高)
g_loss = content_loss * 0.01 + gan_loss * 0.005
# 梯度累计逻辑(在损失计算后执行)
scaled_loss = g_loss / grad_accum_steps # 平均损失
scaler.scale(scaled_loss).backward(retain_graph=True)
if (batch_idx + 1) % grad_accum_steps == 0:
scaler.step(g_optimizer)
scaler.update()
g_optimizer.zero_grad() # 累积结束后梯度清零
# 反向传播和优化
scaler.scale(g_loss).backward(retain_graph=True)
scaler.step(g_optimizer)
# ---------------------
# 训练判别器
# ---------------------
d_optimizer.zero_grad()
# 注意:必须显式指定设备类型
with autocast('cuda', enabled=args.use_amp):
# 真实图像损失
real_pred = discriminator(hr_imgs)
real_loss = gan_criterion(real_pred, True)
# 生成图像损失
fake_pred = discriminator(sr_imgs.detach()) # detach 避免更新生成器
fake_loss = gan_criterion(fake_pred, False)
# 总判别器损失
d_loss = (real_loss + fake_loss) * 0.5
# 反向传播和优化
scaler.scale(d_loss).backward()
scaler.step(d_optimizer)
scaler.update()
# 累计损失
total_g_loss += g_loss.item()
total_d_loss += d_loss.item()
# 日志
if batch_idx % args.log_interval == 0:
avg_g_loss = total_g_loss / (batch_idx + 1)
avg_d_loss = total_d_loss / (batch_idx + 1)
pbar.set_postfix({"G Loss": f"{avg_g_loss:.4f}", "D Loss": f"{avg_d_loss:.4f}"})
# 记录TensorBoard
global_step = epoch * len(train_loader) + batch_idx
tb_writer.add_scalar('Loss/Generator', g_loss.item(), global_step)
tb_writer.add_scalar('Loss/Discriminator', d_loss.item(), global_step)
# 每个epoch结束后更新学习率
g_scheduler.step()
d_scheduler.step()
# 计算平均损失
avg_g_loss_epoch = total_g_loss / len(train_loader)
avg_d_loss_epoch = total_d_loss / len(train_loader)
train_logger.info(f"Epoch {epoch+1} - G Loss: {avg_g_loss_epoch:.4f}, D Loss: {avg_d_loss_epoch:.4f}")
# 保存检查点
if (epoch + 1) % args.save_freq == 0:
save_checkpoint(
epoch + 1,
generator,
discriminator,
g_optimizer,
d_optimizer,
args.checkpoint_dir,
train_logger
)
# 验证
if (epoch + 1) % args.val_interval == 0:
generator.eval()
val_loss = 0.0
with torch.no_grad():
for lr_imgs, hr_imgs in val_loader:
lr_imgs = lr_imgs.to(device)
hr_imgs = hr_imgs.to(device)
sr_imgs = generator(lr_imgs)
loss = content_criterion(sr_imgs, hr_imgs)
val_loss += loss.item()
avg_val_loss = val_loss / len(val_loader)
train_logger.info(f"验证损失: {avg_val_loss:.4f}")
tb_writer.add_scalar('Loss/Validation', avg_val_loss, epoch)
# 记录示例图像
tb_writer.add_images('LR Images', lr_imgs[:4], epoch)
tb_writer.add_images('HR Images', hr_imgs[:4], epoch)
tb_writer.add_images('SR Images', sr_imgs[:4], epoch, dataformats='NCHW')
# 训练结束
train_logger.info("训练完成!")
tb_writer.close()
梯度累积与学习率调度
def main():
args = parse_args()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 设置随机种子,确保可复现性
torch.manual_seed(args.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(args.seed)
# 训练循环
train()
扩展运行脚本
针对大型实验的增强版运行脚本:
#!/bin/bash
# run_esrgan_large.sh
echo "启动ESRGAN大型训练任务..."
# 设置工作目录(按需选择)
cd /home/vscode/workspace/py-torch-learning/codes/esrgan-project
# 创建目录
mkdir -p data checkpoints logs
# 安装依赖
pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
# 记录开始时间
start_time=$(date +%s)
echo "实验开始时间: $(date)"
# 检查GPU状态
nvidia-smi
# 启动训练
nohup python3 -u main.py \
--epochs 100 \
--batch_size 8 \
--dataset_path ./data \
--checkpoint_dir ./checkpoints \
--download_datasets \
> training_log_esrgan_$(date +%Y%m%d_%H%M%S).txt 2>&1 &
# 记录进程ID和日志文件
echo "训练任务已在后台启动,PID: $!"
log_file="training_log_esrgan_$(date +%Y%m%d_%H%M%S).txt"
echo "日志文件: $log_file"
# 监控GPU使用情况(每5分钟记录一次)
while true; do
echo "GPU监控: $(date)" >> $log_file
nvidia-smi >> $log_file 2>&1
sleep 300 # 5分钟
done &
实验监控与分析
使用TensorBoard可视化训练过程:
from torch.utils.tensorboard import SummaryWriter
def init_tensorboard(log_dir):
"""初始化TensorBoard"""
writer = SummaryWriter(log_dir=log_dir)
return writer
def log_to_tensorboard(writer, epoch, train_metrics, val_metrics, images):
"""将训练指标和图像写入TensorBoard"""
# 日志指标
writer.add_scalar('Loss/Generator', train_metrics['gen_loss'], epoch)
writer.add_scalar('Loss/Discriminator', train_metrics['dis_loss'], epoch)
writer.add_scalar('PSNR/Train', train_metrics['psnr'], epoch)
writer.add_scalar('PSNR/Validation', val_metrics['psnr'], epoch)
writer.add_scalar('LearningRate/Generator',
train_metrics['gen_lr'], epoch)
# 日志图像(每10个epoch)
if epoch % 10 == 0:
lr_img, sr_img, hr_img = images
writer.add_image('Input/LowResolution', lr_img, epoch)
writer.add_image('Output/SuperResolution', sr_img, epoch)
writer.add_image('Target/HighResolution', hr_img, epoch)
输出结果汇总
- 核心输出:模型检查点(Checkpoint)
内容:训练过程中保存的生成器(Generator)和判别器(Discriminator)的权重参数、优化器状态、训练轮次等。
路径:由 config.py 中的 --checkpoint_dir 参数指定,默认路径为:
./checkpoints/
文件名格式为 checkpoint_epoch_{epoch}.pth(例如 checkpoint_epoch_10.pth)。
触发时机:每训练 --save_freq 轮(默认 10 轮)保存一次,可通过命令行参数调整。 - 日志记录
内容:训练过程中的损失值、验证指标、关键操作日志(如下载 / 解压进度、模型加载信息等)。
路径:
文本日志:由 config.py 中的 --log_dir 参数指定,默认路径为 ./logs/train/train.log。
终端输出日志:运行脚本 run_esrgan_large.sh 时,会重定向到 training_log_esrgan_$(date).txt(与脚本同目录)。 - TensorBoard 可视化结果
内容:训练 / 验证损失曲线、生成的超分辨率图像(LR 输入、HR 真实值、SR 预测值对比)。
路径:默认存储在 ./logs/tensorboard/(由 main.py 中 init_tensorboard 函数指定,基于 --log_dir 参数)。
查看方式:运行 tensorboard --logdir=./logs/tensorboard 后在浏览器访问本地端口。
注意tensorboard和protobuf的版本要匹配 - 超分辨率结果图像(可选)
内容:验证阶段或推理时生成的超分辨率图像(SR Images)。
路径:由 config.py 中的 --result_dir 参数指定,默认路径为 ./results/(代码中已初始化该目录,可在推理逻辑中补充保存图像的代码)。
总结与后续方向
通过本篇实验,我们实现了:
1.** 大型数据集管理 :自动下载、解压和组合多个大型数据集,优化数据加载流程
2. 复杂模型构建 :实现了基于残差密集块的ESRGAN模型,相比SRCNN能生成更丰富的细节
3. 高级训练策略 :引入混合精度训练、梯度累积和余弦退火学习率调度,提升GPU利用率
4. 完善监控体系 **:结合日志文件、GPU监控和TensorBoard可视化,全面跟踪实验过程
后续可探索的方向:
- 尝试更大规模的模型(如RCAN、SwinIR)
- 引入感知损失和GAN的改进变体(如Relativistic GAN)
- 实现模型并行和数据并行,利用多GPU进行训练
- 探索模型压缩和加速技术,实现实时超分辨率
- 尝试视频超分辨率任务,考虑时间维度的一致性
下一篇我们将探索更前沿的视觉Transformer模型在超分辨率任务中的应用,进一步提升重建质量。
浙公网安备 33010602011771号