fix for incorrect model weight loading for #814
This commit is contained in:
parent
965dcf4469
commit
c715ef04d1
2 changed files with 14 additions and 1 deletions
|
@ -245,6 +245,7 @@ class StableDiffusionModelHijack:
|
|||
|
||||
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
|
||||
m.cond_stage_model = FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
|
||||
|
||||
self.clip = m.cond_stage_model
|
||||
|
||||
if cmd_opts.opt_split_attention_v1:
|
||||
|
@ -263,6 +264,14 @@ class StableDiffusionModelHijack:
|
|||
|
||||
self.layers = flatten(m)
|
||||
|
||||
def undo_hijack(self, m):
|
||||
if type(m.cond_stage_model) == FrozenCLIPEmbedderWithCustomWords:
|
||||
m.cond_stage_model = m.cond_stage_model.wrapped
|
||||
|
||||
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
|
||||
if type(model_embeddings.token_embedding) == EmbeddingsWithFixes:
|
||||
model_embeddings.token_embedding = model_embeddings.token_embedding.wrapped
|
||||
|
||||
def apply_circular(self, enable):
|
||||
if self.circular_enabled == enable:
|
||||
return
|
||||
|
|
|
@ -137,7 +137,7 @@ def load_model():
|
|||
|
||||
|
||||
def reload_model_weights(sd_model, info=None):
|
||||
from modules import lowvram, devices
|
||||
from modules import lowvram, devices, sd_hijack
|
||||
checkpoint_info = info or select_checkpoint()
|
||||
|
||||
if sd_model.sd_model_checkpint == checkpoint_info.filename:
|
||||
|
@ -148,8 +148,12 @@ def reload_model_weights(sd_model, info=None):
|
|||
else:
|
||||
sd_model.to(devices.cpu)
|
||||
|
||||
sd_hijack.model_hijack.undo_hijack(sd_model)
|
||||
|
||||
load_model_weights(sd_model, checkpoint_info.filename, checkpoint_info.hash)
|
||||
|
||||
sd_hijack.model_hijack.hijack(sd_model)
|
||||
|
||||
if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
|
||||
sd_model.to(devices.device)
|
||||
|
||||
|
|
Loading…
Reference in a new issue