fix for incorrect model weight loading for #814

This commit is contained in:
AUTOMATIC 2022-09-29 15:40:28 +03:00
parent 965dcf4469
commit c715ef04d1
2 changed files with 14 additions and 1 deletions

View file

@ -245,6 +245,7 @@ class StableDiffusionModelHijack:
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self) model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
m.cond_stage_model = FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self) m.cond_stage_model = FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
self.clip = m.cond_stage_model self.clip = m.cond_stage_model
if cmd_opts.opt_split_attention_v1: if cmd_opts.opt_split_attention_v1:
@ -263,6 +264,14 @@ class StableDiffusionModelHijack:
self.layers = flatten(m) self.layers = flatten(m)
def undo_hijack(self, m):
if type(m.cond_stage_model) == FrozenCLIPEmbedderWithCustomWords:
m.cond_stage_model = m.cond_stage_model.wrapped
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
if type(model_embeddings.token_embedding) == EmbeddingsWithFixes:
model_embeddings.token_embedding = model_embeddings.token_embedding.wrapped
def apply_circular(self, enable): def apply_circular(self, enable):
if self.circular_enabled == enable: if self.circular_enabled == enable:
return return

View file

@ -137,7 +137,7 @@ def load_model():
def reload_model_weights(sd_model, info=None): def reload_model_weights(sd_model, info=None):
from modules import lowvram, devices from modules import lowvram, devices, sd_hijack
checkpoint_info = info or select_checkpoint() checkpoint_info = info or select_checkpoint()
if sd_model.sd_model_checkpint == checkpoint_info.filename: if sd_model.sd_model_checkpint == checkpoint_info.filename:
@ -148,8 +148,12 @@ def reload_model_weights(sd_model, info=None):
else: else:
sd_model.to(devices.cpu) sd_model.to(devices.cpu)
sd_hijack.model_hijack.undo_hijack(sd_model)
load_model_weights(sd_model, checkpoint_info.filename, checkpoint_info.hash) load_model_weights(sd_model, checkpoint_info.filename, checkpoint_info.hash)
sd_hijack.model_hijack.hijack(sd_model)
if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram: if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
sd_model.to(devices.device) sd_model.to(devices.device)