diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py index 7c371deb..edd95f78 100644 --- a/extensions-builtin/Lora/lora.py +++ b/extensions-builtin/Lora/lora.py @@ -8,14 +8,27 @@ from modules import shared, devices, sd_models, errors metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20} re_digits = re.compile(r"\d+") -re_unet_down_blocks = re.compile(r"lora_unet_down_blocks_(\d+)_attentions_(\d+)_(.+)") -re_unet_mid_blocks = re.compile(r"lora_unet_mid_block_attentions_(\d+)_(.+)") -re_unet_up_blocks = re.compile(r"lora_unet_up_blocks_(\d+)_attentions_(\d+)_(.+)") -re_text_block = re.compile(r"lora_te_text_model_encoder_layers_(\d+)_(.+)") +re_x_proj = re.compile(r"(.*)_([qkv]_proj)$") +re_compiled = {} + +suffix_conversion = { + "attentions": {}, + "resnets": { + "conv1": "in_layers_2", + "conv2": "out_layers_3", + "time_emb_proj": "emb_layers_1", + "conv_shortcut": "skip_connection", + } +} -def convert_diffusers_name_to_compvis(key): - def match(match_list, regex): +def convert_diffusers_name_to_compvis(key, is_sd2): + def match(match_list, regex_text): + regex = re_compiled.get(regex_text) + if regex is None: + regex = re.compile(regex_text) + re_compiled[regex_text] = regex + r = re.match(regex, key) if not r: return False @@ -26,16 +39,33 @@ def convert_diffusers_name_to_compvis(key): m = [] - if match(m, re_unet_down_blocks): - return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[1]}_1_{m[2]}" + if match(m, r"lora_unet_down_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"): + suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3]) + return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}" - if match(m, re_unet_mid_blocks): - return f"diffusion_model_middle_block_1_{m[1]}" + if match(m, r"lora_unet_mid_block_(attentions|resnets)_(\d+)_(.+)"): + suffix = suffix_conversion.get(m[0], {}).get(m[2], m[2]) + return f"diffusion_model_middle_block_{1 if m[0] == 'attentions' else m[1] * 2}_{suffix}" - if match(m, re_unet_up_blocks): - return f"diffusion_model_output_blocks_{m[0] * 3 + m[1]}_1_{m[2]}" + if match(m, r"lora_unet_up_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"): + suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3]) + return f"diffusion_model_output_blocks_{m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}" + + if match(m, r"lora_unet_down_blocks_(\d+)_downsamplers_0_conv"): + return f"diffusion_model_input_blocks_{3 + m[0] * 3}_0_op" + + if match(m, r"lora_unet_up_blocks_(\d+)_upsamplers_0_conv"): + return f"diffusion_model_output_blocks_{2 + m[0] * 3}_{2 if m[0]>0 else 1}_conv" + + if match(m, r"lora_te_text_model_encoder_layers_(\d+)_(.+)"): + if is_sd2: + if 'mlp_fc1' in m[1]: + return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}" + elif 'mlp_fc2' in m[1]: + return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}" + else: + return f"model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}" - if match(m, re_text_block): return f"transformer_text_model_encoder_layers_{m[0]}_{m[1]}" return key @@ -101,15 +131,22 @@ def load_lora(name, filename): sd = sd_models.read_state_dict(filename) - keys_failed_to_match = [] + keys_failed_to_match = {} + is_sd2 = 'model_transformer_resblocks' in shared.sd_model.lora_layer_mapping for key_diffusers, weight in sd.items(): - fullkey = convert_diffusers_name_to_compvis(key_diffusers) - key, lora_key = fullkey.split(".", 1) + key_diffusers_without_lora_parts, lora_key = key_diffusers.split(".", 1) + key = convert_diffusers_name_to_compvis(key_diffusers_without_lora_parts, is_sd2) sd_module = shared.sd_model.lora_layer_mapping.get(key, None) + if sd_module is None: - keys_failed_to_match.append(key_diffusers) + m = re_x_proj.match(key) + if m: + sd_module = shared.sd_model.lora_layer_mapping.get(m.group(1), None) + + if sd_module is None: + keys_failed_to_match[key_diffusers] = key continue lora_module = lora.modules.get(key, None) @@ -123,15 +160,21 @@ def load_lora(name, filename): if type(sd_module) == torch.nn.Linear: module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) + elif type(sd_module) == torch.nn.modules.linear.NonDynamicallyQuantizableLinear: + module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) + elif type(sd_module) == torch.nn.MultiheadAttention: + module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) elif type(sd_module) == torch.nn.Conv2d: module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False) else: + print(f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}') + continue assert False, f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}' with torch.no_grad(): module.weight.copy_(weight) - module.to(device=devices.device, dtype=devices.dtype) + module.to(device=devices.cpu, dtype=devices.dtype) if lora_key == "lora_up.weight": lora_module.up = module @@ -177,29 +220,120 @@ def load_loras(names, multipliers=None): loaded_loras.append(lora) -def lora_forward(module, input, res): - input = devices.cond_cast_unet(input) - if len(loaded_loras) == 0: - return res +def lora_calc_updown(lora, module, target): + with torch.no_grad(): + up = module.up.weight.to(target.device, dtype=target.dtype) + down = module.down.weight.to(target.device, dtype=target.dtype) - lora_layer_name = getattr(module, 'lora_layer_name', None) - for lora in loaded_loras: - module = lora.modules.get(lora_layer_name, None) - if module is not None: - if shared.opts.lora_apply_to_outputs and res.shape == input.shape: - res = res + module.up(module.down(res)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0) + if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1): + updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3) + else: + updown = up @ down + + updown = updown * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0) + + return updown + + +def lora_apply_weights(self: torch.nn.Conv2d | torch.nn.Linear | torch.nn.MultiheadAttention): + """ + Applies the currently selected set of Loras to the weights of torch layer self. + If weights already have this particular set of loras applied, does nothing. + If not, restores orginal weights from backup and alters weights according to loras. + """ + + lora_layer_name = getattr(self, 'lora_layer_name', None) + if lora_layer_name is None: + return + + current_names = getattr(self, "lora_current_names", ()) + wanted_names = tuple((x.name, x.multiplier) for x in loaded_loras) + + weights_backup = getattr(self, "lora_weights_backup", None) + if weights_backup is None: + if isinstance(self, torch.nn.MultiheadAttention): + weights_backup = (self.in_proj_weight.to(devices.cpu, copy=True), self.out_proj.weight.to(devices.cpu, copy=True)) + else: + weights_backup = self.weight.to(devices.cpu, copy=True) + + self.lora_weights_backup = weights_backup + + if current_names != wanted_names: + if weights_backup is not None: + if isinstance(self, torch.nn.MultiheadAttention): + self.in_proj_weight.copy_(weights_backup[0]) + self.out_proj.weight.copy_(weights_backup[1]) else: - res = res + module.up(module.down(input)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0) + self.weight.copy_(weights_backup) - return res + for lora in loaded_loras: + module = lora.modules.get(lora_layer_name, None) + if module is not None and hasattr(self, 'weight'): + self.weight += lora_calc_updown(lora, module, self.weight) + continue + + module_q = lora.modules.get(lora_layer_name + "_q_proj", None) + module_k = lora.modules.get(lora_layer_name + "_k_proj", None) + module_v = lora.modules.get(lora_layer_name + "_v_proj", None) + module_out = lora.modules.get(lora_layer_name + "_out_proj", None) + + if isinstance(self, torch.nn.MultiheadAttention) and module_q and module_k and module_v and module_out: + updown_q = lora_calc_updown(lora, module_q, self.in_proj_weight) + updown_k = lora_calc_updown(lora, module_k, self.in_proj_weight) + updown_v = lora_calc_updown(lora, module_v, self.in_proj_weight) + updown_qkv = torch.vstack([updown_q, updown_k, updown_v]) + + self.in_proj_weight += updown_qkv + self.out_proj.weight += lora_calc_updown(lora, module_out, self.out_proj.weight) + continue + + if module is None: + continue + + print(f'failed to calculate lora weights for layer {lora_layer_name}') + + setattr(self, "lora_current_names", wanted_names) + + +def lora_reset_cached_weight(self: torch.nn.Conv2d | torch.nn.Linear): + setattr(self, "lora_current_names", ()) + setattr(self, "lora_weights_backup", None) def lora_Linear_forward(self, input): - return lora_forward(self, input, torch.nn.Linear_forward_before_lora(self, input)) + lora_apply_weights(self) + + return torch.nn.Linear_forward_before_lora(self, input) + + +def lora_Linear_load_state_dict(self, *args, **kwargs): + lora_reset_cached_weight(self) + + return torch.nn.Linear_load_state_dict_before_lora(self, *args, **kwargs) def lora_Conv2d_forward(self, input): - return lora_forward(self, input, torch.nn.Conv2d_forward_before_lora(self, input)) + lora_apply_weights(self) + + return torch.nn.Conv2d_forward_before_lora(self, input) + + +def lora_Conv2d_load_state_dict(self, *args, **kwargs): + lora_reset_cached_weight(self) + + return torch.nn.Conv2d_load_state_dict_before_lora(self, *args, **kwargs) + + +def lora_MultiheadAttention_forward(self, *args, **kwargs): + lora_apply_weights(self) + + return torch.nn.MultiheadAttention_forward_before_lora(self, *args, **kwargs) + + +def lora_MultiheadAttention_load_state_dict(self, *args, **kwargs): + lora_reset_cached_weight(self) + + return torch.nn.MultiheadAttention_load_state_dict_before_lora(self, *args, **kwargs) def list_available_loras(): diff --git a/extensions-builtin/Lora/scripts/lora_script.py b/extensions-builtin/Lora/scripts/lora_script.py index 2e860160..0adab225 100644 --- a/extensions-builtin/Lora/scripts/lora_script.py +++ b/extensions-builtin/Lora/scripts/lora_script.py @@ -9,7 +9,11 @@ from modules import script_callbacks, ui_extra_networks, extra_networks, shared def unload(): torch.nn.Linear.forward = torch.nn.Linear_forward_before_lora + torch.nn.Linear._load_from_state_dict = torch.nn.Linear_load_state_dict_before_lora torch.nn.Conv2d.forward = torch.nn.Conv2d_forward_before_lora + torch.nn.Conv2d._load_from_state_dict = torch.nn.Conv2d_load_state_dict_before_lora + torch.nn.MultiheadAttention.forward = torch.nn.MultiheadAttention_forward_before_lora + torch.nn.MultiheadAttention._load_from_state_dict = torch.nn.MultiheadAttention_load_state_dict_before_lora def before_ui(): @@ -20,11 +24,27 @@ def before_ui(): if not hasattr(torch.nn, 'Linear_forward_before_lora'): torch.nn.Linear_forward_before_lora = torch.nn.Linear.forward +if not hasattr(torch.nn, 'Linear_load_state_dict_before_lora'): + torch.nn.Linear_load_state_dict_before_lora = torch.nn.Linear._load_from_state_dict + if not hasattr(torch.nn, 'Conv2d_forward_before_lora'): torch.nn.Conv2d_forward_before_lora = torch.nn.Conv2d.forward +if not hasattr(torch.nn, 'Conv2d_load_state_dict_before_lora'): + torch.nn.Conv2d_load_state_dict_before_lora = torch.nn.Conv2d._load_from_state_dict + +if not hasattr(torch.nn, 'MultiheadAttention_forward_before_lora'): + torch.nn.MultiheadAttention_forward_before_lora = torch.nn.MultiheadAttention.forward + +if not hasattr(torch.nn, 'MultiheadAttention_load_state_dict_before_lora'): + torch.nn.MultiheadAttention_load_state_dict_before_lora = torch.nn.MultiheadAttention._load_from_state_dict + torch.nn.Linear.forward = lora.lora_Linear_forward +torch.nn.Linear._load_from_state_dict = lora.lora_Linear_load_state_dict torch.nn.Conv2d.forward = lora.lora_Conv2d_forward +torch.nn.Conv2d._load_from_state_dict = lora.lora_Conv2d_load_state_dict +torch.nn.MultiheadAttention.forward = lora.lora_MultiheadAttention_forward +torch.nn.MultiheadAttention._load_from_state_dict = lora.lora_MultiheadAttention_load_state_dict script_callbacks.on_model_loaded(lora.assign_lora_names_to_compvis_modules) script_callbacks.on_script_unloaded(unload) @@ -33,6 +53,4 @@ script_callbacks.on_before_ui(before_ui) shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), { "sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": [""] + [x for x in lora.available_loras]}, refresh=lora.list_available_loras), - "lora_apply_to_outputs": shared.OptionInfo(False, "Apply Lora to outputs rather than inputs when possible (experimental)"), - }))