"""This package contains modules related to objective functions, optimizations, and network architectures.
To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel.
You need to implement the following five functions:
-- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
-- <set_input>: unpack data from dataset and apply preprocessing.
-- <forward>: produce intermediate results.
-- <optimize_parameters>: calculate loss, gradients, and update network weights.
-- <modify_commandline_options>: (optionally) add model-specific options and set default options.
"""
# 还给出如何添加自定义模型的约定(文件命名、需实现的方法、在 __init__ 中要维护的列表等)。
"""In the function <__init__>, you need to define four lists:
-- self.loss_names (str list): specify the training losses that you want to plot and save.
-- self.model_names (str list): define networks used in our training.
-- self.visual_names (str list): specify the images that you want to display and save.
-- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage.
Now you can use the model class by specifying flag '--model dummy'.
See our template model class 'template_model.py' for more details.
"""
# ↑ 包级文档字符串:说明该包与“目标函数/优化/网络结构”相关;
# 保持原样不改动,因为它本身就是可读的说明文档。
import importlib # 动态导入模块的标准库(按模块名字符串加载 Python 模块)
from models.base_model import BaseModel # 引入所有模型应继承的基类,用于类型约束与通用接口
def find_model_using_name(model_name: str):
"""Import the module "models/[model_name]_model.py".
In the file, the class called DatasetNameModel() will
be instantiated. It has to be a subclass of BaseModel,
and it is case-insensitive.
"""
# 目标:根据传入的 model_name,动态加载对应的 “models/{model_name}_model.py”,
# 人话:根据需要的模型名字找到对应的py文件
# 并在其中找到一个类,其类名与 “{model_name}model”(忽略下划线与大小写)匹配,且必须继承 BaseModel。
model_filename = "models." + model_name + "_model"
# 组装模块路径,如 model_name='cycle_gan' → 'models.cycle_gan_model'
modellib = importlib.import_module(model_filename) # 动态导入该模块,获得模块对象
model = None # 定义变量用于保存最终找到的类;
target_model_name = model_name.replace("_", "") + "model"
# 规范化目标类名:去掉下划线后再拼接 'model'(如 'cycle_gan' → 'cycleganmodel')
# 遍历模块的符号表(字典):键 name 是符号名(类/函数/变量名),值 cls 是对象本身
# (这里和之前创建dataset其实是同一个逻辑)
for name, cls in modellib.__dict__.items():
# 条件1:名称大小写不敏感地匹配规范化后的目标类名;
# 条件2:并且该对象是 BaseModel 的子类(确保模型具备统一接口)
if name.lower() == target_model_name.lower() and issubclass(cls, BaseModel):
model = cls # 命中则记录下来(不立即返回是因为要先遍历完或允许后续覆盖)
if model is None:
# 若没有在模块中找到满足条件的类,打印明确的提示并退出进程(也可改成抛异常,看项目风格)
print(
"In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase."
% (model_filename, target_model_name)
)
exit(0) # 直接结束程序,防止后续使用空模型导致更隐蔽的问题
return model # 返回类对象本身(注意:不是实例)
def get_option_setter(model_name: str):
"""Return the static method <modify_commandline_options> of the model class."""
# 用于拿到某个模型类自带的“静态方法” modify_commandline_options(为 argparse 添加/修改模型特定的命令行参数)
model_class = find_model_using_name(model_name) # 先解析得到模型类
return model_class.modify_commandline_options # 返回该类的静态方法本身,供外部调用
def create_model(opt):
"""Create a model given the option."""
# 根据运行时配置 opt 创建对应模型的“实例”
model = find_model_using_name(opt.model) # 从 opt.model(字符串)解析出具体的模型类
instance = model(opt) # 调用其构造函数完成实例化,并把全局/训练配置传入
print(f"model [{type(instance).__name__}] was created") # 打印已创建模型的类名,便于日志追踪
return instance # 返回模型实例给训练/推理流程使用