error out with a readable message in chwewckpoint merger for incompatible tensor shapes (ie when trying to merge SD1.5 with SD2.0)
This commit is contained in:
parent
4dbde228ff
commit
84dd7e8e24
2 changed files with 3 additions and 1 deletions
|
@ -303,6 +303,8 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
|
||||||
theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier)
|
theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier)
|
||||||
result_is_inpainting_model = True
|
result_is_inpainting_model = True
|
||||||
else:
|
else:
|
||||||
|
assert a.shape == b.shape, f'Incompatible shapes for layer {key}: A is {a.shape}, and B is {b.shape}'
|
||||||
|
|
||||||
theta_0[key] = theta_func2(a, b, multiplier)
|
theta_0[key] = theta_func2(a, b, multiplier)
|
||||||
|
|
||||||
if save_as_half:
|
if save_as_half:
|
||||||
|
|
|
@ -1663,7 +1663,7 @@ def create_ui():
|
||||||
print("Error loading/saving model file:", file=sys.stderr)
|
print("Error loading/saving model file:", file=sys.stderr)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
modules.sd_models.list_models() # to remove the potentially missing models from the list
|
modules.sd_models.list_models() # to remove the potentially missing models from the list
|
||||||
return ["Error loading/saving model file. It doesn't exist or the name contains illegal characters"] + [gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(3)]
|
return [f"Error merging checkpoints: {e}"] + [gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(4)]
|
||||||
return results
|
return results
|
||||||
|
|
||||||
modelmerger_merge.click(
|
modelmerger_merge.click(
|
||||||
|
|
Loading…
Reference in a new issue