stable diffusion webui forge 代码解读
ControlNet
在webui中,controlnet作为extension的形式存在。在Forge中,开发者将controlnet直接集成。
脚本处理入口(前执行process)
extensions-builtin/sd_forge_controlnet/scripts/controlnet.py
@torch.no_grad()
def process(self, p, *args, **kwargs):
self.current_params = {}
enabled_units = self.get_enabled_units(args)
Infotext.write_infotext(enabled_units, p)
for i, unit in enumerate(enabled_units):
self.bound_check_params(unit)
params = ControlNetCachedParameters()
self.process_unit_after_click_generate(p, unit, params, *args, **kwargs)
self.current_params[i] = params
return
其中,process_unit_after_click_generate中的预处理
scripts/controlnet.py/process_unit_after_click_generate
if preprocessor.do_not_need_model:
model_filename = 'Not Needed'
params.model = ControlModelPatcher()
else:
assert unit.model != 'None', 'You have not selected any control model!'
model_filename = global_state.get_controlnet_filename(unit.model)
params.model = cached_controlnet_loader(model_filename)
assert params.model is not None, logger.error(f"Recognizing Control Model failed: {model_filename}")
在cached_controlnet_loader()调用了try_load_supported_control_model()
scripts/controlnet.py
@functools.lru_cache(maxsize=shared.opts.data.get("control_net_model_cache_size", 5))
def cached_controlnet_loader(filename):
return try_load_supported_control_model(filename)
调用了modules/shared.py中的try_load_supported_control_model
modules/shared.py
def try_load_supported_control_model(ckpt_path):
global supported_control_models
state_dict = ldm_patched.modules.utils.load_torch_file(ckpt_path, safe_load=True)
for supported_type in supported_control_models:
state_dict_copy = {k: v for k, v in state_dict.items()}
model = supported_type.try_build_from_state_dict(state_dict_copy, ckpt_path)
if model is not None:
return model
return None
对于缓存中的每个支持的controlnet模型,进行加载模型权重try_build_from_state_dict(),此处先拿到原始的ControlNet结构
modules_forge/supported_controlnet.py
def try_build_from_state_dict(controlnet_data, ckpt_path)
...
control_model = ldm_patched.controlnet.cldm.ControlNet(**controlnet_config)
...
class WeightsLoader(torch.nn.Module):
pass
w = WeightsLoader()
w.control_model = control_model
missing, unexpected = w.load_state_dict(controlnet_data, strict=False)
原始的ControlNet类定义:
ldm_patched/controlnet/cldm.py
class ControlNet(nn.Module):
def __init__(
self,
image_size,
in_channels,
model_channels,
hint_channels,
num_res_blocks,
dropout=0,
channel_mult=(1, 2, 4, 8),
conv_resample=True,
dims=2,
num_classes=None,
use_checkpoint=False,
dtype=torch.float32,
num_heads=-1,
num_head_channels=-1,
num_heads_upsample=-1,
use_scale_shift_norm=False,
resblock_updown=False,
use_new_attention_order=False,
use_spatial_transformer=False, # custom transformer support
transformer_depth=1, # custom transformer support
context_dim=None, # custom transformer support
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
legacy=True,
disable_self_attentions=None,
num_attention_blocks=None,
disable_middle_self_attn=False,
use_linear_in_transformer=False,
adm_in_channels=None,
transformer_depth_middle=None,
transformer_depth_output=None,
device=None,
operations=ldm_patched.modules.ops.disable_weight_init,
**kwargs,
):
ldm_patched/controlnet/cldm.py
def forward(self, x, hint, timesteps, context, y=None, **kwargs):
# 把时间t编码为vector
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
# linear->relu->linear emb.shape=(B, time_embed_dim)
emb = self.time_embed(t_emb)
# 只对 hint 进行conv->silu->conv->silu->...->conv->silu->zero_conv
guided_hint = self.input_hint_block(hint, emb, context)
outs = []
hs = []
if self.num_classes is not None:
assert y.shape[0] == x.shape[0]
emb = emb + self.label_emb(y)
h = x
for module, zero_conv in zip(self.input_blocks, self.zero_convs):
if guided_hint is not None:
h = module(h, emb, context) # conv_nd(h)
h += guided_hint # x + conv_nd(hint)
guided_hint = None
else:
h = module(h, emb, context)
# 对加入hint后的h再次conv_nd
outs.append(zero_conv(h, emb, context))
h = self.middle_block(h, emb, context)
# make_zero_conv
outs.append(self.middle_block_out(h, emb, context))
return outs
根据Scripts.scripts,脚本会在每次采样之前自动调用。
scripts/controlnet.py
def process_before_every_sampling(self, p, *args, **kwargs):
for i, unit in enumerate(self.get_enabled_units(args)):
self.process_unit_before_every_sampling(p, unit, self.current_params[i], *args, **kwargs)
return
走到process_unit_after_every_sampling
scripts/controlnet.py
def process_unit_after_every_sampling(self,
p: StableDiffusionProcessing,
unit: ControlNetUnit,
params: ControlNetCachedParameters,
*args, **kwargs):
params.preprocessor.process_after_every_sampling(p, params, *args, **kwargs)
params.model.process_after_every_sampling(p, params, *args, **kwargs)
return
核心处理逻辑
modules_forge/supported_controlnet.py
def process_before_every_sampling(self, process, cond, mask, *args, **kwargs):
unet = process.sd_model.forge_objects.unet
unet = apply_controlnet_advanced(
unet=unet,
controlnet=self.model_patcher,
image_bchw=cond,
strength=self.strength,
start_percent=self.start_percent,
end_percent=self.end_percent,
positive_advanced_weighting=self.positive_advanced_weighting,
negative_advanced_weighting=self.negative_advanced_weighting,
advanced_frame_weighting=self.advanced_frame_weighting,
advanced_sigma_weighting=self.advanced_sigma_weighting,
advanced_mask_weighting=self.advanced_mask_weighting
)
process.sd_model.forge_objects.unet = unet
return
apply_controlnet_advanced,将参数添加到cnet模型中。
modules_forge/controlnet.py
def apply_controlnet_advanced(
unet,
controlnet,
image_bchw,
strength,
start_percent,
end_percent,
positive_advanced_weighting=None,
negative_advanced_weighting=None,
advanced_frame_weighting=None,
advanced_sigma_weighting=None,
advanced_mask_weighting=None
):
# 使用 set_cond_hint 方法设置条件提示,包括图像张量、强度和时间百分比范围。
cnet = controlnet.copy().set_cond_hint(image_bchw, strength, (start_percent, end_percent))
# 设置正向加权、负向加权、帧加权和 sigma 加权等高级加权参数
cnet.positive_advanced_weighting = positive_advanced_weighting
cnet.negative_advanced_weighting = negative_advanced_weighting
cnet.advanced_frame_weighting = advanced_frame_weighting
cnet.advanced_sigma_weighting = advanced_sigma_weighting
# 如果提供了高级掩码加权参数,验证其类型和形状,并将其赋值cnet
if advanced_mask_weighting is not None:
assert isinstance(advanced_mask_weighting, torch.Tensor)
B, C, H, W = advanced_mask_weighting.shape
assert B > 0 and C == 1 and H > 0 and W > 0
cnet.advanced_mask_weighting = advanced_mask_weighting
# 创建一个 UNet 模型的副本,add_patched_controlnet 将配置好的 ControlNet 副本添加到 UNet 模型中。
m = unet.clone()
m.add_patched_controlnet(cnet)
return m
通过add_patched_controlnet函数将cnet以列表参数的形式加到UNet中
modules_forge/unet_patcher.py
def add_patched_controlnet(self, cnet):
cnet.set_previous_controlnet(self.controlnet_linked_list)
self.controlnet_linked_list = cnet
return
在forge_sampler.py的forge_sample用到了这个list:controlnet_linked_list,
modules_forge/forge_sampler.py
def forge_sample(self, denoiser_params, cond_scale, cond_composition):
model = self.inner_model.inner_model.forge_objects.unet.model
# ----here----
control = self.inner_model.inner_model.forge_objects.unet.controlnet_linked_list
extra_concat_condition = self.inner_model.inner_model.forge_objects.unet.extra_concat_condition
x = denoiser_params.x
timestep = denoiser_params.sigma
uncond = cond_from_a1111_to_patched_ldm(denoiser_params.text_uncond)
cond = cond_from_a1111_to_patched_ldm_weighted(denoiser_params.text_cond, cond_composition)
model_options = self.inner_model.inner_model.forge_objects.unet.model_options
seed = self.p.seeds[0]
if extra_concat_condition is not None:
image_cond_in = extra_concat_condition
else:
image_cond_in = denoiser_params.image_cond
if isinstance(image_cond_in, torch.Tensor):
if image_cond_in.shape[0] == x.shape[0] \
and image_cond_in.shape[2] == x.shape[2] \
and image_cond_in.shape[3] == x.shape[3]:
for i in range(len(uncond)):
uncond[i]['model_conds']['c_concat'] = CONDRegular(image_cond_in)
for i in range(len(cond)):
cond[i]['model_conds']['c_concat'] = CONDRegular(image_cond_in)
# ----here----
if control is not None:
for h in cond + uncond:
h['control'] = control
for modifier in model_options.get('conditioning_modifiers', []):
model, x, timestep, uncond, cond, cond_scale, model_options, seed = modifier(model, x, timestep, uncond, cond, cond_scale, model_options, seed)
# sampling_function是采样函数
denoised = sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options, seed)
return denoised
在sampling_function中的calc_cond_uncond_batch中计算了ControlNet的输出
ldm_patched/modules/samplers.py
def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None):
...
cond_pred, uncond_pred = calc_cond_uncond_batch(model, cond, uncond_, x, timestep, model_options)
其中,通过control.get_control得到了输出output
def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options):
...
if control is not None:
p = control
while p is not None:
p.transformer_options = transformer_options
p = p.previous_controlnet
control_cond = c.copy() # get_control may change items in this dict, so we need to copy it
c['control'] = control.get_control(input_x, timestep_, control_cond, len(cond_or_uncond))
c['control_model'] = control
get_control函数,最后通过control merge合并了输出。
def get_control(self, x_noisy, t, cond, batched_number):
to = self.transformer_options
for conditioning_modifier in to.get('controlnet_conditioning_modifiers', []):
x_noisy, t, cond, batched_number = conditioning_modifier(self, x_noisy, t, cond, batched_number)
control_prev = None
if self.previous_controlnet is not None:
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)
if self.timestep_range is not None:
if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
if control_prev is not None:
return control_prev
else:
return None
dtype = self.control_model.dtype
if self.manual_cast_dtype is not None:
dtype = self.manual_cast_dtype
output_dtype = x_noisy.dtype
if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]:
if self.cond_hint is not None:
del self.cond_hint
self.cond_hint = None
self.cond_hint = ldm_patched.modules.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(dtype)
if x_noisy.shape[0] != self.cond_hint.shape[0]:
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
context = cond['c_crossattn']
y = cond.get('y', None)
if y is not None:
y = y.to(dtype)
timestep = self.model_sampling_current.timestep(t)
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
controlnet_model_function_wrapper = to.get('controlnet_model_function_wrapper', None)
if controlnet_model_function_wrapper is not None:
wrapper_args = dict(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(),
context=context.to(dtype), y=y)
wrapper_args['model'] = self
wrapper_args['inner_model'] = self.control_model
control = controlnet_model_function_wrapper(**wrapper_args)
else:
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint.to(self.device), timesteps=timestep.float(), context=context.to(dtype), y=y)
return self.control_merge(None, control, control_prev, output_dtype)
control_merge
ldm_patched/modules/controlnet.py
def control_merge(self, control_input, control_output, control_prev, output_dtype):
out = {'input':[], 'middle':[], 'output': []}
if control_input is not None:
for i in range(len(control_input)):
key = 'input'
x = control_input[i]
if x is not None:
x *= self.strength
if x.dtype != output_dtype:
x = x.to(output_dtype)
out[key].insert(0, x)
if control_output is not None:
for i in range(len(control_output)):
if i == (len(control_output) - 1):
key = 'middle'
index = 0
else:
key = 'output'
index = i
x = control_output[i]
if x is not None:
if self.global_average_pooling:
x = torch.mean(x, dim=(2, 3), keepdim=True).repeat(1, 1, x.shape[2], x.shape[3])
x *= self.strength
if x.dtype != output_dtype:
x = x.to(output_dtype)
out[key].append(x)
out = compute_controlnet_weighting(out, self)
if control_prev is not None:
for x in ['input', 'middle', 'output']:
o = out[x]
for i in range(len(control_prev[x])):
prev_val = control_prev[x][i]
if i >= len(o):
o.append(prev_val)
elif prev_val is not None:
if o[i] is None:
o[i] = prev_val
else:
if o[i].shape[0] < prev_val.shape[0]:
o[i] = prev_val + o[i]
else:
o[i] += prev_val
return out
自此之后就是走UNet的部分。回到calc_cond_uncond_batch函数中,走apply_model。
ldm_patched/modules/samplers.py
if 'model_function_wrapper' in model_options:
output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks)
else:
output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks)
其中,model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float() ,实现了UNetModel的forward,自此所有模型调用完毕。
ldm_patched/modules/model_base.py
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
sigma = t
xc = self.model_sampling.calculate_input(sigma, x)
if c_concat is not None:
xc = torch.cat([xc] + [c_concat], dim=1)
context = c_crossattn
dtype = self.get_dtype()
if self.manual_cast_dtype is not None:
dtype = self.manual_cast_dtype
xc = xc.to(dtype)
t = self.model_sampling.timestep(t).float()
context = context.to(dtype)
extra_conds = {}
for o in kwargs:
extra = kwargs[o]
if hasattr(extra, "dtype"):
if extra.dtype != torch.int and extra.dtype != torch.long:
extra = extra.to(dtype)
extra_conds[o] = extra
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float()
return self.model_sampling.calculate_denoised(sigma, model_output, x)
其中,diffusion_model实例初始化在:
ldm_patched/modules/model_base.py
class BaseModel(torch.nn.Module):
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
...
self.diffusion_model = UNetModel(**unet_config, device=device, operations=operations)
UNetModel
根据control_merge函数,可知输入control是一个计算完的权重字典,包含了'input'、'middle'和'output'三个键。每个键对应的值是一个权重列表。apply_control将ControlNet的输出应用到UNet的不同阶段。
ldm_patched/ldm/modules/diffusionmodel/openai.py
# UNetModel
def forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs):
"""
Apply the model to an input batch.
:param x: an [N x C x ...] Tensor of inputs.
:param timesteps: a 1-D batch of timesteps.
:param context: conditioning plugged in via crossattn
:param y: an [N] Tensor of labels, if class-conditional.
:return: an [N x C x ...] Tensor of outputs.
"""
transformer_options["original_shape"] = list(x.shape)
transformer_options["transformer_index"] = 0
transformer_patches = transformer_options.get("patches", {})
block_modifiers = transformer_options.get("block_modifiers", [])
num_video_frames = kwargs.get("num_video_frames", self.default_num_video_frames)
image_only_indicator = kwargs.get("image_only_indicator", self.default_image_only_indicator)
time_context = kwargs.get("time_context", None)
assert (y is not None) == (
self.num_classes is not None
), "must specify y if and only if the model is class-conditional"
hs = []
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
emb = self.time_embed(t_emb)
if self.num_classes is not None:
assert y.shape[0] == x.shape[0]
emb = emb + self.label_emb(y)
h = x
for id, module in enumerate(self.input_blocks):
transformer_options["block"] = ("input", id)
for block_modifier in block_modifiers:
h = block_modifier(h, 'before', transformer_options)
h = forward_timestep_embed(module, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
# 在输入块处理后
h = apply_control(h, control, 'input')
for block_modifier in block_modifiers:
h = block_modifier(h, 'after', transformer_options)
if "input_block_patch" in transformer_patches:
patch = transformer_patches["input_block_patch"]
for p in patch:
h = p(h, transformer_options)
hs.append(h)
if "input_block_patch_after_skip" in transformer_patches:
patch = transformer_patches["input_block_patch_after_skip"]
for p in patch:
h = p(h, transformer_options)
transformer_options["block"] = ("middle", 0)
for block_modifier in block_modifiers:
h = block_modifier(h, 'before', transformer_options)
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
# 在中间块处理后
h = apply_control(h, control, 'middle')
for block_modifier in block_modifiers:
h = block_modifier(h, 'after', transformer_options)
for id, module in enumerate(self.output_blocks):
transformer_options["block"] = ("output", id)
hsp = hs.pop()
# 在输出块处理前
hsp = apply_control(hsp, control, 'output')
if "output_block_patch" in transformer_patches:
patch = transformer_patches["output_block_patch"]
for p in patch:
h, hsp = p(h, hsp, transformer_options)
h = th.cat([h, hsp], dim=1)
del hsp
if len(hs) > 0:
output_shape = hs[-1].shape
else:
output_shape = None
for block_modifier in block_modifiers:
h = block_modifier(h, 'before', transformer_options)
h = forward_timestep_embed(module, h, emb, context, transformer_options, output_shape, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
for block_modifier in block_modifiers:
h = block_modifier(h, 'after', transformer_options)
transformer_options["block"] = ("last", 0)
for block_modifier in block_modifiers:
h = block_modifier(h, 'before', transformer_options)
if self.predict_codebook_ids:
h = self.id_predictor(h)
else:
h = self.out(h)
for block_modifier in block_modifiers:
h = block_modifier(h, 'after', transformer_options)
return h.type(x.dtype)

浙公网安备 33010602011771号