Add custom name and try-except
This commit is contained in:
parent
819fd3af40
commit
66fed8ffb8
2 changed files with 15 additions and 2 deletions
|
@ -141,7 +141,7 @@ def run_pnginfo(image):
|
|||
return '', geninfo, info
|
||||
|
||||
|
||||
def run_modelmerger(primary_model_name, secondary_model_name, interp_method, interp_amount, save_as_half):
|
||||
def run_modelmerger(primary_model_name, secondary_model_name, interp_method, interp_amount, save_as_half, custom_name):
|
||||
# Linear interpolation (https://en.wikipedia.org/wiki/Linear_interpolation)
|
||||
def weighted_sum(theta0, theta1, alpha):
|
||||
return ((1 - alpha) * theta0) + (alpha * theta1)
|
||||
|
@ -190,6 +190,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, interp_method, int
|
|||
theta_0[key] = theta_0[key].half()
|
||||
|
||||
filename = primary_model_info.model_name + '_' + str(round(interp_amount, 2)) + '-' + secondary_model_info.model_name + '_' + str(round((float(1.0) - interp_amount), 2)) + '-' + interp_method.replace(" ", "_") + '-merged.ckpt'
|
||||
filename = filename if custom_name == '' else (custom_name + '.ckpt')
|
||||
output_modelname = os.path.join(shared.cmd_opts.ckpt_dir, filename)
|
||||
|
||||
print(f"Saving to {output_modelname}...")
|
||||
|
|
|
@ -882,6 +882,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
|
|||
with gr.Row():
|
||||
primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary Model Name")
|
||||
secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary Model Name")
|
||||
custom_name = gr.Textbox(label="Custom Name (Optional)")
|
||||
interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Interpolation Amount', value=0.3)
|
||||
interp_method = gr.Radio(choices=["Weighted Sum", "Sigmoid", "Inverse Sigmoid"], value="Weighted Sum", label="Interpolation Method")
|
||||
save_as_half = gr.Checkbox(value=False, label="Safe as float16")
|
||||
|
@ -1031,15 +1032,26 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
|
|||
inputs=components,
|
||||
outputs=[result, text_settings],
|
||||
)
|
||||
|
||||
def modelmerger(*args):
|
||||
try:
|
||||
results = run_modelmerger(*args)
|
||||
except Exception as e:
|
||||
print("Error loading/saving model file:", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
modules.sd_models.list_models() #To remove the potentially missing models from the list
|
||||
return ["Error loading/saving model file. It doesn't exist or the name contains illegal characters"] + [gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(3)]
|
||||
return results
|
||||
|
||||
modelmerger_merge.click(
|
||||
fn=run_modelmerger,
|
||||
fn=modelmerger,
|
||||
inputs=[
|
||||
primary_model_name,
|
||||
secondary_model_name,
|
||||
interp_method,
|
||||
interp_amount,
|
||||
save_as_half,
|
||||
custom_name,
|
||||
],
|
||||
outputs=[
|
||||
submit_result,
|
||||
|
|
Loading…
Reference in a new issue