fix #3986 breaking --no-half-vae
This commit is contained in:
parent
675b51ebd3
commit
f2a5cbe6f5
1 changed files with 9 additions and 0 deletions
|
@ -183,11 +183,20 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"):
|
|||
model.to(memory_format=torch.channels_last)
|
||||
|
||||
if not shared.cmd_opts.no_half:
|
||||
vae = model.first_stage_model
|
||||
|
||||
# with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16
|
||||
if shared.cmd_opts.no_half_vae:
|
||||
model.first_stage_model = None
|
||||
|
||||
model.half()
|
||||
model.first_stage_model = vae
|
||||
|
||||
devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16
|
||||
devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16
|
||||
|
||||
model.first_stage_model.to(devices.dtype_vae)
|
||||
|
||||
if shared.opts.sd_checkpoint_cache > 0:
|
||||
# if PR #4035 were to get merged, restore base VAE first before caching
|
||||
checkpoints_loaded[checkpoint_key] = model.state_dict().copy()
|
||||
|
|
Loading…
Reference in a new issue