import os # 标准库:操作系统相关(本文件中未直接使用)
import torch # PyTorch 主库
from pathlib import Path # 处理路径
from collections import OrderedDict # 有序字典:用于按固定顺序组织可视化/损失项
from abc import ABC, abstractmethod # 抽象基类支持
from . import networks # 网络与训练调度相关工具(调度器、网络构造等)
class BaseModel(ABC):
"""This class is an abstract base class (ABC) for models.
To create a subclass, you need to implement the following five functions:
-- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
-- <set_input>: unpack data from dataset and apply preprocessing.
-- <forward>: produce intermediate results.
-- <optimize_parameters>: calculate losses, gradients, and update network weights.
-- <modify_commandline_options>: (optionally) add model-specific options and set default options.
"""
# ↑ 英文文档字符串:说明该类是所有模型的抽象基类;子类必须实现的 5 个关键方法
def __init__(self, opt):
"""Initialize the BaseModel class.
Parameters:
opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
When creating your custom class, you need to implement your own initialization.
In this function, you should first call <BaseModel.__init__>(self, opt)
Then, you need to define four lists:
-- self.loss_names (str list): specify the training losses that you want to plot and save.
-- self.model_names (str list): define networks used in our training.
-- self.visual_names (str list): specify the images that you want to display and save.
-- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
"""
self.opt = opt # 保存全局/命令行配置
self.gpu_ids = opt.gpu_ids # 设备 id 列表
self.isTrain = opt.isTrain # 训练/测试标志
# 根据 gpu_ids 是否为空来选择设备;若为空则用 CPU,否则用第一个 GPU
self.device = torch.device("cuda:{}".format(self.gpu_ids[0])) if self.gpu_ids else torch.device("cpu") # get device name: CPU or GPU
# 检查点目录:<checkpoints_dir>/<experiment_name>
self.save_dir = Path(opt.checkpoints_dir) / opt.name # save all the checkpoints to save_dir
# 为了提升 cudnn 搜索最优卷积算法的速度:当不是 scale_width 预处理时开启 benchmark
if (
opt.preprocess != "scale_width"
# 当预处理方式不是"scale_width"时,开启该模式,以加速后续的卷积等操作。
): # with [scale_width], input images might have different sizes, which hurts the performance of cudnn.benchmark.
torch.backends.cudnn.benchmark = True
# cudnn是 NVIDIA 推出的针对深度学习的 GPU 加速库,专门优化卷积、池化等常用操作。
# torch.backends.cudnn.benchmark = True表示开启cudnn 的基准测试模式:
# 开启后,程序会在首次运行时对当前硬件上可用的卷积算法(如不同的卷积实现方式)进行一次 “基准测试”(耗时很短),
# 找到当前输入尺寸下最优的算法。
# 之后的卷积操作会固定使用这个最优算法,避免每次运行时重新选择算法的开销,从而提升整体性能
# 下面这些列表由子类在 __init__ 中填充,用于日志记录、保存、展示及优化器管理
self.loss_names = [] # 需要记录/可视化/保存的损失项名(不含前缀 'loss_')
self.model_names = [] # 需要管理/保存/加载的网络名称后缀(比如 ['G','D'])
self.visual_names = [] # 需要可视化/保存成网页的图像张量名称
self.optimizers = [] # 优化器列表(通常与 model_names 对应)
self.image_paths = [] # 当前 batch 对应的图像路径列表(供日志/可视化使用)
self.metric = 0 # 学习率策略 'plateau' 的监控指标
@staticmethod
def modify_commandline_options(parser, is_train):
"""Add new model-specific options, and rewrite default values for existing options.
Parameters:
parser -- original option parser
is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
Returns:
the modified parser.
"""
# 子类可重写该静态方法,往 argparse 中添加模型特定参数或覆盖默认值;默认不做修改
return parser
@abstractmethod
def set_input(self, input):
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
Parameters:
input (dict): includes the data itself and its metadata information.
"""
# 抽象方法:从数据加载器中取出一个 batch,并完成必要的预处理/搬运到 self.device
pass
@abstractmethod
def forward(self):
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
# 抽象方法:前向计算,用于训练与测试阶段
pass
@abstractmethod
def optimize_parameters(self):
"""Calculate losses, gradients, and update network weights; called in every training iteration"""
# 抽象方法:一次训练迭代内的完整优化步骤(计算损失→反传→更新参数)
pass
def setup(self, opt):
"""Load and print networks; create schedulers
Parameters:
opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
"""
# 训练阶段:基于当前优化器创建学习率调度器(可能是 step/plateau/cosine 等)
if self.isTrain:
self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers]
# 测试阶段或断点续训:根据 epoch / iter 载入已保存的网络权重
if not self.isTrain or opt.continue_train:
load_suffix = "iter_%d" % opt.load_iter if opt.load_iter > 0 else opt.epoch # 加载优先按指定 iter,否则按 epoch
self.load_networks(load_suffix)
# 打印网络参数规模与(可选)结构
self.print_networks(opt.verbose)
# 若环境支持(PyTorch 2.0+),可选地对网络应用 torch.compile 以获得运行时优化
if hasattr(torch, "compile"):
self.compile_networks()
def eval(self):
"""Make models eval mode during test time"""
# 测试时将各网络切换到 eval()(影响 BN/Dropout 等行为)
for name in self.model_names:
if isinstance(name, str):
net = getattr(self, "net" + name)
net.eval()
def test(self):
"""Forward function used in test time.
This function wraps <forward> function in no_grad() so we don't save intermediate steps for backprop
It also calls <compute_visuals> to produce additional visualization results
"""
# 测试前向:禁用梯度,避免保存中间结果,随后计算可视化用的额外输出
with torch.no_grad():
self.forward()
self.compute_visuals()
def compute_visuals(self):
"""Calculate additional output images for visdom and HTML visualization"""
# 钩子:由子类实现,生成额外的可视化图像(例如:中间特征/重建结果)
pass
def get_image_paths(self):
"""Return image paths that are used to load current data"""
# 返回当前 batch 的原始图像路径(通常由 set_input() 填充)
return self.image_paths
def update_learning_rate(self):
"""Update learning rates for all the networks; called at the end of every epoch"""
# 记录更新前的学习率(读取第一个优化器的第一个 param_group)
old_lr = self.optimizers[0].param_groups[0]["lr"]
# 按策略推进调度器:'plateau' 需要传入监控指标;其它策略直接 step()
for scheduler in self.schedulers:
if self.opt.lr_policy == "plateau":
scheduler.step(self.metric)
else:
scheduler.step()
# 打印学习率变化
lr = self.optimizers[0].param_groups[0]["lr"]
print("learning rate %.7f -> %.7f" % (old_lr, lr))
def get_current_visuals(self):
"""Return visualization images. train.py will display these images with visdom, and save the images to a HTML"""
# 将 visual_names 中列出的张量取出来,按名称组织成有序字典(便于日志和网页展示)
visual_ret = OrderedDict()
for name in self.visual_names:
if isinstance(name, str):
visual_ret[name] = getattr(self, name)
return visual_ret
def get_current_losses(self):
"""Return traning losses / errors. train.py will print out these errors on console, and save them to a file"""
# 将 loss_names 中列出的损失张量(属性名为 'loss_'+name)读取并转成 float
errors_ret = OrderedDict()
for name in self.loss_names:
if isinstance(name, str):
errors_ret[name] = float(getattr(self, "loss_" + name)) # float(...) works for both scalar tensor and float number
return errors_ret
def save_networks(self, epoch):
# 这里的作用是把每轮训练得到的网络权重pth文件保存到本地
"""Save all the networks to the disk.
Parameters:
epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
"""
# 依次保存 model_names 中的所有网络到 <save_dir>/<epoch>_net_<name>.pth
for name in self.model_names:
if isinstance(name, str):
save_filename = f"{epoch}_net_{name}.pth"
save_path = self.save_dir / save_filename
net = getattr(self, "net" + name)
if len(self.gpu_ids) > 0 and torch.cuda.is_available():
# DataParallel/DistributedDataParallel 的 net 可能包在 .module 里;
# 保存时先移到 CPU 再保存,随后将网络移回 GPU(首个 id)
torch.save(net.module.cpu().state_dict(), save_path)
net.cuda(self.gpu_ids[0])
else:
torch.save(net.cpu().state_dict(), save_path)
def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
"""Fix InstanceNorm checkpoints incompatibility (prior to 0.4)"""
# 递归地修补老版本(0.4 之前)InstanceNorm 的 state_dict 键/缓冲区不兼容问题
key = keys[i]
if i + 1 == len(keys): # at the end, pointing to a parameter/buffer
# 如果是 InstanceNorm 且键是 running_mean / running_var,而对应属性为 None,则从 state_dict 移除该键
if module.__class__.__name__.startswith("InstanceNorm") and (key == "running_mean" or key == "running_var"):
if getattr(module, key) is None:
state_dict.pop(".".join(keys))
# 同理,移除 num_batches_tracked(老版不包含)
if module.__class__.__name__.startswith("InstanceNorm") and (key == "num_batches_tracked"):
state_dict.pop(".".join(keys))
else:
# 递归深入下一级模块
self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)
def load_networks(self, epoch):
"""Load all the networks from the disk.
Parameters:
epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
"""
# 依次加载各网络:从 <save_dir>/<epoch>_net_<name>.pth 读入 state_dict 并载入到模型
for name in self.model_names:
if isinstance(name, str):
load_filename = f"{epoch}_net_{name}.pth"
load_path = self.save_dir / load_filename
net = getattr(self, "net" + name)
# 若包裹在 DataParallel 中,先取出实际模块
if isinstance(net, torch.nn.DataParallel):
net = net.module
print(f"loading the model from {load_path}")
# 注意:map_location 使用设备字符串;weights_only=True 仅加载权重张量(更安全)
state_dict = torch.load(load_path, map_location=str(self.device), weights_only=True)
# 某些保存格式会带 _metadata,用不到则删除避免干扰
if hasattr(state_dict, "_metadata"):
del state_dict._metadata
# 老版本 InstanceNorm 的兼容性修补(在遍历中会修改字典,故先拷贝键列表)
for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop
self.__patch_instance_norm_state_dict(state_dict, net, key.split("."))
net.load_state_dict(state_dict)
def compile_networks(self, **compile_kwargs):
"""Apply torch.compile to all networks for optimization.
Parameters:
**compile_kwargs -- keyword arguments to pass to torch.compile
(e.g., mode='reduce-overhead', backend='inductor')
"""
# 对每个网络应用 torch.compile(如 PyTorch 2.0 的 AOT 编译),可显著优化推理/训练性能
for name in self.model_names:
if isinstance(name, str):
net = getattr(self, "net" + name)
compiled_net = torch.compile(net, **compile_kwargs)
setattr(self, "net" + name, compiled_net) # 将编译后的网络回写到实例属性
print(f"[Network {name}] compiled with torch.compile")
setattr(self, "net" + name, compiled_net) # 再次回写(功能上与上一行重复,虽无害但属冗余)
def print_networks(self, verbose):
"""Print the total number of parameters in the network and (if verbose) network architecture
Parameters:
verbose (bool) -- if verbose: print the network architecture
"""
# 打印参数量统计;若 verbose=True 也打印完整网络结构
print("---------- Networks initialized -------------")
for name in self.model_names:
if isinstance(name, str):
net = getattr(self, "net" + name)
num_params = 0
for param in net.parameters():
num_params += param.numel() # 累加所有参数张量的元素个数
if verbose:
print(net)
print("[Network %s] Total number of parameters : %.3f M" % (name, num_params / 1e6))
print("-----------------------------------------------")
def set_requires_grad(self, nets, requires_grad=False):
"""Set requies_grad=Fasle for all the networks to avoid unnecessary computations
Parameters:
nets (network list) -- a list of networks
requires_grad (bool) -- whether the networks require gradients or not
"""
# 将一个网络或网络列表统一设置 requires_grad 标志(常用于冻结判别器/特征提取器以节省计算)
if not isinstance(nets, list):
nets = [nets]
for net in nets:
if net is not None:
for param in net.parameters():
param.requires_grad = requires_grad
关于load_networks
为什么要“把保存好的 .pth 再载入到模型里”
恢复与续训:训练中断(断电、崩溃、换机器)后,从最近的 checkpoint 继续,而不是从头来过。
推理与部署:线上/离线推理时,需要“训练好的权重”,而不是空模型;加载后才能得到正确输出。
复现实验:论文/项目复现、回溯某个 epoch/iter 的结果,都要把当时的权重载回来。
对比与回滚:想比较第 50 和第 100 个 epoch 的效果,或发现回退更好用时,需要随时加载不同版本的权重。
迁移/微调:先加载在大数据集上训练的权重,再在新数据上微调,比从零训练更快更稳。
可移植与稳健:保存/加载 state_dict(纯权重) 比“直接序列化整个模型对象”更稳:
不依赖完全一致的 Python 对象布局和环境;
跨机器、跨版本更不容易出错;
便于手工修改/筛选键值以做兼容修补(正如函数内的 InstanceNorm 兼容逻辑)。
load_networks 只还原网络参数;若要“无缝续训”,优化器/调度器状态也要另存/另载(本基类没做这一点,很多项目会单独存一个 optim.pth)
总结:
BaseModel.__init__()
└── 初始化成员变量(无内部方法调用)
BaseModel.setup()
├── self.load_networks() # 加载网络权重
├── self.print_networks() # 打印网络信息
└── self.compile_networks() # (条件触发)编译网络(若torch.compile可用)
BaseModel.test()
├── self.forward() # 调用前向传播(抽象方法,子类实现)
└── self.compute_visuals() # 计算可视化结果(默认空实现)
BaseModel.update_learning_rate()
└── scheduler.step() # 调用优化器调度器的step方法(依赖networks模块的scheduler)
BaseModel.save_networks()
└── torch.save() # 保存网络状态字典
BaseModel.load_networks()
├── torch.load() # 加载网络状态字典
├── self.__patch_instance_norm_state_dict() # 修复InstanceNorm兼容性问题
└── net.load_state_dict() # 加载状态字典到网络
BaseModel.compile_networks()
└── torch.compile() # 编译网络(PyTorch 2.0+特性)
BaseModel.print_networks()
└── 遍历网络参数(无内部方法调用,打印信息)
BaseModel.set_requires_grad()
└── 直接操作网络参数(无内部方法调用)