apply Lora by altering layer's weights instead of adding more calculations in forward()
This commit is contained in:
parent
69eb2a9ee8
commit
80b26d2a69
2 changed files with 66 additions and 18 deletions
|
@ -131,7 +131,7 @@ def load_lora(name, filename):
|
|||
with torch.no_grad():
|
||||
module.weight.copy_(weight)
|
||||
|
||||
module.to(device=devices.device, dtype=devices.dtype)
|
||||
module.to(device=devices.cpu, dtype=devices.dtype)
|
||||
|
||||
if lora_key == "lora_up.weight":
|
||||
lora_module.up = module
|
||||
|
@ -177,29 +177,69 @@ def load_loras(names, multipliers=None):
|
|||
loaded_loras.append(lora)
|
||||
|
||||
|
||||
def lora_forward(module, input, res):
|
||||
input = devices.cond_cast_unet(input)
|
||||
if len(loaded_loras) == 0:
|
||||
return res
|
||||
def lora_apply_weights(self: torch.nn.Conv2d | torch.nn.Linear):
|
||||
"""
|
||||
Applies the currently selected set of Loras to the weight of torch layer self.
|
||||
If weights already have this particular set of loras applied, does nothing.
|
||||
If not, restores orginal weights from backup and alters weights according to loras.
|
||||
"""
|
||||
|
||||
lora_layer_name = getattr(module, 'lora_layer_name', None)
|
||||
for lora in loaded_loras:
|
||||
module = lora.modules.get(lora_layer_name, None)
|
||||
if module is not None:
|
||||
if shared.opts.lora_apply_to_outputs and res.shape == input.shape:
|
||||
res = res + module.up(module.down(res)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
|
||||
else:
|
||||
res = res + module.up(module.down(input)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
|
||||
current_names = getattr(self, "lora_current_names", ())
|
||||
wanted_names = tuple((x.name, x.multiplier) for x in loaded_loras)
|
||||
|
||||
return res
|
||||
weights_backup = getattr(self, "lora_weights_backup", None)
|
||||
if weights_backup is None:
|
||||
weights_backup = self.weight.to(devices.cpu, copy=True)
|
||||
self.lora_weights_backup = weights_backup
|
||||
|
||||
if current_names != wanted_names:
|
||||
if weights_backup is not None:
|
||||
self.weight.copy_(weights_backup)
|
||||
|
||||
lora_layer_name = getattr(self, 'lora_layer_name', None)
|
||||
for lora in loaded_loras:
|
||||
module = lora.modules.get(lora_layer_name, None)
|
||||
if module is None:
|
||||
continue
|
||||
|
||||
with torch.no_grad():
|
||||
up = module.up.weight.to(self.weight.device, dtype=self.weight.dtype)
|
||||
down = module.down.weight.to(self.weight.device, dtype=self.weight.dtype)
|
||||
|
||||
if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
|
||||
updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
||||
else:
|
||||
updown = up @ down
|
||||
|
||||
self.weight += updown * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
|
||||
|
||||
setattr(self, "lora_current_names", wanted_names)
|
||||
|
||||
|
||||
def lora_Linear_forward(self, input):
|
||||
return lora_forward(self, input, torch.nn.Linear_forward_before_lora(self, input))
|
||||
lora_apply_weights(self)
|
||||
|
||||
return torch.nn.Linear_forward_before_lora(self, input)
|
||||
|
||||
|
||||
def lora_Linear_load_state_dict(self: torch.nn.Linear, *args, **kwargs):
|
||||
setattr(self, "lora_current_names", ())
|
||||
setattr(self, "lora_weights_backup", None)
|
||||
|
||||
return torch.nn.Linear_load_state_dict_before_lora(self, *args, **kwargs)
|
||||
|
||||
|
||||
def lora_Conv2d_forward(self, input):
|
||||
return lora_forward(self, input, torch.nn.Conv2d_forward_before_lora(self, input))
|
||||
lora_apply_weights(self)
|
||||
|
||||
return torch.nn.Conv2d_forward_before_lora(self, input)
|
||||
|
||||
|
||||
def lora_Conv2d_load_state_dict(self: torch.nn.Conv2d, *args, **kwargs):
|
||||
setattr(self, "lora_current_names", ())
|
||||
setattr(self, "lora_weights_backup", None)
|
||||
|
||||
return torch.nn.Conv2d_load_state_dict_before_lora(self, *args, **kwargs)
|
||||
|
||||
|
||||
def list_available_loras():
|
||||
|
|
|
@ -9,7 +9,9 @@ from modules import script_callbacks, ui_extra_networks, extra_networks, shared
|
|||
|
||||
def unload():
|
||||
torch.nn.Linear.forward = torch.nn.Linear_forward_before_lora
|
||||
torch.nn.Linear._load_from_state_dict = torch.nn.Linear_load_state_dict_before_lora
|
||||
torch.nn.Conv2d.forward = torch.nn.Conv2d_forward_before_lora
|
||||
torch.nn.Conv2d._load_from_state_dict = torch.nn.Conv2d_load_state_dict_before_lora
|
||||
|
||||
|
||||
def before_ui():
|
||||
|
@ -20,11 +22,19 @@ def before_ui():
|
|||
if not hasattr(torch.nn, 'Linear_forward_before_lora'):
|
||||
torch.nn.Linear_forward_before_lora = torch.nn.Linear.forward
|
||||
|
||||
if not hasattr(torch.nn, 'Linear_load_state_dict_before_lora'):
|
||||
torch.nn.Linear_load_state_dict_before_lora = torch.nn.Linear._load_from_state_dict
|
||||
|
||||
if not hasattr(torch.nn, 'Conv2d_forward_before_lora'):
|
||||
torch.nn.Conv2d_forward_before_lora = torch.nn.Conv2d.forward
|
||||
|
||||
if not hasattr(torch.nn, 'Conv2d_load_state_dict_before_lora'):
|
||||
torch.nn.Conv2d_load_state_dict_before_lora = torch.nn.Conv2d._load_from_state_dict
|
||||
|
||||
torch.nn.Linear.forward = lora.lora_Linear_forward
|
||||
torch.nn.Linear._load_from_state_dict = lora.lora_Linear_load_state_dict
|
||||
torch.nn.Conv2d.forward = lora.lora_Conv2d_forward
|
||||
torch.nn.Conv2d._load_from_state_dict = lora.lora_Conv2d_load_state_dict
|
||||
|
||||
script_callbacks.on_model_loaded(lora.assign_lora_names_to_compvis_modules)
|
||||
script_callbacks.on_script_unloaded(unload)
|
||||
|
@ -33,6 +43,4 @@ script_callbacks.on_before_ui(before_ui)
|
|||
|
||||
shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), {
|
||||
"sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": [""] + [x for x in lora.available_loras]}, refresh=lora.list_available_loras),
|
||||
"lora_apply_to_outputs": shared.OptionInfo(False, "Apply Lora to outputs rather than inputs when possible (experimental)"),
|
||||
|
||||
}))
|
||||
|
|
Loading…
Reference in a new issue