allow having at half precision when there is only one checkpoint in merger tab
This commit is contained in:
parent
0f5dbfffd0
commit
54674674b8
1 changed files with 13 additions and 3 deletions
|
@ -278,6 +278,13 @@ def create_config(ckpt_result, config_source, a, b, c):
|
||||||
chckpoint_dict_skip_on_merge = ["cond_stage_model.transformer.text_model.embeddings.position_ids"]
|
chckpoint_dict_skip_on_merge = ["cond_stage_model.transformer.text_model.embeddings.position_ids"]
|
||||||
|
|
||||||
|
|
||||||
|
def to_half(tensor, enable):
|
||||||
|
if enable and tensor.dtype == torch.float:
|
||||||
|
return tensor.half()
|
||||||
|
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae):
|
def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae):
|
||||||
shared.state.begin()
|
shared.state.begin()
|
||||||
shared.state.job = 'model-merge'
|
shared.state.job = 'model-merge'
|
||||||
|
@ -400,8 +407,7 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
|
||||||
else:
|
else:
|
||||||
theta_0[key] = theta_func2(a, b, multiplier)
|
theta_0[key] = theta_func2(a, b, multiplier)
|
||||||
|
|
||||||
if save_as_half:
|
theta_0[key] = to_half(theta_0[key], save_as_half)
|
||||||
theta_0[key] = theta_0[key].half()
|
|
||||||
|
|
||||||
shared.state.sampling_step += 1
|
shared.state.sampling_step += 1
|
||||||
|
|
||||||
|
@ -416,10 +422,14 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
|
||||||
for key in vae_dict.keys():
|
for key in vae_dict.keys():
|
||||||
theta_0_key = 'first_stage_model.' + key
|
theta_0_key = 'first_stage_model.' + key
|
||||||
if theta_0_key in theta_0:
|
if theta_0_key in theta_0:
|
||||||
theta_0[theta_0_key] = vae_dict[key].half() if save_as_half else vae_dict[key]
|
theta_0[theta_0_key] = to_half(vae_dict[key], save_as_half)
|
||||||
|
|
||||||
del vae_dict
|
del vae_dict
|
||||||
|
|
||||||
|
if save_as_half and not theta_func2:
|
||||||
|
for key in theta_0.keys():
|
||||||
|
theta_0[key] = to_half(theta_0[key], save_as_half)
|
||||||
|
|
||||||
ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path
|
ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path
|
||||||
|
|
||||||
filename = filename_generator() if custom_name == '' else custom_name
|
filename = filename_generator() if custom_name == '' else custom_name
|
||||||
|
|
Loading…
Reference in a new issue