Loading

stable diffusion webui forge 代码解读

ControlNet

在webui中,controlnet作为extension的形式存在。在Forge中,开发者将controlnet直接集成。
脚本处理入口(前执行process)
extensions-builtin/sd_forge_controlnet/scripts/controlnet.py

@torch.no_grad()
    def process(self, p, *args, **kwargs):
        self.current_params = {}
        enabled_units = self.get_enabled_units(args)
        Infotext.write_infotext(enabled_units, p)
        for i, unit in enumerate(enabled_units):
            self.bound_check_params(unit)
            params = ControlNetCachedParameters()
            self.process_unit_after_click_generate(p, unit, params, *args, **kwargs)
            self.current_params[i] = params
        return
其中,process_unit_after_click_generate中的预处理
scripts/controlnet.py/process_unit_after_click_generate
 if preprocessor.do_not_need_model:
    model_filename = 'Not Needed'
    params.model = ControlModelPatcher()
else:
    assert unit.model != 'None', 'You have not selected any control model!'
    model_filename = global_state.get_controlnet_filename(unit.model)
    params.model = cached_controlnet_loader(model_filename)
    assert params.model is not None, logger.error(f"Recognizing Control Model failed: {model_filename}")
在cached_controlnet_loader()调用了try_load_supported_control_model()
scripts/controlnet.py
 @functools.lru_cache(maxsize=shared.opts.data.get("control_net_model_cache_size", 5))
def cached_controlnet_loader(filename):
    return try_load_supported_control_model(filename)
调用了modules/shared.py中的try_load_supported_control_model
modules/shared.py
 def try_load_supported_control_model(ckpt_path):
    global supported_control_models
    state_dict = ldm_patched.modules.utils.load_torch_file(ckpt_path, safe_load=True)
    for supported_type in supported_control_models:
        state_dict_copy = {k: v for k, v in state_dict.items()}
        model = supported_type.try_build_from_state_dict(state_dict_copy, ckpt_path)
        if model is not None:
            return model
    return None
对于缓存中的每个支持的controlnet模型,进行加载模型权重try_build_from_state_dict(),此处先拿到原始的ControlNet结构
modules_forge/supported_controlnet.py
 def try_build_from_state_dict(controlnet_data, ckpt_path)
...
    control_model = ldm_patched.controlnet.cldm.ControlNet(**controlnet_config)
...
            class WeightsLoader(torch.nn.Module):
                pass

            w = WeightsLoader()
            w.control_model = control_model
            missing, unexpected = w.load_state_dict(controlnet_data, strict=False)
原始的ControlNet类定义:
ldm_patched/controlnet/cldm.py
 class ControlNet(nn.Module):
    def __init__(
        self,
        image_size,
        in_channels,
        model_channels,
        hint_channels,
        num_res_blocks,
        dropout=0,
        channel_mult=(1, 2, 4, 8),
        conv_resample=True,
        dims=2,
        num_classes=None,
        use_checkpoint=False,
        dtype=torch.float32,
        num_heads=-1,
        num_head_channels=-1,
        num_heads_upsample=-1,
        use_scale_shift_norm=False,
        resblock_updown=False,
        use_new_attention_order=False,
        use_spatial_transformer=False,    # custom transformer support
        transformer_depth=1,              # custom transformer support
        context_dim=None,                 # custom transformer support
        n_embed=None,                     # custom support for prediction of discrete ids into codebook of first stage vq model
        legacy=True,
        disable_self_attentions=None,
        num_attention_blocks=None,
        disable_middle_self_attn=False,
        use_linear_in_transformer=False,
        adm_in_channels=None,
        transformer_depth_middle=None,
        transformer_depth_output=None,
        device=None,
        operations=ldm_patched.modules.ops.disable_weight_init,
        **kwargs,
    ):
 
ldm_patched/controlnet/cldm.py
 def forward(self, x, hint, timesteps, context, y=None, **kwargs):
        # 把时间t编码为vector
        t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
        # linear->relu->linear   emb.shape=(B, time_embed_dim)
        emb = self.time_embed(t_emb)
        # 只对 hint 进行conv->silu->conv->silu->...->conv->silu->zero_conv
        guided_hint = self.input_hint_block(hint, emb, context)

        outs = []

        hs = []
        if self.num_classes is not None:
            assert y.shape[0] == x.shape[0]
            emb = emb + self.label_emb(y)

        h = x
        for module, zero_conv in zip(self.input_blocks, self.zero_convs):
            if guided_hint is not None:
                h = module(h, emb, context)  # conv_nd(h)
                h += guided_hint # x + conv_nd(hint)
                guided_hint = None
            else:
                h = module(h, emb, context)
            # 对加入hint后的h再次conv_nd
            outs.append(zero_conv(h, emb, context))

        h = self.middle_block(h, emb, context)
        # make_zero_conv
        outs.append(self.middle_block_out(h, emb, context))

        return outs
根据Scripts.scripts,脚本会在每次采样之前自动调用。
scripts/controlnet.py
 def process_before_every_sampling(self, p, *args, **kwargs):
    for i, unit in enumerate(self.get_enabled_units(args)):
        self.process_unit_before_every_sampling(p, unit, self.current_params[i], *args, **kwargs)
    return
走到process_unit_after_every_sampling
scripts/controlnet.py
 def process_unit_after_every_sampling(self,
                                          p: StableDiffusionProcessing,
                                          unit: ControlNetUnit,
                                          params: ControlNetCachedParameters,
                                          *args, **kwargs):

        params.preprocessor.process_after_every_sampling(p, params, *args, **kwargs)
        params.model.process_after_every_sampling(p, params, *args, **kwargs)
        return
核心处理逻辑
modules_forge/supported_controlnet.py
  def process_before_every_sampling(self, process, cond, mask, *args, **kwargs):
        unet = process.sd_model.forge_objects.unet

        unet = apply_controlnet_advanced(
            unet=unet,
            controlnet=self.model_patcher,
            image_bchw=cond,
            strength=self.strength,
            start_percent=self.start_percent,
            end_percent=self.end_percent,
            positive_advanced_weighting=self.positive_advanced_weighting,
            negative_advanced_weighting=self.negative_advanced_weighting,
            advanced_frame_weighting=self.advanced_frame_weighting,
            advanced_sigma_weighting=self.advanced_sigma_weighting,
            advanced_mask_weighting=self.advanced_mask_weighting
        )

        process.sd_model.forge_objects.unet = unet
        return
apply_controlnet_advanced,将参数添加到cnet模型中。
modules_forge/controlnet.py
 def apply_controlnet_advanced(
        unet,
        controlnet,
        image_bchw,
        strength,
        start_percent,
        end_percent,
        positive_advanced_weighting=None,
        negative_advanced_weighting=None,
        advanced_frame_weighting=None,
        advanced_sigma_weighting=None,
        advanced_mask_weighting=None
):
    # 使用 set_cond_hint 方法设置条件提示,包括图像张量、强度和时间百分比范围。
    cnet = controlnet.copy().set_cond_hint(image_bchw, strength, (start_percent, end_percent))
    # 设置正向加权、负向加权、帧加权和 sigma 加权等高级加权参数
    cnet.positive_advanced_weighting = positive_advanced_weighting
    cnet.negative_advanced_weighting = negative_advanced_weighting
    cnet.advanced_frame_weighting = advanced_frame_weighting
    cnet.advanced_sigma_weighting = advanced_sigma_weighting
    
    # 如果提供了高级掩码加权参数,验证其类型和形状,并将其赋值cnet
    if advanced_mask_weighting is not None:
        assert isinstance(advanced_mask_weighting, torch.Tensor)
        B, C, H, W = advanced_mask_weighting.shape
        assert B > 0 and C == 1 and H > 0 and W > 0

    cnet.advanced_mask_weighting = advanced_mask_weighting
    # 创建一个 UNet 模型的副本,add_patched_controlnet 将配置好的 ControlNet 副本添加到 UNet 模型中。
    m = unet.clone()
    m.add_patched_controlnet(cnet)
    return m
通过add_patched_controlnet函数将cnet以列表参数的形式加到UNet中
modules_forge/unet_patcher.py
 def add_patched_controlnet(self, cnet):
        cnet.set_previous_controlnet(self.controlnet_linked_list)
        self.controlnet_linked_list = cnet
        return
在forge_sampler.py的forge_sample用到了这个list:controlnet_linked_list, 
modules_forge/forge_sampler.py
 def forge_sample(self, denoiser_params, cond_scale, cond_composition):
    model = self.inner_model.inner_model.forge_objects.unet.model
    # ----here----
    control = self.inner_model.inner_model.forge_objects.unet.controlnet_linked_list

    extra_concat_condition = self.inner_model.inner_model.forge_objects.unet.extra_concat_condition
    x = denoiser_params.x
    timestep = denoiser_params.sigma
    uncond = cond_from_a1111_to_patched_ldm(denoiser_params.text_uncond)
    cond = cond_from_a1111_to_patched_ldm_weighted(denoiser_params.text_cond, cond_composition)
    model_options = self.inner_model.inner_model.forge_objects.unet.model_options
    seed = self.p.seeds[0]

    if extra_concat_condition is not None:
        image_cond_in = extra_concat_condition
    else:
        image_cond_in = denoiser_params.image_cond

    if isinstance(image_cond_in, torch.Tensor):
        if image_cond_in.shape[0] == x.shape[0] \
                and image_cond_in.shape[2] == x.shape[2] \
                and image_cond_in.shape[3] == x.shape[3]:
            for i in range(len(uncond)):
                uncond[i]['model_conds']['c_concat'] = CONDRegular(image_cond_in)
            for i in range(len(cond)):
                cond[i]['model_conds']['c_concat'] = CONDRegular(image_cond_in)
    # ----here----
    if control is not None:
        for h in cond + uncond:
            h['control'] = control

    for modifier in model_options.get('conditioning_modifiers', []):
        model, x, timestep, uncond, cond, cond_scale, model_options, seed = modifier(model, x, timestep, uncond, cond, cond_scale, model_options, seed)
    
    # sampling_function是采样函数
    denoised = sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options, seed)
    return denoised
在sampling_function中的calc_cond_uncond_batch中计算了ControlNet的输出
ldm_patched/modules/samplers.py
 def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None):
    ...
    cond_pred, uncond_pred = calc_cond_uncond_batch(model, cond, uncond_, x, timestep, model_options)
其中,通过control.get_control得到了输出output
def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options):
    ...
        if control is not None:
            p = control
            while p is not None:
                p.transformer_options = transformer_options
                p = p.previous_controlnet
            control_cond = c.copy()  # get_control may change items in this dict, so we need to copy it
            c['control'] = control.get_control(input_x, timestep_, control_cond, len(cond_or_uncond))
            c['control_model'] = control
get_control函数,最后通过control merge合并了输出。
 def get_control(self, x_noisy, t, cond, batched_number):
      to = self.transformer_options

        for conditioning_modifier in to.get('controlnet_conditioning_modifiers', []):
            x_noisy, t, cond, batched_number = conditioning_modifier(self, x_noisy, t, cond, batched_number)

        control_prev = None
        if self.previous_controlnet is not None:
            control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)

        if self.timestep_range is not None:
            if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
                if control_prev is not None:
                    return control_prev
                else:
                    return None

        dtype = self.control_model.dtype
        if self.manual_cast_dtype is not None:
            dtype = self.manual_cast_dtype

        output_dtype = x_noisy.dtype
        if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]:
            if self.cond_hint is not None:
                del self.cond_hint
            self.cond_hint = None
            self.cond_hint = ldm_patched.modules.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(dtype)
        if x_noisy.shape[0] != self.cond_hint.shape[0]:
            self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)

        context = cond['c_crossattn']
        y = cond.get('y', None)
        if y is not None:
            y = y.to(dtype)
        timestep = self.model_sampling_current.timestep(t)
        x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)

        controlnet_model_function_wrapper = to.get('controlnet_model_function_wrapper', None)

        if controlnet_model_function_wrapper is not None:
            wrapper_args = dict(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(),
                                context=context.to(dtype), y=y)
            wrapper_args['model'] = self
            wrapper_args['inner_model'] = self.control_model
            control = controlnet_model_function_wrapper(**wrapper_args)
        else:
            control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint.to(self.device), timesteps=timestep.float(), context=context.to(dtype), y=y)
        return self.control_merge(None, control, control_prev, output_dtype)
control_merge
ldm_patched/modules/controlnet.py
     def control_merge(self, control_input, control_output, control_prev, output_dtype):
        out = {'input':[], 'middle':[], 'output': []}

        if control_input is not None:
            for i in range(len(control_input)):
                key = 'input'
                x = control_input[i]
                if x is not None:
                    x *= self.strength
                    if x.dtype != output_dtype:
                        x = x.to(output_dtype)
                out[key].insert(0, x)

        if control_output is not None:
            for i in range(len(control_output)):
                if i == (len(control_output) - 1):
                    key = 'middle'
                    index = 0
                else:
                    key = 'output'
                    index = i
                x = control_output[i]
                if x is not None:
                    if self.global_average_pooling:
                        x = torch.mean(x, dim=(2, 3), keepdim=True).repeat(1, 1, x.shape[2], x.shape[3])

                    x *= self.strength
                    if x.dtype != output_dtype:
                        x = x.to(output_dtype)

                out[key].append(x)

        out = compute_controlnet_weighting(out, self)

        if control_prev is not None:
            for x in ['input', 'middle', 'output']:
                o = out[x]
                for i in range(len(control_prev[x])):
                    prev_val = control_prev[x][i]
                    if i >= len(o):
                        o.append(prev_val)
                    elif prev_val is not None:
                        if o[i] is None:
                            o[i] = prev_val
                        else:
                            if o[i].shape[0] < prev_val.shape[0]:
                                o[i] = prev_val + o[i]
                            else:
                                o[i] += prev_val
        return out
自此之后就是走UNet的部分。回到calc_cond_uncond_batch函数中,走apply_model。 
ldm_patched/modules/samplers.py
         if 'model_function_wrapper' in model_options:
            output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks)
        else:
            output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks)
其中,model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float() ,实现了UNetModel的forward,自此所有模型调用完毕。
ldm_patched/modules/model_base.py
     def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
        sigma = t
        xc = self.model_sampling.calculate_input(sigma, x)
        if c_concat is not None:
            xc = torch.cat([xc] + [c_concat], dim=1)

        context = c_crossattn
        dtype = self.get_dtype()

        if self.manual_cast_dtype is not None:
            dtype = self.manual_cast_dtype

        xc = xc.to(dtype)
        t = self.model_sampling.timestep(t).float()
        context = context.to(dtype)
        extra_conds = {}
        for o in kwargs:
            extra = kwargs[o]
            if hasattr(extra, "dtype"):
                if extra.dtype != torch.int and extra.dtype != torch.long:
                    extra = extra.to(dtype)
            extra_conds[o] = extra

        model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float()
        return self.model_sampling.calculate_denoised(sigma, model_output, x)
其中,diffusion_model实例初始化在:
ldm_patched/modules/model_base.py
 class BaseModel(torch.nn.Module):
    def __init__(self, model_config, model_type=ModelType.EPS, device=None):
    ...
    self.diffusion_model = UNetModel(**unet_config, device=device, operations=operations)

UNetModel

根据control_merge函数,可知输入control是一个计算完的权重字典,包含了'input'、'middle'和'output'三个键。每个键对应的值是一个权重列表。apply_control将ControlNet的输出应用到UNet的不同阶段。
 
ldm_patched/ldm/modules/diffusionmodel/openai.py
 # UNetModel    
    def forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs):
        """
        Apply the model to an input batch.
        :param x: an [N x C x ...] Tensor of inputs.
        :param timesteps: a 1-D batch of timesteps.
        :param context: conditioning plugged in via crossattn
        :param y: an [N] Tensor of labels, if class-conditional.
        :return: an [N x C x ...] Tensor of outputs.
        """
        transformer_options["original_shape"] = list(x.shape)
        transformer_options["transformer_index"] = 0
        transformer_patches = transformer_options.get("patches", {})
        block_modifiers = transformer_options.get("block_modifiers", [])

        num_video_frames = kwargs.get("num_video_frames", self.default_num_video_frames)
        image_only_indicator = kwargs.get("image_only_indicator", self.default_image_only_indicator)
        time_context = kwargs.get("time_context", None)

        assert (y is not None) == (
            self.num_classes is not None
        ), "must specify y if and only if the model is class-conditional"
        hs = []
        t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
        emb = self.time_embed(t_emb)

        if self.num_classes is not None:
            assert y.shape[0] == x.shape[0]
            emb = emb + self.label_emb(y)

        h = x
        for id, module in enumerate(self.input_blocks):
            transformer_options["block"] = ("input", id)

            for block_modifier in block_modifiers:
                h = block_modifier(h, 'before', transformer_options)

            h = forward_timestep_embed(module, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
            
            # 在输入块处理后
            h = apply_control(h, control, 'input')

            for block_modifier in block_modifiers:
                h = block_modifier(h, 'after', transformer_options)

            if "input_block_patch" in transformer_patches:
                patch = transformer_patches["input_block_patch"]
                for p in patch:
                    h = p(h, transformer_options)

            hs.append(h)
            if "input_block_patch_after_skip" in transformer_patches:
                patch = transformer_patches["input_block_patch_after_skip"]
                for p in patch:
                    h = p(h, transformer_options)

        transformer_options["block"] = ("middle", 0)

        for block_modifier in block_modifiers:
            h = block_modifier(h, 'before', transformer_options)

        h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
        
        # 在中间块处理后
        h = apply_control(h, control, 'middle')

        for block_modifier in block_modifiers:
            h = block_modifier(h, 'after', transformer_options)

        for id, module in enumerate(self.output_blocks):
            transformer_options["block"] = ("output", id)
            hsp = hs.pop()
            # 在输出块处理前
            hsp = apply_control(hsp, control, 'output')

            if "output_block_patch" in transformer_patches:
                patch = transformer_patches["output_block_patch"]
                for p in patch:
                    h, hsp = p(h, hsp, transformer_options)

            h = th.cat([h, hsp], dim=1)
            del hsp
            if len(hs) > 0:
                output_shape = hs[-1].shape
            else:
                output_shape = None

            for block_modifier in block_modifiers:
                h = block_modifier(h, 'before', transformer_options)

            h = forward_timestep_embed(module, h, emb, context, transformer_options, output_shape, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)

            for block_modifier in block_modifiers:
                h = block_modifier(h, 'after', transformer_options)

        transformer_options["block"] = ("last", 0)

        for block_modifier in block_modifiers:
            h = block_modifier(h, 'before', transformer_options)

        if self.predict_codebook_ids:
            h = self.id_predictor(h)
        else:
            h = self.out(h)

        for block_modifier in block_modifiers:
            h = block_modifier(h, 'after', transformer_options)

        return h.type(x.dtype)
 

Ipadpter forge controlnet中预处理器和模型的对应对应:

posted @ 2024-09-02 22:54  Everyday_Struggle  阅读(410)  评论(0)    收藏  举报