Loading

【CV】GAN代码解析 model init.py

"""This package contains modules related to objective functions, optimizations, and network architectures.

To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel.
You need to implement the following five functions:
    -- <__init__>:                      initialize the class; first call BaseModel.__init__(self, opt).
    -- <set_input>:                     unpack data from dataset and apply preprocessing.
    -- <forward>:                       produce intermediate results.
    -- <optimize_parameters>:           calculate loss, gradients, and update network weights.
    -- <modify_commandline_options>:    (optionally) add model-specific options and set default options.
"""
#   还给出如何添加自定义模型的约定(文件命名、需实现的方法、在 __init__ 中要维护的列表等)。

"""In the function <__init__>, you need to define four lists:
    -- self.loss_names (str list):          specify the training losses that you want to plot and save.
    -- self.model_names (str list):         define networks used in our training.
    -- self.visual_names (str list):        specify the images that you want to display and save.
    -- self.optimizers (optimizer list):    define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage.

Now you can use the model class by specifying flag '--model dummy'.
See our template model class 'template_model.py' for more details.
"""
# ↑ 包级文档字符串:说明该包与“目标函数/优化/网络结构”相关;
#   保持原样不改动,因为它本身就是可读的说明文档。

import importlib  # 动态导入模块的标准库(按模块名字符串加载 Python 模块)
from models.base_model import BaseModel  # 引入所有模型应继承的基类,用于类型约束与通用接口


def find_model_using_name(model_name: str):
    """Import the module "models/[model_name]_model.py".

    In the file, the class called DatasetNameModel() will
    be instantiated. It has to be a subclass of BaseModel,
    and it is case-insensitive.
    """
    # 目标:根据传入的 model_name,动态加载对应的 “models/{model_name}_model.py”,
    # 人话:根据需要的模型名字找到对应的py文件
    # 并在其中找到一个类,其类名与 “{model_name}model”(忽略下划线与大小写)匹配,且必须继承 BaseModel。

    model_filename = "models." + model_name + "_model"  
    # 组装模块路径,如 model_name='cycle_gan' → 'models.cycle_gan_model'
    modellib = importlib.import_module(model_filename)  # 动态导入该模块,获得模块对象
    model = None  # 定义变量用于保存最终找到的类;
    target_model_name = model_name.replace("_", "") + "model"  
    # 规范化目标类名:去掉下划线后再拼接 'model'(如 'cycle_gan' → 'cycleganmodel')

    # 遍历模块的符号表(字典):键 name 是符号名(类/函数/变量名),值 cls 是对象本身
    # (这里和之前创建dataset其实是同一个逻辑)
    for name, cls in modellib.__dict__.items():
        # 条件1:名称大小写不敏感地匹配规范化后的目标类名;
        # 条件2:并且该对象是 BaseModel 的子类(确保模型具备统一接口)
        if name.lower() == target_model_name.lower() and issubclass(cls, BaseModel):
            model = cls  # 命中则记录下来(不立即返回是因为要先遍历完或允许后续覆盖)

    if model is None:
        # 若没有在模块中找到满足条件的类,打印明确的提示并退出进程(也可改成抛异常,看项目风格)
        print(
            "In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase."
            % (model_filename, target_model_name)
        )
        exit(0)  # 直接结束程序,防止后续使用空模型导致更隐蔽的问题

    return model  # 返回类对象本身(注意:不是实例)


def get_option_setter(model_name: str):
    """Return the static method <modify_commandline_options> of the model class."""
    # 用于拿到某个模型类自带的“静态方法” modify_commandline_options(为 argparse 添加/修改模型特定的命令行参数)
    model_class = find_model_using_name(model_name)  # 先解析得到模型类
    return model_class.modify_commandline_options  # 返回该类的静态方法本身,供外部调用


def create_model(opt):
    """Create a model given the option."""
    # 根据运行时配置 opt 创建对应模型的“实例”
    model = find_model_using_name(opt.model)  # 从 opt.model(字符串)解析出具体的模型类
    instance = model(opt)  # 调用其构造函数完成实例化,并把全局/训练配置传入
    print(f"model [{type(instance).__name__}] was created")  # 打印已创建模型的类名,便于日志追踪
    return instance  # 返回模型实例给训练/推理流程使用

posted @ 2025-09-24 16:24  SaTsuki26681534  阅读(7)  评论(0)    收藏  举报