Loading

【CV】GAN代码解析:base_model.py

import os  # 标准库:操作系统相关(本文件中未直接使用)
import torch  # PyTorch 主库
from pathlib import Path  # 处理路径
from collections import OrderedDict  # 有序字典:用于按固定顺序组织可视化/损失项
from abc import ABC, abstractmethod  # 抽象基类支持
from . import networks  # 网络与训练调度相关工具(调度器、网络构造等)


class BaseModel(ABC):
    """This class is an abstract base class (ABC) for models.
    To create a subclass, you need to implement the following five functions:
        -- <__init__>:                      initialize the class; first call BaseModel.__init__(self, opt).
        -- <set_input>:                     unpack data from dataset and apply preprocessing.
        -- <forward>:                       produce intermediate results.
        -- <optimize_parameters>:           calculate losses, gradients, and update network weights.
        -- <modify_commandline_options>:    (optionally) add model-specific options and set default options.
    """
    # ↑ 英文文档字符串:说明该类是所有模型的抽象基类;子类必须实现的 5 个关键方法

    def __init__(self, opt):
        """Initialize the BaseModel class.

        Parameters:
            opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions

        When creating your custom class, you need to implement your own initialization.
        In this function, you should first call <BaseModel.__init__>(self, opt)
        Then, you need to define four lists:
            -- self.loss_names (str list):          specify the training losses that you want to plot and save.
            -- self.model_names (str list):         define networks used in our training.
            -- self.visual_names (str list):        specify the images that you want to display and save.
            -- self.optimizers (optimizer list):    define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
        """
        self.opt = opt  # 保存全局/命令行配置
        self.gpu_ids = opt.gpu_ids  # 设备 id 列表
        self.isTrain = opt.isTrain  # 训练/测试标志
		
        # 根据 gpu_ids 是否为空来选择设备;若为空则用 CPU,否则用第一个 GPU
        self.device = torch.device("cuda:{}".format(self.gpu_ids[0])) if self.gpu_ids else torch.device("cpu")  # get device name: CPU or GPU
		
        # 检查点目录:<checkpoints_dir>/<experiment_name>
        self.save_dir = Path(opt.checkpoints_dir) / opt.name  # save all the checkpoints to save_dir
		
        # 为了提升 cudnn 搜索最优卷积算法的速度:当不是 scale_width 预处理时开启 benchmark
        if (
            opt.preprocess != "scale_width"
# 当预处理方式不是"scale_width"时,开启该模式,以加速后续的卷积等操作。
        ):  # with [scale_width], input images might have different sizes, which hurts the performance of cudnn.benchmark.
            torch.backends.cudnn.benchmark = True
# cudnn是 NVIDIA 推出的针对深度学习的 GPU 加速库,专门优化卷积、池化等常用操作。
# torch.backends.cudnn.benchmark = True表示开启cudnn 的基准测试模式:
# 开启后,程序会在首次运行时对当前硬件上可用的卷积算法(如不同的卷积实现方式)进行一次 “基准测试”(耗时很短),
# 找到当前输入尺寸下最优的算法。
# 之后的卷积操作会固定使用这个最优算法,避免每次运行时重新选择算法的开销,从而提升整体性能


        # 下面这些列表由子类在 __init__ 中填充,用于日志记录、保存、展示及优化器管理
        self.loss_names = []      # 需要记录/可视化/保存的损失项名(不含前缀 'loss_')
        self.model_names = []     # 需要管理/保存/加载的网络名称后缀(比如 ['G','D'])
        self.visual_names = []    # 需要可视化/保存成网页的图像张量名称
        self.optimizers = []      # 优化器列表(通常与 model_names 对应)
        self.image_paths = []     # 当前 batch 对应的图像路径列表(供日志/可视化使用)
        self.metric = 0           # 学习率策略 'plateau' 的监控指标

    @staticmethod
    def modify_commandline_options(parser, is_train):
        """Add new model-specific options, and rewrite default values for existing options.

        Parameters:
            parser          -- original option parser
            is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.

        Returns:
            the modified parser.
        """
        # 子类可重写该静态方法,往 argparse 中添加模型特定参数或覆盖默认值;默认不做修改
        return parser

    @abstractmethod
    def set_input(self, input):
        """Unpack input data from the dataloader and perform necessary pre-processing steps.

        Parameters:
            input (dict): includes the data itself and its metadata information.
        """
        # 抽象方法:从数据加载器中取出一个 batch,并完成必要的预处理/搬运到 self.device
        pass

    @abstractmethod
    def forward(self):
        """Run forward pass; called by both functions <optimize_parameters> and <test>."""
        # 抽象方法:前向计算,用于训练与测试阶段
        pass

    @abstractmethod
    def optimize_parameters(self):
        """Calculate losses, gradients, and update network weights; called in every training iteration"""
        # 抽象方法:一次训练迭代内的完整优化步骤(计算损失→反传→更新参数)
        pass

    def setup(self, opt):
        """Load and print networks; create schedulers

        Parameters:
            opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
        """
        # 训练阶段:基于当前优化器创建学习率调度器(可能是 step/plateau/cosine 等)
        if self.isTrain:
            self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers]
			
        # 测试阶段或断点续训:根据 epoch / iter 载入已保存的网络权重
        if not self.isTrain or opt.continue_train:
            load_suffix = "iter_%d" % opt.load_iter if opt.load_iter > 0 else opt.epoch  # 加载优先按指定 iter,否则按 epoch
            self.load_networks(load_suffix)
        # 打印网络参数规模与(可选)结构
        self.print_networks(opt.verbose)

        # 若环境支持(PyTorch 2.0+),可选地对网络应用 torch.compile 以获得运行时优化
        if hasattr(torch, "compile"):
            self.compile_networks()

    def eval(self):
        """Make models eval mode during test time"""
        # 测试时将各网络切换到 eval()(影响 BN/Dropout 等行为)
        for name in self.model_names:
            if isinstance(name, str):
                net = getattr(self, "net" + name)
                net.eval()

    def test(self):
        """Forward function used in test time.

        This function wraps <forward> function in no_grad() so we don't save intermediate steps for backprop
        It also calls <compute_visuals> to produce additional visualization results
        """
        # 测试前向:禁用梯度,避免保存中间结果,随后计算可视化用的额外输出
        with torch.no_grad():
            self.forward()
            self.compute_visuals()

    def compute_visuals(self):
        """Calculate additional output images for visdom and HTML visualization"""
        # 钩子:由子类实现,生成额外的可视化图像(例如:中间特征/重建结果)
        pass

    def get_image_paths(self):
        """Return image paths that are used to load current data"""
        # 返回当前 batch 的原始图像路径(通常由 set_input() 填充)
        return self.image_paths

    def update_learning_rate(self):
        """Update learning rates for all the networks; called at the end of every epoch"""
        # 记录更新前的学习率(读取第一个优化器的第一个 param_group)
        old_lr = self.optimizers[0].param_groups[0]["lr"]
        # 按策略推进调度器:'plateau' 需要传入监控指标;其它策略直接 step()
        for scheduler in self.schedulers:
            if self.opt.lr_policy == "plateau":
                scheduler.step(self.metric)
            else:
                scheduler.step()
        # 打印学习率变化
        lr = self.optimizers[0].param_groups[0]["lr"]
        print("learning rate %.7f -> %.7f" % (old_lr, lr))

    def get_current_visuals(self):
        """Return visualization images. train.py will display these images with visdom, and save the images to a HTML"""
        # 将 visual_names 中列出的张量取出来,按名称组织成有序字典(便于日志和网页展示)
        visual_ret = OrderedDict()
        for name in self.visual_names:
            if isinstance(name, str):
                visual_ret[name] = getattr(self, name)
        return visual_ret

    def get_current_losses(self):
        """Return traning losses / errors. train.py will print out these errors on console, and save them to a file"""
        # 将 loss_names 中列出的损失张量(属性名为 'loss_'+name)读取并转成 float
        errors_ret = OrderedDict()
        for name in self.loss_names:
            if isinstance(name, str):
                errors_ret[name] = float(getattr(self, "loss_" + name))  # float(...) works for both scalar tensor and float number
        return errors_ret

    def save_networks(self, epoch):
# 这里的作用是把每轮训练得到的网络权重pth文件保存到本地
        """Save all the networks to the disk.

        Parameters:
            epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
        """
        # 依次保存 model_names 中的所有网络到 <save_dir>/<epoch>_net_<name>.pth
        for name in self.model_names:
            if isinstance(name, str):
                save_filename = f"{epoch}_net_{name}.pth"
                save_path = self.save_dir / save_filename
                net = getattr(self, "net" + name)

                if len(self.gpu_ids) > 0 and torch.cuda.is_available():
                    # DataParallel/DistributedDataParallel 的 net 可能包在 .module 里;
                    # 保存时先移到 CPU 再保存,随后将网络移回 GPU(首个 id)
                    torch.save(net.module.cpu().state_dict(), save_path)
                    net.cuda(self.gpu_ids[0])
                else:
                    torch.save(net.cpu().state_dict(), save_path)

    def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
        """Fix InstanceNorm checkpoints incompatibility (prior to 0.4)"""
        # 递归地修补老版本(0.4 之前)InstanceNorm 的 state_dict 键/缓冲区不兼容问题
        key = keys[i]
        if i + 1 == len(keys):  # at the end, pointing to a parameter/buffer
            # 如果是 InstanceNorm 且键是 running_mean / running_var,而对应属性为 None,则从 state_dict 移除该键
            if module.__class__.__name__.startswith("InstanceNorm") and (key == "running_mean" or key == "running_var"):
                if getattr(module, key) is None:
                    state_dict.pop(".".join(keys))
            # 同理,移除 num_batches_tracked(老版不包含)
            if module.__class__.__name__.startswith("InstanceNorm") and (key == "num_batches_tracked"):
                state_dict.pop(".".join(keys))
        else:
            # 递归深入下一级模块
            self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)

    def load_networks(self, epoch):
        """Load all the networks from the disk.

        Parameters:
            epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
        """
        # 依次加载各网络:从 <save_dir>/<epoch>_net_<name>.pth 读入 state_dict 并载入到模型
        for name in self.model_names:
            if isinstance(name, str):
                load_filename = f"{epoch}_net_{name}.pth"
                load_path = self.save_dir / load_filename
                net = getattr(self, "net" + name)
                # 若包裹在 DataParallel 中,先取出实际模块
                if isinstance(net, torch.nn.DataParallel):
                    net = net.module
                print(f"loading the model from {load_path}")
                # 注意:map_location 使用设备字符串;weights_only=True 仅加载权重张量(更安全)
                state_dict = torch.load(load_path, map_location=str(self.device), weights_only=True)
                # 某些保存格式会带 _metadata,用不到则删除避免干扰
                if hasattr(state_dict, "_metadata"):
                    del state_dict._metadata

                # 老版本 InstanceNorm 的兼容性修补(在遍历中会修改字典,故先拷贝键列表)
                for key in list(state_dict.keys()):  # need to copy keys here because we mutate in loop
                    self.__patch_instance_norm_state_dict(state_dict, net, key.split("."))
                net.load_state_dict(state_dict)

    def compile_networks(self, **compile_kwargs):
        """Apply torch.compile to all networks for optimization.

        Parameters:
            **compile_kwargs -- keyword arguments to pass to torch.compile
                              (e.g., mode='reduce-overhead', backend='inductor')
        """
        # 对每个网络应用 torch.compile(如 PyTorch 2.0 的 AOT 编译),可显著优化推理/训练性能
        for name in self.model_names:
            if isinstance(name, str):
                net = getattr(self, "net" + name)
                compiled_net = torch.compile(net, **compile_kwargs)
                setattr(self, "net" + name, compiled_net)  # 将编译后的网络回写到实例属性
                print(f"[Network {name}] compiled with torch.compile")
                setattr(self, "net" + name, compiled_net)  # 再次回写(功能上与上一行重复,虽无害但属冗余)

    def print_networks(self, verbose):
        """Print the total number of parameters in the network and (if verbose) network architecture

        Parameters:
            verbose (bool) -- if verbose: print the network architecture
        """
        # 打印参数量统计;若 verbose=True 也打印完整网络结构
        print("---------- Networks initialized -------------")
        for name in self.model_names:
            if isinstance(name, str):
                net = getattr(self, "net" + name)
                num_params = 0
                for param in net.parameters():
                    num_params += param.numel()  # 累加所有参数张量的元素个数
                if verbose:
                    print(net)
                print("[Network %s] Total number of parameters : %.3f M" % (name, num_params / 1e6))
        print("-----------------------------------------------")

    def set_requires_grad(self, nets, requires_grad=False):
        """Set requies_grad=Fasle for all the networks to avoid unnecessary computations
        Parameters:
            nets (network list)   -- a list of networks
            requires_grad (bool)  -- whether the networks require gradients or not
        """
        # 将一个网络或网络列表统一设置 requires_grad 标志(常用于冻结判别器/特征提取器以节省计算)
        if not isinstance(nets, list):
            nets = [nets]
        for net in nets:
            if net is not None:
                for param in net.parameters():
                    param.requires_grad = requires_grad

关于load_networks

为什么要“把保存好的 .pth 再载入到模型里”

恢复与续训:训练中断(断电、崩溃、换机器)后,从最近的 checkpoint 继续,而不是从头来过。
推理与部署:线上/离线推理时,需要“训练好的权重”,而不是空模型;加载后才能得到正确输出。
复现实验:论文/项目复现、回溯某个 epoch/iter 的结果,都要把当时的权重载回来。
对比与回滚:想比较第 50 和第 100 个 epoch 的效果,或发现回退更好用时,需要随时加载不同版本的权重。
迁移/微调:先加载在大数据集上训练的权重,再在新数据上微调,比从零训练更快更稳。
可移植与稳健:保存/加载 state_dict(纯权重) 比“直接序列化整个模型对象”更稳:
	不依赖完全一致的 Python 对象布局和环境;
	跨机器、跨版本更不容易出错;
	便于手工修改/筛选键值以做兼容修补(正如函数内的 InstanceNorm 兼容逻辑)。

load_networks 只还原网络参数;若要“无缝续训”,优化器/调度器状态也要另存/另载(本基类没做这一点,很多项目会单独存一个 optim.pth)

总结:

BaseModel.__init__() 
└── 初始化成员变量(无内部方法调用)

BaseModel.setup() 
├── self.load_networks()  # 加载网络权重
├── self.print_networks()  # 打印网络信息
└── self.compile_networks()  # (条件触发)编译网络(若torch.compile可用)

BaseModel.test() 
├── self.forward()  # 调用前向传播(抽象方法,子类实现)
└── self.compute_visuals()  # 计算可视化结果(默认空实现)

BaseModel.update_learning_rate() 
└── scheduler.step()  # 调用优化器调度器的step方法(依赖networks模块的scheduler)

BaseModel.save_networks() 
└── torch.save()  # 保存网络状态字典

BaseModel.load_networks() 
├── torch.load()  # 加载网络状态字典
├── self.__patch_instance_norm_state_dict()  # 修复InstanceNorm兼容性问题
└── net.load_state_dict()  # 加载状态字典到网络

BaseModel.compile_networks() 
└── torch.compile()  # 编译网络(PyTorch 2.0+特性)

BaseModel.print_networks() 
└── 遍历网络参数(无内部方法调用,打印信息)

BaseModel.set_requires_grad() 
└── 直接操作网络参数(无内部方法调用)
posted @ 2025-09-24 16:40  SaTsuki26681534  阅读(11)  评论(0)    收藏  举报