重生之从零开始的神经网络算法学习之路——第八篇 大型数据集与复杂模型的GPU训练实践

引言

在前一篇中,我们实现了基础的SRCNN超分辨率模型并掌握了后台训练技巧。本篇将进一步拓展实验规模:引入更大规模的数据集、实现更复杂的网络结构,并优化GPU训练策略,以应对更具挑战性的图像重建任务。通过这些实践,我们将深入理解大规模深度学习实验的关键技术和工程细节。

大型数据集的获取与处理

适合超分辨率任务的大型数据集

为了提升模型泛化能力,我们可以使用以下大型数据集:

1.** DIV2K扩展集 :包含1000张高分辨率训练图像和100张验证图像(2K分辨率)
2.
Flickr2K :2650张来自Flickr的高分辨率自然图像(4K及以上)
3.
CelebA-HQ :30,000张高质量人脸图像(1024x1024分辨率)
4.
ImageNet **:百万级通用图像数据集(可用于预训练)

代码实现

git clone https://gitee.com/cmx1998/py-torch-learning.git
cd py-torch-learning/codes/esrgan-project

自动下载与解压实现

import os
import wget
import zipfile
import tarfile
from tqdm import tqdm

# 数据集下载配置
DATASETS = {
    "DIV2K": {
        "train": "http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_HR.zip",
        "valid": "http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_valid_HR.zip"
    },
    "Flickr2K": {
        "url": "https://cv.snu.ac.kr/research/EDSR/Flickr2K.tar"
    }
}

def download_dataset(url, save_dir, filename=None):
    """带进度条的数据集下载函数"""
    # 根据输入参数save_dir,新建存储目录
    os.makedirs(save_dir, exist_ok=True)
    
    # 传入文件名参数检查
    if not filename:
        filename = url.split("/")[-1]
    save_path = os.path.join(save_dir, filename)
    
    # 文件存在性检查
    if os.path.exists(save_path):
        print(f"文件 {filename} 已存在,跳过下载")
        return save_path

    print(f"开始下载 {filename}...")
    # 使用tqdm创建进度条,减少输出频率
    with tqdm(total=100, desc=f"下载 {filename}", unit="%") as pbar:
        def progress_bar(current, total, width=80):
            progress = current / total * 100
            pbar.n = int(progress)
            pbar.update(0)      # 只更新进度条显示,不产生新输出
            
        wget.download(url, save_path, bar=progress_bar)
    print(f"\n{filename} 下载完成")
    return save_path

def extract_archive(file_path, extract_dir):
    """解压数据集文件"""
    # 根据输入参数extract_dir,新建解压目录
    os.makedirs(extract_dir, exist_ok=True)
    filename = os.path.basename(file_path)
    # 生成解压后根目录的标识(根据压缩包名判断)
    extract_flag = os.path.join(extract_dir, f".{filename}.extracted")  # 标记文件
    
    # 检查是否已解压(通过标记文件判断)
    if os.path.exists(extract_flag):
        print(f"文件 {filename} 已解压,跳过解压!")
        return
    
    # 执行解压
    try:
        # 文件名后缀检查
        if file_path.endswith(".zip"):
            """处理zip压缩文件"""
            with zipfile.ZipFile(file_path, 'r') as zip_ref:
                # 显示解压进度
                for file in tqdm(zip_ref.namelist(), desc="解压中"):
                    zip_ref.extract(file, extract_dir)
        elif file_path.endswith(".tar") or file_path.endswith(".tar.gz"):
            """处理tar压缩文件"""
            with tarfile.open(file_path, 'r') as tar_ref:
                # 显示解压进度
                members = tar_ref.getmembers()
                for member in tqdm(members, desc="解压中"):
                    tar_ref.extract(member, extract_dir)
                    
        # 解压成功后创建标记文件
        with open(extract_flag, 'w') as f:
            f.write("Extracted successfully")
        print(f"文件 {os.path.basename(file_path)} 解压完成")
    except Exception as e:
        print(f"解压失败:{e}")
        # 失败时删除标记文件(避免误判)
        if os.path.exists(extract_flag):
            os.remove(extract_flag)

def prepare_large_datasets(base_dir):
    """准备所有大型数据集"""
    # 下载DIV2K
    div2k_dir = os.path.join(base_dir, "DIV2K")
    for split, url in DATASETS["DIV2K"].items():
        file_path = download_dataset(url, div2k_dir)
        extract_archive(file_path, os.path.join(div2k_dir, split))
    
    # 下载Flickr2K
    flickr_dir = os.path.join(base_dir, "Flickr2K")
    flickr_url = DATASETS["Flickr2K"]["url"]
    file_path = download_dataset(flickr_url, flickr_dir)
    extract_archive(file_path, flickr_dir)
    
    print("所有数据集准备完成")

高效数据加载策略

对于大型数据集,需要优化数据加载流程以充分利用GPU:

from torch.utils.data import ConcatDataset

class CombinedDataset(Dataset):
    """组合多个数据集的包装类"""
    def __init__(self, dataset_paths, scale_factor=4, patch_size=128, train=True, 
                 augment=True, cache_in_memory=False):
        self.datasets = []
        for path in dataset_paths:
            # 路径检查
            if os.path.exists(path):
                dataset = SuperResolutionDataset(
                    path, 
                    scale_factor=scale_factor,
                    patch_size=patch_size,
                    train=train,
                    augment=augment,
                    cache_in_memory=cache_in_memory
                )
                self.datasets.append(dataset)
            else:
                print(f"警告: 数据集路径不存在: {path}")
                
        
        if not self.datasets:
            raise ValueError("没有有效的数据集路径")
        self.combined = ConcatDataset(self.datasets)
        
    def __len__(self):
        return len(self.combined)
    
    def __getitem__(self, idx):
        return self.combined[idx]

# 优化的数据加载器
def create_optimized_dataloaders(batch_size, num_workers=8, pin_memory=True):
    # 组合多个大型数据集
    dataset_paths = [
        os.path.join(args.dataset_path, "DIV2K"),
        os.path.join(args.dataset_path, "Flickr2K")
    ]
    
    train_dataset = CombinedDataset(
        dataset_paths,
        scale_factor=args.scale_factor,
        patch_size=args.patch_size,
        train=True
    )
    
    val_dataset = SuperResolutionDataset(
        os.path.join(args.dataset_path, "DIV2K"),
        train=False,
        scale_factor=args.scale_factor,
        patch_size=args.patch_size
    )
    
    # 使用预加载和多进程加速
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=pin_memory,
        prefetch_factor=2,  # 预加载下一批数据
        persistent_workers=True  # 保持工作进程存活
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=pin_memory
    )
    
    return train_loader, val_loader

复杂模型实现:ESRGAN

相比SRCNN,ESRGAN(Enhanced Super-Resolution Generative Adversarial Networks)能生成更富细节的高分辨率图像。我们实现其核心结构:

class ResidualDenseBlock(nn.Module):
    """残差密集块,ESRGAN的核心组件"""
    def __init__(self, nf=64, gc=32, bias=True):
        super(ResidualDenseBlock, self).__init__()
        self.conv1 = nn.Conv2d(nf + 0 * gc, gc, 3, 1, 1, bias=bias)
        self.conv2 = nn.Conv2d(nf + 1 * gc, gc, 3, 1, 1, bias=bias)
        self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
        self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
        self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
        self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
        
        # 初始化权重
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')

    def forward(self, x):
        x1 = self.lrelu(self.conv1(x))
        x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
        x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
        x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
        x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
        # 残差连接
        return x5 * 0.2 + x

class RRDB(nn.Module):
    """残差在残差密集块"""
    def __init__(self, nf, gc=32):
        super(RRDB, self).__init__()
        self.rdb1 = ResidualDenseBlock(nf, gc)
        self.rdb2 = ResidualDenseBlock(nf, gc)
        self.rdb3 = ResidualDenseBlock(nf, gc)

    def forward(self, x):
        out = self.rdb1(x)
        out = self.rdb2(out)
        out = self.rdb3(out)
        # 残差连接
        return out * 0.2 + x
class RRDBNet(nn.Module):
    """ESRGAN 生成器的基础模块(RRDB 网络)"""
    def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4):
        super(RRDBNet, self).__init__()
        self.scale = scale
        # 示例结构:卷积 + RRDB块 + 上采样 + 输出卷积
        self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1, bias=True)
        self.body = self._make_rrdb_blocks(num_feat, num_block, num_grow_ch)
        self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
        self.upsampler = self._make_upsampler(num_feat, scale)
        self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1, bias=True)

    def _make_rrdb_blocks(self, num_feat, num_block, num_grow_ch):
        blocks = []
        for _ in range(num_block):
            blocks.append(RRDB(num_feat, num_grow_ch))
        return nn.Sequential(*blocks)

    def _make_upsampler(self, num_feat, scale):
        # 实现上采样模块
        upsampler = []
        for _ in range(int(torch.log2(torch.tensor(scale)))):
            upsampler.append(nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1, bias=True))
            upsampler.append(nn.PixelShuffle(2))
        return nn.Sequential(*upsampler)

    def forward(self, x):
        # 实现前向传播逻辑
        feat = self.conv_first(x)
        body_feat = self.conv_body(self.body(feat))
        feat = feat + body_feat
        out = self.conv_last(self.upsampler(feat))
        return out

# 定义ESRGAN生成器(继承RRDB网络,保持接口一致性)
class ESRGAN(RRDBNet):
    """ESRGAN生成器类"""
    def __init__(self, scale_factor=4, num_block=23, num_grow_ch=32, **kwargs):
        super(ESRGAN, self).__init__(
            scale=scale_factor,
            num_block=num_block,
            num_grow_ch=num_grow_ch,** kwargs
        )
        self.scale_factor = scale_factor
        self.conv_first = nn.Conv2d(3, 64, 3, 1, 1, bias=True)
        # 保存参数为实例变量,供后续调用
        self.num_rrdb_blocks = num_block  # RRDB块数量
        self.num_grow_ch = num_grow_ch    # 增长通道数
        # 正确调用_make_rrdb_blocks,使用实例变量
        self.RRDB_trunk = self._make_rrdb_blocks(64, self.num_rrdb_blocks, self.num_grow_ch)
        self.trunk_conv = nn.Conv2d(64, 64, 3, 1, 1, bias=True)
        self.HRconv = nn.Conv2d(64, 64, 3, 1, 1, bias=True)
        self.conv_last = nn.Conv2d(64, 3, 3, 1, 1, bias=True)
        self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
        
        # 根据缩放因子添加上采样层
        self.upsampler = self._make_upsampler(64, scale_factor)     # 复用父类的上采样方法

    def forward(self, x):
        fea = self.conv_first(x)
        trunk = self.trunk_conv(self.RRDB_trunk(fea))
        fea = fea + trunk
        # 上采样逻辑
        fea = self.upsampler(fea)           # 先上采样到高分辨率尺寸
        
        fea = self.lrelu(self.HRconv(fea))
        out = self.conv_last(fea)
        return out

生成对抗训练策略

ESRGAN使用GAN损失函数,需要定义生成器和判别器:

# 判别器定义
class Discriminator(nn.Module):
    def __init__(self, num_in_ch=3, num_feat=64, skip_connection=True):
        super(Discriminator, self).__init__()
        self.skip_connection = skip_connection
        
        self.features = nn.Sequential(
            # 第一层:输入为3通道(RGB图像),输出64通道
            nn.Conv2d(num_in_ch, num_feat, 3, 1, 1),
            nn.LeakyReLU(0.2, True),
            
            # 第二层:输入64通道(承接上一层),输出64通道,步长2(下采样)
            nn.Conv2d(num_feat, num_feat, 3, 2, 1),
            nn.BatchNorm2d(num_feat),
            nn.LeakyReLU(0.2, True),
            
            # 第三层:输入64通道,输出128通道
            nn.Conv2d(num_feat, num_feat * 2, 3, 1, 1),     # 64 -> 128
            nn.BatchNorm2d(num_feat * 2),
            nn.LeakyReLU(0.2, True),
            
            # 第四层:输入128通道(承接上一层),输出1280通道,步长2
            nn.Conv2d(num_feat * 2, num_feat * 2, 3, 2, 1), # 128 -> 128
            nn.BatchNorm2d(num_feat * 2),
            nn.LeakyReLU(0.2, True),
            
            # 第五层:输入128通道,输出256通道
            nn.Conv2d(num_feat * 2, num_feat * 4, 3, 1, 1), # 128 -> 256
            nn.BatchNorm2d(num_feat * 4),
            nn.LeakyReLU(0.2, True),
            
            # 第六层:输入256通道,输出256通道,步长2
            nn.Conv2d(num_feat * 4, num_feat * 4, 3, 2, 1), # 256 -> 256
            nn.BatchNorm2d(num_feat * 4),
            nn.LeakyReLU(0.2, True),
            
            # 第七层:输入256通道,输出512通道
            nn.Conv2d(num_feat * 4, num_feat * 8 ,3, 1, 1), # 256 -> 512
            nn.BatchNorm2d(num_feat * 8),
            nn.LeakyReLU(0.2, True),
            
            # 第八层:输入512通道,输出512通道,步长2
            nn.Conv2d(num_feat * 8, num_feat * 8, 3, 2, 1), # 512 -> 512
            nn.BatchNorm2d(num_feat * 8),
            nn.LeakyReLU(0.2, True),
        )
        
        # 分类头(判断真假)
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(num_feat * 8, num_feat * 16, 1, 1, 0),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(num_feat * 16, 1, 1, 1, 0),
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x
# 混合损失函数
class ContentLoss(nn.Module):
    def __init__(self):
        super(ContentLoss, self).__init__()
        # 使用预训练的VGG作为特征提取器
        vgg = torchvision.models.vgg19(pretrained=True).features[:35].eval()
        for param in vgg.parameters():
            param.requires_grad = False
        self.vgg = vgg.to(device)
        self.criterion = nn.L1Loss()
        self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                             std=[0.229, 0.224, 0.225])

    def forward(self, sr, hr):
        # 归一化输入以匹配VGG训练条件
        sr_norm = self.normalize(sr)
        hr_norm = self.normalize(hr)
        # 提取特征
        sr_feat = self.vgg(sr_norm)
        hr_feat = self.vgg(hr_norm)
        return self.criterion(sr_feat, hr_feat)

GPU训练优化技巧

混合精度训练

from torch.amp import GradScaler, autocast

def train():
    # 解析输入参数
    args = parse_args()
    
    # 初始化日志
    train_logger = setup_logger(
        logger_name="train",
        log_file="train.log",
        log_dir="logs/train",
        level=logging.DEBUG  # 调试级别,输出更详细信息
    )
    
    # 加载数据集
    if args.download_datasets:
        train_logger.info("开始自动下载数据集...")
        prepare_large_datasets(args.dataset_path)  # 下载到指定路径
        train_logger.info("数据集下载完成")

    train_loader, val_loader = create_optimized_dataloaders(
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        pin_memory=args.pin_memory
    )
    train_logger.info(f"数据集加载完成 - 训练集: {len(train_loader.dataset)} 样本, 验证集: {len(val_loader.dataset)} 样本")
    
    # 初始化TensorBoard
    tb_writer = init_tensorboard(os.path.join(args.log_dir, "tensorboard"))
    
    # 设置设备
    device = torch.device(args.device)
    train_logger.info(f"使用设备: {device}")
    
    # 初始化模型
    generator = ESRGAN(scale_factor=args.scale_factor).to(device)
    discriminator = Discriminator().to(device)
    
    # 初始化损失函数
    content_criterion = ContentLoss(device)
    gan_criterion = GANLoss(gan_type='vanilla').to(device)
    
    # 初始化优化器
    g_optimizer = optim.Adam(generator.parameters(), lr=args.lr, betas=(0.9, 0.999))
    d_optimizer = optim.Adam(discriminator.parameters(), lr=args.lr, betas=(0.9, 0.999))
    
    # 学习率调度器
    g_scheduler = CosineAnnealingWarmRestarts(g_optimizer, T_0=100, T_mult=2)
    d_scheduler = CosineAnnealingWarmRestarts(d_optimizer, T_0=100, T_mult=2)
    
    # 混合精度训练
    scaler = GradScaler('cuda', enabled=args.use_amp)       # 显式指定cuda设备,虽然默认也是cuda
    
    # 恢复训练(如果有 checkpoint)
    start_epoch = 0
    if args.resume:
        if os.path.isfile(args.resume):
            checkpoint = torch.load(args.resume, map_location=device)
            generator.load_state_dict(checkpoint['generator_state_dict'])
            discriminator.load_state_dict(checkpoint['discriminator_state_dict'])
            g_optimizer.load_state_dict(checkpoint['g_optimizer_state_dict'])
            d_optimizer.load_state_dict(checkpoint['d_optimizer_state_dict'])
            start_epoch = checkpoint['epoch'] + 1
            train_logger.info(f"从检查点恢复训练: {args.resume}, 开始于 epoch {start_epoch}")
        else:
            train_logger.warning(f"未找到检查点文件: {args.resume}, 从头开始训练")
    
    # 训练循环
    train_logger.info("开始训练...")
    for epoch in range(start_epoch, args.epochs):
        generator.train()
        discriminator.train()
        
        total_g_loss = 0.0
        total_d_loss = 0.0
        
        # 进度条
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{args.epochs}")
        
        for batch_idx, (lr_imgs, hr_imgs) in enumerate(pbar):
            lr_imgs = lr_imgs.to(device)
            hr_imgs = hr_imgs.to(device)
            
            # ---------------------
            #  训练生成器
            # ---------------------
            g_optimizer.zero_grad()     # 初始化梯度
            grad_accum_steps = 4
            
            with autocast('cuda', enabled=args.use_amp):    # 确保在autocase上下文内计算损失
                # 生成超分辨率图像(此时sr_imgs尺寸已正确放大)
                sr_imgs = generator(lr_imgs)
                
                # 计算生成器损失
                print(f"SR尺寸:{sr_imgs.shape}, HR尺寸:{hr_imgs.shape}")
                assert sr_imgs.shape == hr_imgs.shape, "SR与HR尺寸不匹配!"
                content_loss = content_criterion(sr_imgs, hr_imgs)  # 现在尺寸匹配,可正常计算
                fake_pred = discriminator(sr_imgs)
                gan_loss = gan_criterion(fake_pred, True)
                
                # 总生成器损失 (内容损失权重更高)
                g_loss = content_loss * 0.01 + gan_loss * 0.005
                
            # 梯度累计逻辑(在损失计算后执行)
            scaled_loss = g_loss / grad_accum_steps  # 平均损失
            scaler.scale(scaled_loss).backward(retain_graph=True)

            if (batch_idx + 1) % grad_accum_steps == 0:
                scaler.step(g_optimizer)
                scaler.update()
                g_optimizer.zero_grad()              # 累积结束后梯度清零
            
            # 反向传播和优化
            scaler.scale(g_loss).backward(retain_graph=True)
            scaler.step(g_optimizer)
            
            # ---------------------
            #  训练判别器
            # ---------------------
            d_optimizer.zero_grad()
            
            # 注意:必须显式指定设备类型
            with autocast('cuda', enabled=args.use_amp):
                # 真实图像损失
                real_pred = discriminator(hr_imgs)
                real_loss = gan_criterion(real_pred, True)
                
                # 生成图像损失
                fake_pred = discriminator(sr_imgs.detach())  #  detach 避免更新生成器
                fake_loss = gan_criterion(fake_pred, False)
                
                # 总判别器损失
                d_loss = (real_loss + fake_loss) * 0.5
            
            # 反向传播和优化
            scaler.scale(d_loss).backward()
            scaler.step(d_optimizer)
            scaler.update()
            
            # 累计损失
            total_g_loss += g_loss.item()
            total_d_loss += d_loss.item()
            
            # 日志
            if batch_idx % args.log_interval == 0:
                avg_g_loss = total_g_loss / (batch_idx + 1)
                avg_d_loss = total_d_loss / (batch_idx + 1)
                pbar.set_postfix({"G Loss": f"{avg_g_loss:.4f}", "D Loss": f"{avg_d_loss:.4f}"})
                
                # 记录TensorBoard
                global_step = epoch * len(train_loader) + batch_idx
                tb_writer.add_scalar('Loss/Generator', g_loss.item(), global_step)
                tb_writer.add_scalar('Loss/Discriminator', d_loss.item(), global_step)
        
        # 每个epoch结束后更新学习率
        g_scheduler.step()
        d_scheduler.step()
        
        # 计算平均损失
        avg_g_loss_epoch = total_g_loss / len(train_loader)
        avg_d_loss_epoch = total_d_loss / len(train_loader)
        train_logger.info(f"Epoch {epoch+1} - G Loss: {avg_g_loss_epoch:.4f}, D Loss: {avg_d_loss_epoch:.4f}")
        
        # 保存检查点
        if (epoch + 1) % args.save_freq == 0:
            save_checkpoint(
                epoch + 1,
                generator,
                discriminator,
                g_optimizer,
                d_optimizer,
                args.checkpoint_dir,
                train_logger
            )
        
        # 验证
        if (epoch + 1) % args.val_interval == 0:
            generator.eval()
            val_loss = 0.0
            
            with torch.no_grad():
                for lr_imgs, hr_imgs in val_loader:
                    lr_imgs = lr_imgs.to(device)
                    hr_imgs = hr_imgs.to(device)
                    
                    sr_imgs = generator(lr_imgs)
                    loss = content_criterion(sr_imgs, hr_imgs)
                    val_loss += loss.item()
            
            avg_val_loss = val_loss / len(val_loader)
            train_logger.info(f"验证损失: {avg_val_loss:.4f}")
            tb_writer.add_scalar('Loss/Validation', avg_val_loss, epoch)
            
            # 记录示例图像
            tb_writer.add_images('LR Images', lr_imgs[:4], epoch)
            tb_writer.add_images('HR Images', hr_imgs[:4], epoch)
            tb_writer.add_images('SR Images', sr_imgs[:4], epoch, dataformats='NCHW')
    
    # 训练结束
    train_logger.info("训练完成!")
    tb_writer.close()

梯度累积与学习率调度

def main():
    args = parse_args()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # 设置随机种子,确保可复现性
    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)
        
    # 训练循环
    train()

扩展运行脚本

针对大型实验的增强版运行脚本:

#!/bin/bash
# run_esrgan_large.sh
echo "启动ESRGAN大型训练任务..."

# 设置工作目录(按需选择)
cd /home/vscode/workspace/py-torch-learning/codes/esrgan-project

# 创建目录
mkdir -p data checkpoints logs

# 安装依赖
pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple

# 记录开始时间
start_time=$(date +%s)
echo "实验开始时间: $(date)"

# 检查GPU状态
nvidia-smi

# 启动训练
nohup python3 -u main.py \
    --epochs 100 \
    --batch_size 8 \
    --dataset_path ./data \
    --checkpoint_dir ./checkpoints \
    --download_datasets \
    > training_log_esrgan_$(date +%Y%m%d_%H%M%S).txt 2>&1 &

# 记录进程ID和日志文件
echo "训练任务已在后台启动,PID: $!"
log_file="training_log_esrgan_$(date +%Y%m%d_%H%M%S).txt"
echo "日志文件: $log_file"

# 监控GPU使用情况(每5分钟记录一次)
while true; do
    echo "GPU监控: $(date)" >> $log_file
    nvidia-smi >> $log_file 2>&1
    sleep 300  # 5分钟
done &

实验监控与分析

使用TensorBoard可视化训练过程:

from torch.utils.tensorboard import SummaryWriter

def init_tensorboard(log_dir):
    """初始化TensorBoard"""
    writer = SummaryWriter(log_dir=log_dir)
    return writer

def log_to_tensorboard(writer, epoch, train_metrics, val_metrics, images):
    """将训练指标和图像写入TensorBoard"""
    # 日志指标
    writer.add_scalar('Loss/Generator', train_metrics['gen_loss'], epoch)
    writer.add_scalar('Loss/Discriminator', train_metrics['dis_loss'], epoch)
    writer.add_scalar('PSNR/Train', train_metrics['psnr'], epoch)
    writer.add_scalar('PSNR/Validation', val_metrics['psnr'], epoch)
    writer.add_scalar('LearningRate/Generator', 
                     train_metrics['gen_lr'], epoch)
    
    # 日志图像(每10个epoch)
    if epoch % 10 == 0:
        lr_img, sr_img, hr_img = images
        writer.add_image('Input/LowResolution', lr_img, epoch)
        writer.add_image('Output/SuperResolution', sr_img, epoch)
        writer.add_image('Target/HighResolution', hr_img, epoch)

输出结果汇总

  1. 核心输出:模型检查点(Checkpoint)
    内容:训练过程中保存的生成器(Generator)和判别器(Discriminator)的权重参数、优化器状态、训练轮次等。
    路径:由 config.py 中的 --checkpoint_dir 参数指定,默认路径为:
    ./checkpoints/
    文件名格式为 checkpoint_epoch_{epoch}.pth(例如 checkpoint_epoch_10.pth)。
    触发时机:每训练 --save_freq 轮(默认 10 轮)保存一次,可通过命令行参数调整。
  2. 日志记录
    内容:训练过程中的损失值、验证指标、关键操作日志(如下载 / 解压进度、模型加载信息等)。
    路径:
    文本日志:由 config.py 中的 --log_dir 参数指定,默认路径为 ./logs/train/train.log。
    终端输出日志:运行脚本 run_esrgan_large.sh 时,会重定向到 training_log_esrgan_$(date).txt(与脚本同目录)。
  3. TensorBoard 可视化结果
    内容:训练 / 验证损失曲线、生成的超分辨率图像(LR 输入、HR 真实值、SR 预测值对比)。
    路径:默认存储在 ./logs/tensorboard/(由 main.py 中 init_tensorboard 函数指定,基于 --log_dir 参数)。
    查看方式:运行 tensorboard --logdir=./logs/tensorboard 后在浏览器访问本地端口。
    注意tensorboard和protobuf的版本要匹配
  4. 超分辨率结果图像(可选)
    内容:验证阶段或推理时生成的超分辨率图像(SR Images)。
    路径:由 config.py 中的 --result_dir 参数指定,默认路径为 ./results/(代码中已初始化该目录,可在推理逻辑中补充保存图像的代码)。

总结与后续方向

通过本篇实验,我们实现了:

1.** 大型数据集管理 :自动下载、解压和组合多个大型数据集,优化数据加载流程
2.
复杂模型构建 :实现了基于残差密集块的ESRGAN模型,相比SRCNN能生成更丰富的细节
3.
高级训练策略 :引入混合精度训练、梯度累积和余弦退火学习率调度,提升GPU利用率
4.
完善监控体系 **:结合日志文件、GPU监控和TensorBoard可视化,全面跟踪实验过程

后续可探索的方向:

  • 尝试更大规模的模型(如RCAN、SwinIR)
  • 引入感知损失和GAN的改进变体(如Relativistic GAN)
  • 实现模型并行和数据并行,利用多GPU进行训练
  • 探索模型压缩和加速技术,实现实时超分辨率
  • 尝试视频超分辨率任务,考虑时间维度的一致性

下一篇我们将探索更前沿的视觉Transformer模型在超分辨率任务中的应用,进一步提升重建质量。

posted on 2025-09-25 21:21  cmxcxd  阅读(23)  评论(0)    收藏  举报