Merge pull request #5589 from MrCheeze/better-special-model-support
Better support for 2.0-inpainting and 2.0-depth special models
This commit is contained in:
commit
94450b8877
3 changed files with 12 additions and 8 deletions
|
@ -55,18 +55,20 @@ def setup_for_low_vram(sd_model, use_medvram):
|
||||||
if hasattr(sd_model.cond_stage_model, 'model'):
|
if hasattr(sd_model.cond_stage_model, 'model'):
|
||||||
sd_model.cond_stage_model.transformer = sd_model.cond_stage_model.model
|
sd_model.cond_stage_model.transformer = sd_model.cond_stage_model.model
|
||||||
|
|
||||||
# remove three big modules, cond, first_stage, and unet from the model and then
|
# remove four big modules, cond, first_stage, depth (if applicable), and unet from the model and then
|
||||||
# send the model to GPU. Then put modules back. the modules will be in CPU.
|
# send the model to GPU. Then put modules back. the modules will be in CPU.
|
||||||
stored = sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model
|
stored = sd_model.cond_stage_model.transformer, sd_model.first_stage_model, getattr(sd_model, 'depth_model', None), sd_model.model
|
||||||
sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model = None, None, None
|
sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.depth_model, sd_model.model = None, None, None, None
|
||||||
sd_model.to(devices.device)
|
sd_model.to(devices.device)
|
||||||
sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model = stored
|
sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.depth_model, sd_model.model = stored
|
||||||
|
|
||||||
# register hooks for those the first two models
|
# register hooks for those the first three models
|
||||||
sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu)
|
sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu)
|
||||||
sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu)
|
sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu)
|
||||||
sd_model.first_stage_model.encode = first_stage_model_encode_wrap
|
sd_model.first_stage_model.encode = first_stage_model_encode_wrap
|
||||||
sd_model.first_stage_model.decode = first_stage_model_decode_wrap
|
sd_model.first_stage_model.decode = first_stage_model_decode_wrap
|
||||||
|
if sd_model.depth_model:
|
||||||
|
sd_model.depth_model.register_forward_pre_hook(send_me_to_gpu)
|
||||||
parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model
|
parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model
|
||||||
|
|
||||||
if hasattr(sd_model.cond_stage_model, 'model'):
|
if hasattr(sd_model.cond_stage_model, 'model'):
|
||||||
|
|
|
@ -324,12 +324,11 @@ def should_hijack_inpainting(checkpoint_info):
|
||||||
|
|
||||||
def do_inpainting_hijack():
|
def do_inpainting_hijack():
|
||||||
# most of this stuff seems to no longer be needed because it is already included into SD2.0
|
# most of this stuff seems to no longer be needed because it is already included into SD2.0
|
||||||
# LatentInpaintDiffusion remains because SD2.0's LatentInpaintDiffusion can't be loaded without specifying a checkpoint
|
|
||||||
# p_sample_plms is needed because PLMS can't work with dicts as conditionings
|
# p_sample_plms is needed because PLMS can't work with dicts as conditionings
|
||||||
# this file should be cleaned up later if weverything tuens out to work fine
|
# this file should be cleaned up later if weverything tuens out to work fine
|
||||||
|
|
||||||
# ldm.models.diffusion.ddpm.get_unconditional_conditioning = get_unconditional_conditioning
|
# ldm.models.diffusion.ddpm.get_unconditional_conditioning = get_unconditional_conditioning
|
||||||
ldm.models.diffusion.ddpm.LatentInpaintDiffusion = LatentInpaintDiffusion
|
# ldm.models.diffusion.ddpm.LatentInpaintDiffusion = LatentInpaintDiffusion
|
||||||
|
|
||||||
# ldm.models.diffusion.ddim.DDIMSampler.p_sample_ddim = p_sample_ddim
|
# ldm.models.diffusion.ddim.DDIMSampler.p_sample_ddim = p_sample_ddim
|
||||||
# ldm.models.diffusion.ddim.DDIMSampler.sample = sample_ddim
|
# ldm.models.diffusion.ddim.DDIMSampler.sample = sample_ddim
|
||||||
|
|
|
@ -293,13 +293,16 @@ def load_model(checkpoint_info=None):
|
||||||
if should_hijack_inpainting(checkpoint_info):
|
if should_hijack_inpainting(checkpoint_info):
|
||||||
# Hardcoded config for now...
|
# Hardcoded config for now...
|
||||||
sd_config.model.target = "ldm.models.diffusion.ddpm.LatentInpaintDiffusion"
|
sd_config.model.target = "ldm.models.diffusion.ddpm.LatentInpaintDiffusion"
|
||||||
sd_config.model.params.use_ema = False
|
|
||||||
sd_config.model.params.conditioning_key = "hybrid"
|
sd_config.model.params.conditioning_key = "hybrid"
|
||||||
sd_config.model.params.unet_config.params.in_channels = 9
|
sd_config.model.params.unet_config.params.in_channels = 9
|
||||||
|
sd_config.model.params.finetune_keys = None
|
||||||
|
|
||||||
# Create a "fake" config with a different name so that we know to unload it when switching models.
|
# Create a "fake" config with a different name so that we know to unload it when switching models.
|
||||||
checkpoint_info = checkpoint_info._replace(config=checkpoint_info.config.replace(".yaml", "-inpainting.yaml"))
|
checkpoint_info = checkpoint_info._replace(config=checkpoint_info.config.replace(".yaml", "-inpainting.yaml"))
|
||||||
|
|
||||||
|
if not hasattr(sd_config.model.params, "use_ema"):
|
||||||
|
sd_config.model.params.use_ema = False
|
||||||
|
|
||||||
do_inpainting_hijack()
|
do_inpainting_hijack()
|
||||||
|
|
||||||
if shared.cmd_opts.no_half:
|
if shared.cmd_opts.no_half:
|
||||||
|
|
Loading…
Reference in a new issue