fix #3986 breaking --no-half-vae
This commit is contained in:
parent
675b51ebd3
commit
f2a5cbe6f5
1 changed files with 9 additions and 0 deletions
|
@ -183,11 +183,20 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"):
|
||||||
model.to(memory_format=torch.channels_last)
|
model.to(memory_format=torch.channels_last)
|
||||||
|
|
||||||
if not shared.cmd_opts.no_half:
|
if not shared.cmd_opts.no_half:
|
||||||
|
vae = model.first_stage_model
|
||||||
|
|
||||||
|
# with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16
|
||||||
|
if shared.cmd_opts.no_half_vae:
|
||||||
|
model.first_stage_model = None
|
||||||
|
|
||||||
model.half()
|
model.half()
|
||||||
|
model.first_stage_model = vae
|
||||||
|
|
||||||
devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16
|
devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16
|
||||||
devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16
|
devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16
|
||||||
|
|
||||||
|
model.first_stage_model.to(devices.dtype_vae)
|
||||||
|
|
||||||
if shared.opts.sd_checkpoint_cache > 0:
|
if shared.opts.sd_checkpoint_cache > 0:
|
||||||
# if PR #4035 were to get merged, restore base VAE first before caching
|
# if PR #4035 were to get merged, restore base VAE first before caching
|
||||||
checkpoints_loaded[checkpoint_key] = model.state_dict().copy()
|
checkpoints_loaded[checkpoint_key] = model.state_dict().copy()
|
||||||
|
|
Loading…
Reference in a new issue