make it possible to merge inpainting model with non-inpainting one

This commit is contained in:
AUTOMATIC 2022-12-04 12:30:44 +03:00
parent 8504db5170
commit 44c46f0ed3

View file

@ -247,6 +247,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_nam
primary_model_info = sd_models.checkpoints_list[primary_model_name]
secondary_model_info = sd_models.checkpoints_list[secondary_model_name]
teritary_model_info = sd_models.checkpoints_list.get(teritary_model_name, None)
result_is_inpainting_model = False
print(f"Loading {primary_model_info.filename}...")
theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cpu')
@ -280,8 +281,22 @@ def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_nam
for key in tqdm.tqdm(theta_0.keys()):
if 'model' in key and key in theta_1:
a = theta_0[key]
b = theta_1[key]
theta_0[key] = theta_func2(theta_0[key], theta_1[key], multiplier)
# this enables merging an inpainting model (A) with another one (B);
# where normal model would have 4 channels, for latenst space, inpainting model would
# have another 4 channels for unmasked picture's latent space, plus one channel for mask, for a total of 9
if a.shape != b.shape and a.shape[0:1] + a.shape[2:] == b.shape[0:1] + b.shape[2:]:
if a.shape[1] == 4 and b.shape[1] == 9:
raise RuntimeError("When merging inpainting model with a normal one, A must be the inpainting model.")
assert a.shape[1] == 9 and b.shape[1] == 4, f"Bad dimensions for merged layer {key}: A={a.shape}, B={b.shape}"
theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier)
result_is_inpainting_model = True
else:
theta_0[key] = theta_func2(a, b, multiplier)
if save_as_half:
theta_0[key] = theta_0[key].half()
@ -295,8 +310,16 @@ def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_nam
ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path
filename = primary_model_info.model_name + '_' + str(round(1-multiplier, 2)) + '-' + secondary_model_info.model_name + '_' + str(round(multiplier, 2)) + '-' + interp_method.replace(" ", "_") + '-merged.' + checkpoint_format
filename = \
primary_model_info.model_name + '_' + str(round(1-multiplier, 2)) + '-' + \
secondary_model_info.model_name + '_' + str(round(multiplier, 2)) + '-' + \
interp_method.replace(" ", "_") + \
'-merged.' + \
("inpainting." if result_is_inpainting_model else "") + \
checkpoint_format
filename = filename if custom_name == '' else (custom_name + '.' + checkpoint_format)
output_modelname = os.path.join(ckpt_dir, filename)
print(f"Saving to {output_modelname}...")