Lora support for SD2

This commit is contained in:
AUTOMATIC 2023-03-26 10:44:20 +03:00
parent b705c9b72b
commit 650ddc9dd3
2 changed files with 127 additions and 40 deletions

View file

@ -8,14 +8,27 @@ from modules import shared, devices, sd_models, errors
metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20} metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20}
re_digits = re.compile(r"\d+") re_digits = re.compile(r"\d+")
re_unet_down_blocks = re.compile(r"lora_unet_down_blocks_(\d+)_attentions_(\d+)_(.+)") re_x_proj = re.compile(r"(.*)_([qkv]_proj)$")
re_unet_mid_blocks = re.compile(r"lora_unet_mid_block_attentions_(\d+)_(.+)") re_compiled = {}
re_unet_up_blocks = re.compile(r"lora_unet_up_blocks_(\d+)_attentions_(\d+)_(.+)")
re_text_block = re.compile(r"lora_te_text_model_encoder_layers_(\d+)_(.+)") suffix_conversion = {
"attentions": {},
"resnets": {
"conv1": "in_layers_2",
"conv2": "out_layers_3",
"time_emb_proj": "emb_layers_1",
"conv_shortcut": "skip_connection",
}
}
def convert_diffusers_name_to_compvis(key, is_sd2): def convert_diffusers_name_to_compvis(key, is_sd2):
def match(match_list, regex): def match(match_list, regex_text):
regex = re_compiled.get(regex_text)
if regex is None:
regex = re.compile(regex_text)
re_compiled[regex_text] = regex
r = re.match(regex, key) r = re.match(regex, key)
if not r: if not r:
return False return False
@ -26,16 +39,25 @@ def convert_diffusers_name_to_compvis(key, is_sd2):
m = [] m = []
if match(m, re_unet_down_blocks): if match(m, r"lora_unet_down_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"):
return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[1]}_1_{m[2]}" suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3])
return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}"
if match(m, re_unet_mid_blocks): if match(m, r"lora_unet_mid_block_(attentions|resnets)_(\d+)_(.+)"):
return f"diffusion_model_middle_block_1_{m[1]}" suffix = suffix_conversion.get(m[0], {}).get(m[2], m[2])
return f"diffusion_model_middle_block_{1 if m[0] == 'attentions' else m[1] * 2}_{suffix}"
if match(m, re_unet_up_blocks): if match(m, r"lora_unet_up_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"):
return f"diffusion_model_output_blocks_{m[0] * 3 + m[1]}_1_{m[2]}" suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3])
return f"diffusion_model_output_blocks_{m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}"
if match(m, re_text_block): if match(m, r"lora_unet_down_blocks_(\d+)_downsamplers_0_conv"):
return f"diffusion_model_input_blocks_{3 + m[0] * 3}_0_op"
if match(m, r"lora_unet_up_blocks_(\d+)_upsamplers_0_conv"):
return f"diffusion_model_output_blocks_{2 + m[0] * 3}_{2 if m[0]>0 else 1}_conv"
if match(m, r"lora_te_text_model_encoder_layers_(\d+)_(.+)"):
if is_sd2: if is_sd2:
if 'mlp_fc1' in m[1]: if 'mlp_fc1' in m[1]:
return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}" return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}"
@ -109,16 +131,22 @@ def load_lora(name, filename):
sd = sd_models.read_state_dict(filename) sd = sd_models.read_state_dict(filename)
keys_failed_to_match = [] keys_failed_to_match = {}
is_sd2 = 'model_transformer_resblocks' in shared.sd_model.lora_layer_mapping is_sd2 = 'model_transformer_resblocks' in shared.sd_model.lora_layer_mapping
for key_diffusers, weight in sd.items(): for key_diffusers, weight in sd.items():
fullkey = convert_diffusers_name_to_compvis(key_diffusers, is_sd2) key_diffusers_without_lora_parts, lora_key = key_diffusers.split(".", 1)
key, lora_key = fullkey.split(".", 1) key = convert_diffusers_name_to_compvis(key_diffusers_without_lora_parts, is_sd2)
sd_module = shared.sd_model.lora_layer_mapping.get(key, None) sd_module = shared.sd_model.lora_layer_mapping.get(key, None)
if sd_module is None: if sd_module is None:
keys_failed_to_match.append(key_diffusers) m = re_x_proj.match(key)
if m:
sd_module = shared.sd_model.lora_layer_mapping.get(m.group(1), None)
if sd_module is None:
keys_failed_to_match[key_diffusers] = key
continue continue
lora_module = lora.modules.get(key, None) lora_module = lora.modules.get(key, None)
@ -133,7 +161,9 @@ def load_lora(name, filename):
if type(sd_module) == torch.nn.Linear: if type(sd_module) == torch.nn.Linear:
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
elif type(sd_module) == torch.nn.modules.linear.NonDynamicallyQuantizableLinear: elif type(sd_module) == torch.nn.modules.linear.NonDynamicallyQuantizableLinear:
module = torch.nn.modules.linear.NonDynamicallyQuantizableLinear(weight.shape[1], weight.shape[0], bias=False) module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
elif type(sd_module) == torch.nn.MultiheadAttention:
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
elif type(sd_module) == torch.nn.Conv2d: elif type(sd_module) == torch.nn.Conv2d:
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False) module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
else: else:
@ -190,54 +220,94 @@ def load_loras(names, multipliers=None):
loaded_loras.append(lora) loaded_loras.append(lora)
def lora_apply_weights(self: torch.nn.Conv2d | torch.nn.Linear): def lora_calc_updown(lora, module, target):
with torch.no_grad():
up = module.up.weight.to(target.device, dtype=target.dtype)
down = module.down.weight.to(target.device, dtype=target.dtype)
if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
else:
updown = up @ down
updown = updown * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
return updown
def lora_apply_weights(self: torch.nn.Conv2d | torch.nn.Linear | torch.nn.MultiheadAttention):
""" """
Applies the currently selected set of Loras to the weight of torch layer self. Applies the currently selected set of Loras to the weights of torch layer self.
If weights already have this particular set of loras applied, does nothing. If weights already have this particular set of loras applied, does nothing.
If not, restores orginal weights from backup and alters weights according to loras. If not, restores orginal weights from backup and alters weights according to loras.
""" """
lora_layer_name = getattr(self, 'lora_layer_name', None)
if lora_layer_name is None:
return
current_names = getattr(self, "lora_current_names", ()) current_names = getattr(self, "lora_current_names", ())
wanted_names = tuple((x.name, x.multiplier) for x in loaded_loras) wanted_names = tuple((x.name, x.multiplier) for x in loaded_loras)
weights_backup = getattr(self, "lora_weights_backup", None) weights_backup = getattr(self, "lora_weights_backup", None)
if weights_backup is None: if weights_backup is None:
weights_backup = self.weight.to(devices.cpu, copy=True) if isinstance(self, torch.nn.MultiheadAttention):
weights_backup = (self.in_proj_weight.to(devices.cpu, copy=True), self.out_proj.weight.to(devices.cpu, copy=True))
else:
weights_backup = self.weight.to(devices.cpu, copy=True)
self.lora_weights_backup = weights_backup self.lora_weights_backup = weights_backup
if current_names != wanted_names: if current_names != wanted_names:
if weights_backup is not None: if weights_backup is not None:
self.weight.copy_(weights_backup) if isinstance(self, torch.nn.MultiheadAttention):
self.in_proj_weight.copy_(weights_backup[0])
self.out_proj.weight.copy_(weights_backup[1])
else:
self.weight.copy_(weights_backup)
lora_layer_name = getattr(self, 'lora_layer_name', None)
for lora in loaded_loras: for lora in loaded_loras:
module = lora.modules.get(lora_layer_name, None) module = lora.modules.get(lora_layer_name, None)
if module is not None and hasattr(self, 'weight'):
self.weight += lora_calc_updown(lora, module, self.weight)
continue
module_q = lora.modules.get(lora_layer_name + "_q_proj", None)
module_k = lora.modules.get(lora_layer_name + "_k_proj", None)
module_v = lora.modules.get(lora_layer_name + "_v_proj", None)
module_out = lora.modules.get(lora_layer_name + "_out_proj", None)
if isinstance(self, torch.nn.MultiheadAttention) and module_q and module_k and module_v and module_out:
updown_q = lora_calc_updown(lora, module_q, self.in_proj_weight)
updown_k = lora_calc_updown(lora, module_k, self.in_proj_weight)
updown_v = lora_calc_updown(lora, module_v, self.in_proj_weight)
updown_qkv = torch.vstack([updown_q, updown_k, updown_v])
self.in_proj_weight += updown_qkv
self.out_proj.weight += lora_calc_updown(lora, module_out, self.out_proj.weight)
continue
if module is None: if module is None:
continue continue
with torch.no_grad(): print(f'failed to calculate lora weights for layer {lora_layer_name}')
up = module.up.weight.to(self.weight.device, dtype=self.weight.dtype)
down = module.down.weight.to(self.weight.device, dtype=self.weight.dtype)
if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
else:
updown = up @ down
self.weight += updown * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
setattr(self, "lora_current_names", wanted_names) setattr(self, "lora_current_names", wanted_names)
def lora_reset_cached_weight(self: torch.nn.Conv2d | torch.nn.Linear):
setattr(self, "lora_current_names", ())
setattr(self, "lora_weights_backup", None)
def lora_Linear_forward(self, input): def lora_Linear_forward(self, input):
lora_apply_weights(self) lora_apply_weights(self)
return torch.nn.Linear_forward_before_lora(self, input) return torch.nn.Linear_forward_before_lora(self, input)
def lora_Linear_load_state_dict(self: torch.nn.Linear, *args, **kwargs): def lora_Linear_load_state_dict(self, *args, **kwargs):
setattr(self, "lora_current_names", ()) lora_reset_cached_weight(self)
setattr(self, "lora_weights_backup", None)
return torch.nn.Linear_load_state_dict_before_lora(self, *args, **kwargs) return torch.nn.Linear_load_state_dict_before_lora(self, *args, **kwargs)
@ -248,15 +318,22 @@ def lora_Conv2d_forward(self, input):
return torch.nn.Conv2d_forward_before_lora(self, input) return torch.nn.Conv2d_forward_before_lora(self, input)
def lora_Conv2d_load_state_dict(self: torch.nn.Conv2d, *args, **kwargs): def lora_Conv2d_load_state_dict(self, *args, **kwargs):
setattr(self, "lora_current_names", ()) lora_reset_cached_weight(self)
setattr(self, "lora_weights_backup", None)
return torch.nn.Conv2d_load_state_dict_before_lora(self, *args, **kwargs) return torch.nn.Conv2d_load_state_dict_before_lora(self, *args, **kwargs)
def lora_NonDynamicallyQuantizableLinear_forward(self, input): def lora_MultiheadAttention_forward(self, *args, **kwargs):
return lora_forward(self, input, torch.nn.NonDynamicallyQuantizableLinear_forward_before_lora(self, input)) lora_apply_weights(self)
return torch.nn.MultiheadAttention_forward_before_lora(self, *args, **kwargs)
def lora_MultiheadAttention_load_state_dict(self, *args, **kwargs):
lora_reset_cached_weight(self)
return torch.nn.MultiheadAttention_load_state_dict_before_lora(self, *args, **kwargs)
def list_available_loras(): def list_available_loras():

View file

@ -12,6 +12,8 @@ def unload():
torch.nn.Linear._load_from_state_dict = torch.nn.Linear_load_state_dict_before_lora torch.nn.Linear._load_from_state_dict = torch.nn.Linear_load_state_dict_before_lora
torch.nn.Conv2d.forward = torch.nn.Conv2d_forward_before_lora torch.nn.Conv2d.forward = torch.nn.Conv2d_forward_before_lora
torch.nn.Conv2d._load_from_state_dict = torch.nn.Conv2d_load_state_dict_before_lora torch.nn.Conv2d._load_from_state_dict = torch.nn.Conv2d_load_state_dict_before_lora
torch.nn.MultiheadAttention.forward = torch.nn.MultiheadAttention_forward_before_lora
torch.nn.MultiheadAttention._load_from_state_dict = torch.nn.MultiheadAttention_load_state_dict_before_lora
def before_ui(): def before_ui():
@ -31,10 +33,18 @@ if not hasattr(torch.nn, 'Conv2d_forward_before_lora'):
if not hasattr(torch.nn, 'Conv2d_load_state_dict_before_lora'): if not hasattr(torch.nn, 'Conv2d_load_state_dict_before_lora'):
torch.nn.Conv2d_load_state_dict_before_lora = torch.nn.Conv2d._load_from_state_dict torch.nn.Conv2d_load_state_dict_before_lora = torch.nn.Conv2d._load_from_state_dict
if not hasattr(torch.nn, 'MultiheadAttention_forward_before_lora'):
torch.nn.MultiheadAttention_forward_before_lora = torch.nn.MultiheadAttention.forward
if not hasattr(torch.nn, 'MultiheadAttention_load_state_dict_before_lora'):
torch.nn.MultiheadAttention_load_state_dict_before_lora = torch.nn.MultiheadAttention._load_from_state_dict
torch.nn.Linear.forward = lora.lora_Linear_forward torch.nn.Linear.forward = lora.lora_Linear_forward
torch.nn.Linear._load_from_state_dict = lora.lora_Linear_load_state_dict torch.nn.Linear._load_from_state_dict = lora.lora_Linear_load_state_dict
torch.nn.Conv2d.forward = lora.lora_Conv2d_forward torch.nn.Conv2d.forward = lora.lora_Conv2d_forward
torch.nn.Conv2d._load_from_state_dict = lora.lora_Conv2d_load_state_dict torch.nn.Conv2d._load_from_state_dict = lora.lora_Conv2d_load_state_dict
torch.nn.MultiheadAttention.forward = lora.lora_MultiheadAttention_forward
torch.nn.MultiheadAttention._load_from_state_dict = lora.lora_MultiheadAttention_load_state_dict
script_callbacks.on_model_loaded(lora.assign_lora_names_to_compvis_modules) script_callbacks.on_model_loaded(lora.assign_lora_names_to_compvis_modules)
script_callbacks.on_script_unloaded(unload) script_callbacks.on_script_unloaded(unload)