added console outputs, more clear indication of progress, and ability to specify full filename to checkpoint merger

restore "Loading..." text
This commit is contained in:
AUTOMATIC 2022-09-27 10:44:00 +03:00
parent a9dc307a21
commit ada901ed66
2 changed files with 35 additions and 16 deletions

View file

@ -4,6 +4,7 @@ import numpy as np
from PIL import Image from PIL import Image
import torch import torch
import tqdm
from modules import processing, shared, images, devices from modules import processing, shared, images, devices
from modules.shared import opts from modules.shared import opts
@ -149,28 +150,45 @@ def run_modelmerger(modelname_0, modelname_1, interp_method, interp_amount):
alpha = alpha * alpha * (3 - (2 * alpha)) alpha = alpha * alpha * (3 - (2 * alpha))
return theta0 + ((theta1 - theta0) * alpha) return theta0 + ((theta1 - theta0) * alpha)
model_0 = torch.load('models/' + modelname_0 + '.ckpt') if os.path.exists(modelname_0):
model_1 = torch.load('models/' + modelname_1 + '.ckpt') model0_filename = modelname_0
modelname_0 = os.path.splitext(os.path.basename(modelname_0))[0]
else:
model0_filename = 'models/' + modelname_0 + '.ckpt'
if os.path.exists(modelname_1):
model1_filename = modelname_1
modelname_1 = os.path.splitext(os.path.basename(modelname_1))[0]
else:
model1_filename = 'models/' + modelname_1 + '.ckpt'
print(f"Loading {model0_filename}...")
model_0 = torch.load(model0_filename, map_location='cpu')
print(f"Loading {model1_filename}...")
model_1 = torch.load(model1_filename, map_location='cpu')
theta_0 = model_0['state_dict'] theta_0 = model_0['state_dict']
theta_1 = model_1['state_dict'] theta_1 = model_1['state_dict']
theta_func = weighted_sum
theta_funcs = {
if interp_method == "Weighted Sum": "Weighted Sum": weighted_sum,
theta_func = weighted_sum "Sigmoid": sigmoid,
if interp_method == "Sigmoid": }
theta_func = sigmoid theta_func = theta_funcs[interp_method]
for key in theta_0.keys(): print(f"Merging...")
for key in tqdm.tqdm(theta_0.keys()):
if 'model' in key and key in theta_1: if 'model' in key and key in theta_1:
theta_0[key] = theta_func(theta_0[key], theta_1[key], interp_amount) theta_0[key] = theta_func(theta_0[key], theta_1[key], interp_amount)
for key in theta_1.keys(): for key in theta_1.keys():
if 'model' in key and key not in theta_0: if 'model' in key and key not in theta_0:
theta_0[key] = theta_1[key] theta_0[key] = theta_1[key]
output_modelname = 'models/' + modelname_0 + '-' + modelname_1 + '-merged.ckpt'; output_modelname = 'models/' + modelname_0 + '-' + modelname_1 + '-merged.ckpt'
print(f"Saving to {output_modelname}...")
torch.save(model_0, output_modelname) torch.save(model_0, output_modelname)
return "<p>Model saved to " + output_modelname + "</p>" print(f"Checkpoint saved.")
return "Checkpoint saved to " + output_modelname

View file

@ -49,6 +49,7 @@ sample_img2img = sample_img2img if os.path.exists(sample_img2img) else None
css_hide_progressbar = """ css_hide_progressbar = """
.wrap .m-12 svg { display:none!important; } .wrap .m-12 svg { display:none!important; }
.wrap .m-12::before { content:"Loading..." }
.progress-bar { display:none!important; } .progress-bar { display:none!important; }
.meta-text { display:none!important; } .meta-text { display:none!important; }
""" """
@ -865,7 +866,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
submit = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary') submit = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary')
with gr.Column(variant='panel'): with gr.Column(variant='panel'):
submit_result = gr.HTML(elem_id="modelmerger_result") submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False)
submit.click( submit.click(
fn=run_modelmerger, fn=run_modelmerger,