Unload sd_model before loading the other
This commit is contained in:
parent
5c9b3625fa
commit
af758e97fa
5 changed files with 34 additions and 10 deletions
|
@ -38,13 +38,18 @@ def setup_for_low_vram(sd_model, use_medvram):
|
||||||
# see below for register_forward_pre_hook;
|
# see below for register_forward_pre_hook;
|
||||||
# first_stage_model does not use forward(), it uses encode/decode, so register_forward_pre_hook is
|
# first_stage_model does not use forward(), it uses encode/decode, so register_forward_pre_hook is
|
||||||
# useless here, and we just replace those methods
|
# useless here, and we just replace those methods
|
||||||
def first_stage_model_encode_wrap(self, encoder, x):
|
|
||||||
send_me_to_gpu(self, None)
|
|
||||||
return encoder(x)
|
|
||||||
|
|
||||||
def first_stage_model_decode_wrap(self, decoder, z):
|
first_stage_model = sd_model.first_stage_model
|
||||||
send_me_to_gpu(self, None)
|
first_stage_model_encode = sd_model.first_stage_model.encode
|
||||||
return decoder(z)
|
first_stage_model_decode = sd_model.first_stage_model.decode
|
||||||
|
|
||||||
|
def first_stage_model_encode_wrap(x):
|
||||||
|
send_me_to_gpu(first_stage_model, None)
|
||||||
|
return first_stage_model_encode(x)
|
||||||
|
|
||||||
|
def first_stage_model_decode_wrap(z):
|
||||||
|
send_me_to_gpu(first_stage_model, None)
|
||||||
|
return first_stage_model_decode(z)
|
||||||
|
|
||||||
# remove three big modules, cond, first_stage, and unet from the model and then
|
# remove three big modules, cond, first_stage, and unet from the model and then
|
||||||
# send the model to GPU. Then put modules back. the modules will be in CPU.
|
# send the model to GPU. Then put modules back. the modules will be in CPU.
|
||||||
|
@ -56,8 +61,8 @@ def setup_for_low_vram(sd_model, use_medvram):
|
||||||
# register hooks for those the first two models
|
# register hooks for those the first two models
|
||||||
sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu)
|
sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu)
|
||||||
sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu)
|
sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu)
|
||||||
sd_model.first_stage_model.encode = lambda x, en=sd_model.first_stage_model.encode: first_stage_model_encode_wrap(sd_model.first_stage_model, en, x)
|
sd_model.first_stage_model.encode = first_stage_model_encode_wrap
|
||||||
sd_model.first_stage_model.decode = lambda z, de=sd_model.first_stage_model.decode: first_stage_model_decode_wrap(sd_model.first_stage_model, de, z)
|
sd_model.first_stage_model.decode = first_stage_model_decode_wrap
|
||||||
parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model
|
parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model
|
||||||
|
|
||||||
if use_medvram:
|
if use_medvram:
|
||||||
|
|
|
@ -597,6 +597,9 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||||
if p.scripts is not None:
|
if p.scripts is not None:
|
||||||
p.scripts.postprocess(p, res)
|
p.scripts.postprocess(p, res)
|
||||||
|
|
||||||
|
p.sd_model = None
|
||||||
|
p.sampler = None
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -94,6 +94,10 @@ class StableDiffusionModelHijack:
|
||||||
if type(model_embeddings.token_embedding) == EmbeddingsWithFixes:
|
if type(model_embeddings.token_embedding) == EmbeddingsWithFixes:
|
||||||
model_embeddings.token_embedding = model_embeddings.token_embedding.wrapped
|
model_embeddings.token_embedding = model_embeddings.token_embedding.wrapped
|
||||||
|
|
||||||
|
self.layers = None
|
||||||
|
self.circular_enabled = False
|
||||||
|
self.clip = None
|
||||||
|
|
||||||
def apply_circular(self, enable):
|
def apply_circular(self, enable):
|
||||||
if self.circular_enabled == enable:
|
if self.circular_enabled == enable:
|
||||||
return
|
return
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import collections
|
import collections
|
||||||
import os.path
|
import os.path
|
||||||
import sys
|
import sys
|
||||||
|
import gc
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
import torch
|
import torch
|
||||||
import re
|
import re
|
||||||
|
@ -220,6 +221,12 @@ def load_model(checkpoint_info=None):
|
||||||
if checkpoint_info.config != shared.cmd_opts.config:
|
if checkpoint_info.config != shared.cmd_opts.config:
|
||||||
print(f"Loading config from: {checkpoint_info.config}")
|
print(f"Loading config from: {checkpoint_info.config}")
|
||||||
|
|
||||||
|
if shared.sd_model:
|
||||||
|
sd_hijack.model_hijack.undo_hijack(shared.sd_model)
|
||||||
|
shared.sd_model = None
|
||||||
|
gc.collect()
|
||||||
|
devices.torch_gc()
|
||||||
|
|
||||||
sd_config = OmegaConf.load(checkpoint_info.config)
|
sd_config = OmegaConf.load(checkpoint_info.config)
|
||||||
|
|
||||||
if should_hijack_inpainting(checkpoint_info):
|
if should_hijack_inpainting(checkpoint_info):
|
||||||
|
@ -233,6 +240,7 @@ def load_model(checkpoint_info=None):
|
||||||
checkpoint_info = checkpoint_info._replace(config=checkpoint_info.config.replace(".yaml", "-inpainting.yaml"))
|
checkpoint_info = checkpoint_info._replace(config=checkpoint_info.config.replace(".yaml", "-inpainting.yaml"))
|
||||||
|
|
||||||
do_inpainting_hijack()
|
do_inpainting_hijack()
|
||||||
|
|
||||||
sd_model = instantiate_from_config(sd_config.model)
|
sd_model = instantiate_from_config(sd_config.model)
|
||||||
load_model_weights(sd_model, checkpoint_info)
|
load_model_weights(sd_model, checkpoint_info)
|
||||||
|
|
||||||
|
@ -252,14 +260,18 @@ def load_model(checkpoint_info=None):
|
||||||
return sd_model
|
return sd_model
|
||||||
|
|
||||||
|
|
||||||
def reload_model_weights(sd_model, info=None):
|
def reload_model_weights(sd_model=None, info=None):
|
||||||
from modules import lowvram, devices, sd_hijack
|
from modules import lowvram, devices, sd_hijack
|
||||||
checkpoint_info = info or select_checkpoint()
|
checkpoint_info = info or select_checkpoint()
|
||||||
|
|
||||||
|
if not sd_model:
|
||||||
|
sd_model = shared.sd_model
|
||||||
|
|
||||||
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
|
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
|
||||||
return
|
return
|
||||||
|
|
||||||
if sd_model.sd_checkpoint_info.config != checkpoint_info.config or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info):
|
if sd_model.sd_checkpoint_info.config != checkpoint_info.config or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info):
|
||||||
|
del sd_model
|
||||||
checkpoints_loaded.clear()
|
checkpoints_loaded.clear()
|
||||||
load_model(checkpoint_info)
|
load_model(checkpoint_info)
|
||||||
return shared.sd_model
|
return shared.sd_model
|
||||||
|
|
2
webui.py
2
webui.py
|
@ -77,7 +77,7 @@ def initialize():
|
||||||
modules.scripts.load_scripts()
|
modules.scripts.load_scripts()
|
||||||
|
|
||||||
modules.sd_models.load_model()
|
modules.sd_models.load_model()
|
||||||
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(shared.sd_model)))
|
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()))
|
||||||
shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork)))
|
shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork)))
|
||||||
shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength)
|
shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength)
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue