apply Lora by altering layer's weights instead of adding more calculations in forward()
This commit is contained in:
parent
69eb2a9ee8
commit
80b26d2a69
2 changed files with 66 additions and 18 deletions
|
@ -131,7 +131,7 @@ def load_lora(name, filename):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
module.weight.copy_(weight)
|
module.weight.copy_(weight)
|
||||||
|
|
||||||
module.to(device=devices.device, dtype=devices.dtype)
|
module.to(device=devices.cpu, dtype=devices.dtype)
|
||||||
|
|
||||||
if lora_key == "lora_up.weight":
|
if lora_key == "lora_up.weight":
|
||||||
lora_module.up = module
|
lora_module.up = module
|
||||||
|
@ -177,29 +177,69 @@ def load_loras(names, multipliers=None):
|
||||||
loaded_loras.append(lora)
|
loaded_loras.append(lora)
|
||||||
|
|
||||||
|
|
||||||
def lora_forward(module, input, res):
|
def lora_apply_weights(self: torch.nn.Conv2d | torch.nn.Linear):
|
||||||
input = devices.cond_cast_unet(input)
|
"""
|
||||||
if len(loaded_loras) == 0:
|
Applies the currently selected set of Loras to the weight of torch layer self.
|
||||||
return res
|
If weights already have this particular set of loras applied, does nothing.
|
||||||
|
If not, restores orginal weights from backup and alters weights according to loras.
|
||||||
|
"""
|
||||||
|
|
||||||
lora_layer_name = getattr(module, 'lora_layer_name', None)
|
current_names = getattr(self, "lora_current_names", ())
|
||||||
|
wanted_names = tuple((x.name, x.multiplier) for x in loaded_loras)
|
||||||
|
|
||||||
|
weights_backup = getattr(self, "lora_weights_backup", None)
|
||||||
|
if weights_backup is None:
|
||||||
|
weights_backup = self.weight.to(devices.cpu, copy=True)
|
||||||
|
self.lora_weights_backup = weights_backup
|
||||||
|
|
||||||
|
if current_names != wanted_names:
|
||||||
|
if weights_backup is not None:
|
||||||
|
self.weight.copy_(weights_backup)
|
||||||
|
|
||||||
|
lora_layer_name = getattr(self, 'lora_layer_name', None)
|
||||||
for lora in loaded_loras:
|
for lora in loaded_loras:
|
||||||
module = lora.modules.get(lora_layer_name, None)
|
module = lora.modules.get(lora_layer_name, None)
|
||||||
if module is not None:
|
if module is None:
|
||||||
if shared.opts.lora_apply_to_outputs and res.shape == input.shape:
|
continue
|
||||||
res = res + module.up(module.down(res)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
|
|
||||||
else:
|
|
||||||
res = res + module.up(module.down(input)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
|
|
||||||
|
|
||||||
return res
|
with torch.no_grad():
|
||||||
|
up = module.up.weight.to(self.weight.device, dtype=self.weight.dtype)
|
||||||
|
down = module.down.weight.to(self.weight.device, dtype=self.weight.dtype)
|
||||||
|
|
||||||
|
if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
|
||||||
|
updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
||||||
|
else:
|
||||||
|
updown = up @ down
|
||||||
|
|
||||||
|
self.weight += updown * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
|
||||||
|
|
||||||
|
setattr(self, "lora_current_names", wanted_names)
|
||||||
|
|
||||||
|
|
||||||
def lora_Linear_forward(self, input):
|
def lora_Linear_forward(self, input):
|
||||||
return lora_forward(self, input, torch.nn.Linear_forward_before_lora(self, input))
|
lora_apply_weights(self)
|
||||||
|
|
||||||
|
return torch.nn.Linear_forward_before_lora(self, input)
|
||||||
|
|
||||||
|
|
||||||
|
def lora_Linear_load_state_dict(self: torch.nn.Linear, *args, **kwargs):
|
||||||
|
setattr(self, "lora_current_names", ())
|
||||||
|
setattr(self, "lora_weights_backup", None)
|
||||||
|
|
||||||
|
return torch.nn.Linear_load_state_dict_before_lora(self, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def lora_Conv2d_forward(self, input):
|
def lora_Conv2d_forward(self, input):
|
||||||
return lora_forward(self, input, torch.nn.Conv2d_forward_before_lora(self, input))
|
lora_apply_weights(self)
|
||||||
|
|
||||||
|
return torch.nn.Conv2d_forward_before_lora(self, input)
|
||||||
|
|
||||||
|
|
||||||
|
def lora_Conv2d_load_state_dict(self: torch.nn.Conv2d, *args, **kwargs):
|
||||||
|
setattr(self, "lora_current_names", ())
|
||||||
|
setattr(self, "lora_weights_backup", None)
|
||||||
|
|
||||||
|
return torch.nn.Conv2d_load_state_dict_before_lora(self, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def list_available_loras():
|
def list_available_loras():
|
||||||
|
|
|
@ -9,7 +9,9 @@ from modules import script_callbacks, ui_extra_networks, extra_networks, shared
|
||||||
|
|
||||||
def unload():
|
def unload():
|
||||||
torch.nn.Linear.forward = torch.nn.Linear_forward_before_lora
|
torch.nn.Linear.forward = torch.nn.Linear_forward_before_lora
|
||||||
|
torch.nn.Linear._load_from_state_dict = torch.nn.Linear_load_state_dict_before_lora
|
||||||
torch.nn.Conv2d.forward = torch.nn.Conv2d_forward_before_lora
|
torch.nn.Conv2d.forward = torch.nn.Conv2d_forward_before_lora
|
||||||
|
torch.nn.Conv2d._load_from_state_dict = torch.nn.Conv2d_load_state_dict_before_lora
|
||||||
|
|
||||||
|
|
||||||
def before_ui():
|
def before_ui():
|
||||||
|
@ -20,11 +22,19 @@ def before_ui():
|
||||||
if not hasattr(torch.nn, 'Linear_forward_before_lora'):
|
if not hasattr(torch.nn, 'Linear_forward_before_lora'):
|
||||||
torch.nn.Linear_forward_before_lora = torch.nn.Linear.forward
|
torch.nn.Linear_forward_before_lora = torch.nn.Linear.forward
|
||||||
|
|
||||||
|
if not hasattr(torch.nn, 'Linear_load_state_dict_before_lora'):
|
||||||
|
torch.nn.Linear_load_state_dict_before_lora = torch.nn.Linear._load_from_state_dict
|
||||||
|
|
||||||
if not hasattr(torch.nn, 'Conv2d_forward_before_lora'):
|
if not hasattr(torch.nn, 'Conv2d_forward_before_lora'):
|
||||||
torch.nn.Conv2d_forward_before_lora = torch.nn.Conv2d.forward
|
torch.nn.Conv2d_forward_before_lora = torch.nn.Conv2d.forward
|
||||||
|
|
||||||
|
if not hasattr(torch.nn, 'Conv2d_load_state_dict_before_lora'):
|
||||||
|
torch.nn.Conv2d_load_state_dict_before_lora = torch.nn.Conv2d._load_from_state_dict
|
||||||
|
|
||||||
torch.nn.Linear.forward = lora.lora_Linear_forward
|
torch.nn.Linear.forward = lora.lora_Linear_forward
|
||||||
|
torch.nn.Linear._load_from_state_dict = lora.lora_Linear_load_state_dict
|
||||||
torch.nn.Conv2d.forward = lora.lora_Conv2d_forward
|
torch.nn.Conv2d.forward = lora.lora_Conv2d_forward
|
||||||
|
torch.nn.Conv2d._load_from_state_dict = lora.lora_Conv2d_load_state_dict
|
||||||
|
|
||||||
script_callbacks.on_model_loaded(lora.assign_lora_names_to_compvis_modules)
|
script_callbacks.on_model_loaded(lora.assign_lora_names_to_compvis_modules)
|
||||||
script_callbacks.on_script_unloaded(unload)
|
script_callbacks.on_script_unloaded(unload)
|
||||||
|
@ -33,6 +43,4 @@ script_callbacks.on_before_ui(before_ui)
|
||||||
|
|
||||||
shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), {
|
shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), {
|
||||||
"sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": [""] + [x for x in lora.available_loras]}, refresh=lora.list_available_loras),
|
"sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": [""] + [x for x in lora.available_loras]}, refresh=lora.list_available_loras),
|
||||||
"lora_apply_to_outputs": shared.OptionInfo(False, "Apply Lora to outputs rather than inputs when possible (experimental)"),
|
|
||||||
|
|
||||||
}))
|
}))
|
||||||
|
|
Loading…
Reference in a new issue