diff --git a/modules/extras.py b/modules/extras.py index 532d869f..2e7b3751 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -209,7 +209,12 @@ def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_nam for key in tqdm.tqdm(theta_0.keys()): if 'model' in key and key in theta_1: - theta_0[key] = theta_func(theta_0[key], theta_1[key], theta_2[key] if theta_2 else None, (float(1.0) - interp_amount)) # Need to reverse the interp_amount to match the desired mix ration in the merged checkpoint + t2 = (theta_2 or {}).get(key) + if t2 is None: + t2 = torch.zeros_like(theta_0[key]) + + theta_0[key] = theta_func(theta_0[key], theta_1[key], t2, (float(1.0) - interp_amount)) # Need to reverse the interp_amount to match the desired mix ration in the merged checkpoint + if save_as_half: theta_0[key] = theta_0[key].half()