add progress bar to modelmerger

This commit is contained in:
AUTOMATIC 2023-01-19 09:25:37 +03:00
parent 7cfc645030
commit c7e50425f6
5 changed files with 40 additions and 9 deletions

View file

@ -172,6 +172,17 @@ function submit_img2img(){
return res return res
} }
function modelmerger(){
var id = randomId()
requestProgress(id, gradioApp().getElementById('modelmerger_results_panel'), null, function(){})
gradioApp().getElementById('modelmerger_result').innerHTML = ''
var res = create_submit_args(arguments)
res[0] = id
return res
}
function ask_for_style_name(_, prompt_text, negative_prompt_text) { function ask_for_style_name(_, prompt_text, negative_prompt_text) {
name_ = prompt('Style name:') name_ = prompt('Style name:')

View file

@ -274,14 +274,15 @@ def create_config(ckpt_result, config_source, a, b, c):
shutil.copyfile(cfg, checkpoint_filename) shutil.copyfile(cfg, checkpoint_filename)
def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source): def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source):
shared.state.begin() shared.state.begin()
shared.state.job = 'model-merge' shared.state.job = 'model-merge'
shared.state.job_count = 1
def fail(message): def fail(message):
shared.state.textinfo = message shared.state.textinfo = message
shared.state.end() shared.state.end()
return [message, *[gr.update() for _ in range(4)]] return [*[gr.update() for _ in range(4)], message]
def weighted_sum(theta0, theta1, alpha): def weighted_sum(theta0, theta1, alpha):
return ((1 - alpha) * theta0) + (alpha * theta1) return ((1 - alpha) * theta0) + (alpha * theta1)
@ -320,9 +321,12 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu') theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu')
if theta_func1: if theta_func1:
shared.state.job_count += 1
print(f"Loading {tertiary_model_info.filename}...") print(f"Loading {tertiary_model_info.filename}...")
theta_2 = sd_models.read_state_dict(tertiary_model_info.filename, map_location='cpu') theta_2 = sd_models.read_state_dict(tertiary_model_info.filename, map_location='cpu')
shared.state.sampling_steps = len(theta_1.keys())
for key in tqdm.tqdm(theta_1.keys()): for key in tqdm.tqdm(theta_1.keys()):
if 'model' in key: if 'model' in key:
if key in theta_2: if key in theta_2:
@ -330,8 +334,12 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
theta_1[key] = theta_func1(theta_1[key], t2) theta_1[key] = theta_func1(theta_1[key], t2)
else: else:
theta_1[key] = torch.zeros_like(theta_1[key]) theta_1[key] = torch.zeros_like(theta_1[key])
shared.state.sampling_step += 1
del theta_2 del theta_2
shared.state.nextjob()
shared.state.textinfo = f"Loading {primary_model_info.filename}..." shared.state.textinfo = f"Loading {primary_model_info.filename}..."
print(f"Loading {primary_model_info.filename}...") print(f"Loading {primary_model_info.filename}...")
theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cpu') theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cpu')
@ -340,6 +348,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
chckpoint_dict_skip_on_merge = ["cond_stage_model.transformer.text_model.embeddings.position_ids"] chckpoint_dict_skip_on_merge = ["cond_stage_model.transformer.text_model.embeddings.position_ids"]
shared.state.sampling_steps = len(theta_0.keys())
for key in tqdm.tqdm(theta_0.keys()): for key in tqdm.tqdm(theta_0.keys()):
if 'model' in key and key in theta_1: if 'model' in key and key in theta_1:
@ -367,6 +376,8 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
if save_as_half: if save_as_half:
theta_0[key] = theta_0[key].half() theta_0[key] = theta_0[key].half()
shared.state.sampling_step += 1
# I believe this part should be discarded, but I'll leave it for now until I am sure # I believe this part should be discarded, but I'll leave it for now until I am sure
for key in theta_1.keys(): for key in theta_1.keys():
if 'model' in key and key not in theta_0: if 'model' in key and key not in theta_0:
@ -393,6 +404,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
output_modelname = os.path.join(ckpt_dir, filename) output_modelname = os.path.join(ckpt_dir, filename)
shared.state.nextjob()
shared.state.textinfo = f"Saving to {output_modelname}..." shared.state.textinfo = f"Saving to {output_modelname}..."
print(f"Saving to {output_modelname}...") print(f"Saving to {output_modelname}...")
@ -410,4 +422,4 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
shared.state.textinfo = "Checkpoint saved to " + output_modelname shared.state.textinfo = "Checkpoint saved to " + output_modelname
shared.state.end() shared.state.end()
return ["Checkpoint saved to " + output_modelname] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)] return [*[gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)], "Checkpoint saved to " + output_modelname]

View file

@ -72,7 +72,7 @@ def progressapi(req: ProgressRequest):
if job_count > 0: if job_count > 0:
progress += job_no / job_count progress += job_no / job_count
if sampling_steps > 0: if sampling_steps > 0 and job_count > 0:
progress += 1 / job_count * sampling_step / sampling_steps progress += 1 / job_count * sampling_step / sampling_steps
progress = min(progress, 1) progress = min(progress, 1)

View file

@ -1208,8 +1208,9 @@ def create_ui():
with gr.Row(): with gr.Row():
modelmerger_merge = gr.Button(elem_id="modelmerger_merge", value="Merge", variant='primary') modelmerger_merge = gr.Button(elem_id="modelmerger_merge", value="Merge", variant='primary')
with gr.Column(variant='panel'): with gr.Column(variant='compact', elem_id="modelmerger_results_container"):
submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False) with gr.Group(elem_id="modelmerger_results_panel"):
modelmerger_result = gr.HTML(elem_id="modelmerger_result", show_label=False)
with gr.Blocks(analytics_enabled=False) as train_interface: with gr.Blocks(analytics_enabled=False) as train_interface:
with gr.Row().style(equal_height=False): with gr.Row().style(equal_height=False):
@ -1753,12 +1754,14 @@ def create_ui():
print("Error loading/saving model file:", file=sys.stderr) print("Error loading/saving model file:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr) print(traceback.format_exc(), file=sys.stderr)
modules.sd_models.list_models() # to remove the potentially missing models from the list modules.sd_models.list_models() # to remove the potentially missing models from the list
return [f"Error merging checkpoints: {e}"] + [gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(4)] return [*[gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(4)], f"Error merging checkpoints: {e}"]
return results return results
modelmerger_merge.click( modelmerger_merge.click(
fn=modelmerger, fn=wrap_gradio_gpu_call(modelmerger, extra_outputs=lambda: [gr.update() for _ in range(4)]),
_js='modelmerger',
inputs=[ inputs=[
dummy_component,
primary_model_name, primary_model_name,
secondary_model_name, secondary_model_name,
tertiary_model_name, tertiary_model_name,
@ -1770,11 +1773,11 @@ def create_ui():
config_source, config_source,
], ],
outputs=[ outputs=[
submit_result,
primary_model_name, primary_model_name,
secondary_model_name, secondary_model_name,
tertiary_model_name, tertiary_model_name,
component_dict['sd_model_checkpoint'], component_dict['sd_model_checkpoint'],
modelmerger_result,
] ]
) )

View file

@ -737,6 +737,11 @@ footer {
line-height: 2.4em; line-height: 2.4em;
} }
#modelmerger_results_container{
margin-top: 1em;
overflow: visible;
}
/* The following handles localization for right-to-left (RTL) languages like Arabic. /* The following handles localization for right-to-left (RTL) languages like Arabic.
The rtl media type will only be activated by the logic in javascript/localization.js. The rtl media type will only be activated by the logic in javascript/localization.js.
If you change anything above, you need to make sure it is RTL compliant by just running If you change anything above, you need to make sure it is RTL compliant by just running