Loading

【cv】cycleGAN代码解析:train.py

import time                                   # 计时:统计每轮/每次迭代耗时
from options.train_options import TrainOptions # 训练期命令行参数解析器(继承 BaseOptions 并添加训练相关项)
from data import create_dataset                # 工厂函数:按 opt.dataset_mode 创建数据集实例
from models import create_model                # 工厂函数:按 opt.model 创建对应模型(如 CycleGAN、pix2pix)
from util.visualizer import Visualizer         # 可视化与日志工具:显示/保存图像、打印/绘制损失曲线
if __name__ == "__main__":                     # 仅在作为脚本运行时执行,避免被 import 时跑训练
    opt = TrainOptions().parse()               # 解析命令行参数,得到训练配置对象 opt
	# 从train_options.py里得到命令行参数的对象
	
    dataset = create_dataset(opt)              # 依据选项创建数据集(并可能内部创建 DataLoader)
	# 这里的dataset对象里包含数据集对象以及dataLoader
	
    dataset_size = len(dataset)                # 获取数据集中图像数量(通常是样本数或步数的近似)
    print(f"The number of training images = {dataset_size}")  # 打印样本数量,便于确认数据是否加载正确

	# 创建模型并进行一定的初始化操作
    model = create_model(opt)                  # 创建指定的模型实例(如 CycleGANModel)
    model.setup(opt)                           # 常规设置:加载/打印网络,创建学习率调度器(若训练/续训)
    visualizer = Visualizer(opt)               # 创建可视化器:负责 visdom/HTML 图像展示与损失曲线绘制
    total_iters = 0                            # 全局累计的“迭代步数”(以样本/批为单位累加)

    for epoch in range(                        # 外层按 epoch 循环
        opt.epoch_count,                       # 起始 epoch(支持从中间轮次续训)
        opt.n_epochs + opt.n_epochs_decay + 1  # 训练轮数 = 预热阶段 + 线性衰减阶段;+1 使得上界包含在内
    ):
	# 构造for循环,开始每个epoch的操作
	
        epoch_start_time = time.time()         # 记录该轮开始时间(统计整轮耗时)
        iter_data_time = time.time()           # 记录上一次取数据的时间(用于统计 data loading 时间)
	# (先定义好iter_data_time这个变量)
        epoch_iter = 0                         # 当前 epoch 内已处理的样本数(或近似步数),每轮重置
	# 记录当前epoch里经历的iter数
        visualizer.reset()                     # 重置可视化器:确保至少每个 epoch 会把结果写入 HTML

enumerate机制和可迭代对象

enumerate

enumerate是 Python 的内置函数,核心作用是为可迭代对象的元素添加索引,方便在迭代时同时获取 “索引” 和 “元素值”

enumerate(iterable)会返回一个枚举对象(enumerate object),它本质是一个迭代器(iterator)。每次迭代时,这个迭代器会返回一个元组(索引, 元素),其中:

第一个元素是当前迭代的索引(默认从 0 开始,可通过start参数指定起始值,如enumerate(lst, start=1));
第二个元素是可迭代对象iterable中的对应元素。

因此,在for i, item in enumerate(iterable):中,i接收索引,item接收元素,这是对元组(i, item)的解包操作

可迭代对象

enumerate函数的传入参数必须是可迭代对象

在 Python 中,可迭代对象需要满足:实现了__iter__()方法,该方法返回一个迭代器(iterator);而迭代器需要实现__next__()方法(用于返回下一个元素)和__iter__()方法(返回自身)

参考文献:https://zhuanlan.zhihu.com/p/7364648529

        for i, data in enumerate(dataset):     # 内层按批次迭代数据集(DataLoader 可迭代)
		# 当这里迭代dataset对象时,其实是在迭代里面的dataLoader对象,以Iter为单位
		# 即,每次取出一个iter的数据
		# data对象里包含着dataLoader本次取出的数据的相关信息,在后面的set_input方法里会把这些数据加载进去
		
            iter_start_time = time.time()      # 记录本次迭代计算开始时间
            if total_iters % opt.print_freq == 0:      # 每隔 print_freq 次统计一次“取数耗时”
			# parser.add_argument('--print_freq', type=int, default=100, 
			# help='frequency of showing training results on console')
			# print_freq是打印训练结果的频率
                t_data = iter_start_time - iter_data_time

            total_iters += opt.batch_size       # 全局步数累加(以 batch_size 作为步长)
            epoch_iter += opt.batch_size        # 当前轮的步数累加
			
            model.set_input(data)               # 解包 dataloader 返回的 data,并搬到正确的 device
		# 这里模型已经得到了dataLoader里的数据了
            model.optimize_parameters()         # 前向、计算损失、反传、更新网络参数(一次标准训练步)
		# 这一句就是训练的核心操作

回顾一下这个函数

    def optimize_parameters(self):
        """Calculate losses, gradients, and update network weights; called in every training iteration"""
        # 1) 前向:生成假图与重建图
        self.forward()
        # 2) 优化生成器(冻结两个判别器的梯度)
        self.set_requires_grad([self.netD_A, self.netD_B], False)
        self.optimizer_G.zero_grad()
        self.backward_G()
        self.optimizer_G.step()
        # 3) 优化判别器(解冻判别器)
        self.set_requires_grad([self.netD_A, self.netD_B], True)
        self.optimizer_D.zero_grad()
        self.backward_D_A()
        self.backward_D_B()
        self.optimizer_D.step()

            if total_iters % opt.display_freq == 0:  # 到了展示频率:显示图像并(可选)保存到 HTML
                save_result = total_iters % opt.update_html_freq == 0  # 是否本次也写 HTML(更低频地写盘)
                model.compute_visuals()         # 生成额外可视化结果(由模型实现,如重建/中间图)
                visualizer.display_current_results(
                    model.get_current_visuals(),  # 从模型取出需要展示的可视化张量
                    epoch,                        # 当前 epoch 序号
                    total_iters,                  # 全局迭代计数(用于命名与记录)
                    save_result                   # 是否保存 HTML(否则只在 visdom 上显示)
                )

            if total_iters % opt.print_freq == 0:    # 到了打印频率:打印损失并记录到磁盘
                losses = model.get_current_losses()  # 从模型取回当前各项损失(有序字典)
                t_comp = (time.time() - iter_start_time) / opt.batch_size  # 计算每张/每样本平均计算耗时
                visualizer.print_current_losses(epoch, epoch_iter, losses, t_comp, t_data)  # 控制台打印
                visualizer.plot_current_losses(total_iters, losses)                         # 动态曲线

            if total_iters % opt.save_latest_freq == 0:  # 到了“保存最新模型”的频率:保存快照
                print(f"saving the latest model (epoch {epoch}, total_iters {total_iters})")
                save_suffix = f"iter_{total_iters}" if opt.save_by_iter else "latest"  # 可按迭代编号或统一 latest
                model.save_networks(save_suffix)  # 以 <suffix> 作为文件名后缀存盘(见 BaseModel.save_networks)

            iter_data_time = time.time()         # 更新“上一次取数时间”,用于下一个 batch 的 t_data 统计

        model.update_learning_rate()             # 每个 epoch 结束时根据策略更新学习率(调度器 step)

        if epoch % opt.save_epoch_freq == 0:     # 到了“按 epoch 频率保存”的时刻:保存 latest 与该轮编号
            print(f"saving the model at the end of epoch {epoch}, iters {total_iters}")
            model.save_networks("latest")        # 保存成 latest(覆盖)
            model.save_networks(epoch)           # 再保存一个以 epoch 编号命名的权重(留存历史)

        print(                                   # 打印本轮结束信息与整轮耗时(四舍五入到秒)
            f"End of epoch {epoch} / {opt.n_epochs + opt.n_epochs_decay} \t "
            f"Time Taken: {time.time() - epoch_start_time:.0f} sec"
        )

posted @ 2025-09-25 16:37  SaTsuki26681534  阅读(40)  评论(0)    收藏  举报