init job and add info to model merge

This commit is contained in:
Vladimir Mandic 2023-01-03 10:21:51 -05:00 committed by GitHub
parent e9fb9bb0c2
commit 1d9dc48efd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -242,6 +242,9 @@ def run_pnginfo(image):
def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format): def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format):
shared.state.begin()
shared.state.job = 'model-merge'
def weighted_sum(theta0, theta1, alpha): def weighted_sum(theta0, theta1, alpha):
return ((1 - alpha) * theta0) + (alpha * theta1) return ((1 - alpha) * theta0) + (alpha * theta1)
@ -263,8 +266,11 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
theta_func1, theta_func2 = theta_funcs[interp_method] theta_func1, theta_func2 = theta_funcs[interp_method]
if theta_func1 and not tertiary_model_info: if theta_func1 and not tertiary_model_info:
shared.state.textinfo = "Failed: Interpolation method requires a tertiary model."
shared.state.end()
return ["Failed: Interpolation method requires a tertiary model."] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)] return ["Failed: Interpolation method requires a tertiary model."] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)]
shared.state.textinfo = f"Loading {secondary_model_info.filename}..."
print(f"Loading {secondary_model_info.filename}...") print(f"Loading {secondary_model_info.filename}...")
theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu') theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu')
@ -281,6 +287,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
theta_1[key] = torch.zeros_like(theta_1[key]) theta_1[key] = torch.zeros_like(theta_1[key])
del theta_2 del theta_2
shared.state.textinfo = f"Loading {primary_model_info.filename}..."
print(f"Loading {primary_model_info.filename}...") print(f"Loading {primary_model_info.filename}...")
theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cpu') theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cpu')
@ -291,6 +298,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
a = theta_0[key] a = theta_0[key]
b = theta_1[key] b = theta_1[key]
shared.state.textinfo = f'Merging layer {key}'
# this enables merging an inpainting model (A) with another one (B); # this enables merging an inpainting model (A) with another one (B);
# where normal model would have 4 channels, for latenst space, inpainting model would # where normal model would have 4 channels, for latenst space, inpainting model would
# have another 4 channels for unmasked picture's latent space, plus one channel for mask, for a total of 9 # have another 4 channels for unmasked picture's latent space, plus one channel for mask, for a total of 9
@ -303,8 +311,6 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier) theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier)
result_is_inpainting_model = True result_is_inpainting_model = True
else: else:
assert a.shape == b.shape, f'Incompatible shapes for layer {key}: A is {a.shape}, and B is {b.shape}'
theta_0[key] = theta_func2(a, b, multiplier) theta_0[key] = theta_func2(a, b, multiplier)
if save_as_half: if save_as_half:
@ -332,6 +338,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
output_modelname = os.path.join(ckpt_dir, filename) output_modelname = os.path.join(ckpt_dir, filename)
shared.state.textinfo = f"Saving to {output_modelname}..."
print(f"Saving to {output_modelname}...") print(f"Saving to {output_modelname}...")
_, extension = os.path.splitext(output_modelname) _, extension = os.path.splitext(output_modelname)
@ -343,4 +350,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
sd_models.list_models() sd_models.list_models()
print("Checkpoint saved.") print("Checkpoint saved.")
shared.state.textinfo = "Checkpoint saved to " + output_modelname
shared.state.end()
return ["Checkpoint saved to " + output_modelname] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)] return ["Checkpoint saved to " + output_modelname] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)]