add progress bar to modelmerger
This commit is contained in:
parent
7cfc645030
commit
c7e50425f6
5 changed files with 40 additions and 9 deletions
|
@ -172,6 +172,17 @@ function submit_img2img(){
|
||||||
return res
|
return res
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function modelmerger(){
|
||||||
|
var id = randomId()
|
||||||
|
requestProgress(id, gradioApp().getElementById('modelmerger_results_panel'), null, function(){})
|
||||||
|
|
||||||
|
gradioApp().getElementById('modelmerger_result').innerHTML = ''
|
||||||
|
|
||||||
|
var res = create_submit_args(arguments)
|
||||||
|
res[0] = id
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
function ask_for_style_name(_, prompt_text, negative_prompt_text) {
|
function ask_for_style_name(_, prompt_text, negative_prompt_text) {
|
||||||
name_ = prompt('Style name:')
|
name_ = prompt('Style name:')
|
||||||
|
|
|
@ -274,14 +274,15 @@ def create_config(ckpt_result, config_source, a, b, c):
|
||||||
shutil.copyfile(cfg, checkpoint_filename)
|
shutil.copyfile(cfg, checkpoint_filename)
|
||||||
|
|
||||||
|
|
||||||
def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source):
|
def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source):
|
||||||
shared.state.begin()
|
shared.state.begin()
|
||||||
shared.state.job = 'model-merge'
|
shared.state.job = 'model-merge'
|
||||||
|
shared.state.job_count = 1
|
||||||
|
|
||||||
def fail(message):
|
def fail(message):
|
||||||
shared.state.textinfo = message
|
shared.state.textinfo = message
|
||||||
shared.state.end()
|
shared.state.end()
|
||||||
return [message, *[gr.update() for _ in range(4)]]
|
return [*[gr.update() for _ in range(4)], message]
|
||||||
|
|
||||||
def weighted_sum(theta0, theta1, alpha):
|
def weighted_sum(theta0, theta1, alpha):
|
||||||
return ((1 - alpha) * theta0) + (alpha * theta1)
|
return ((1 - alpha) * theta0) + (alpha * theta1)
|
||||||
|
@ -320,9 +321,12 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
|
||||||
theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu')
|
theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu')
|
||||||
|
|
||||||
if theta_func1:
|
if theta_func1:
|
||||||
|
shared.state.job_count += 1
|
||||||
|
|
||||||
print(f"Loading {tertiary_model_info.filename}...")
|
print(f"Loading {tertiary_model_info.filename}...")
|
||||||
theta_2 = sd_models.read_state_dict(tertiary_model_info.filename, map_location='cpu')
|
theta_2 = sd_models.read_state_dict(tertiary_model_info.filename, map_location='cpu')
|
||||||
|
|
||||||
|
shared.state.sampling_steps = len(theta_1.keys())
|
||||||
for key in tqdm.tqdm(theta_1.keys()):
|
for key in tqdm.tqdm(theta_1.keys()):
|
||||||
if 'model' in key:
|
if 'model' in key:
|
||||||
if key in theta_2:
|
if key in theta_2:
|
||||||
|
@ -330,8 +334,12 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
|
||||||
theta_1[key] = theta_func1(theta_1[key], t2)
|
theta_1[key] = theta_func1(theta_1[key], t2)
|
||||||
else:
|
else:
|
||||||
theta_1[key] = torch.zeros_like(theta_1[key])
|
theta_1[key] = torch.zeros_like(theta_1[key])
|
||||||
|
|
||||||
|
shared.state.sampling_step += 1
|
||||||
del theta_2
|
del theta_2
|
||||||
|
|
||||||
|
shared.state.nextjob()
|
||||||
|
|
||||||
shared.state.textinfo = f"Loading {primary_model_info.filename}..."
|
shared.state.textinfo = f"Loading {primary_model_info.filename}..."
|
||||||
print(f"Loading {primary_model_info.filename}...")
|
print(f"Loading {primary_model_info.filename}...")
|
||||||
theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cpu')
|
theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cpu')
|
||||||
|
@ -340,6 +348,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
|
||||||
|
|
||||||
chckpoint_dict_skip_on_merge = ["cond_stage_model.transformer.text_model.embeddings.position_ids"]
|
chckpoint_dict_skip_on_merge = ["cond_stage_model.transformer.text_model.embeddings.position_ids"]
|
||||||
|
|
||||||
|
shared.state.sampling_steps = len(theta_0.keys())
|
||||||
for key in tqdm.tqdm(theta_0.keys()):
|
for key in tqdm.tqdm(theta_0.keys()):
|
||||||
if 'model' in key and key in theta_1:
|
if 'model' in key and key in theta_1:
|
||||||
|
|
||||||
|
@ -367,6 +376,8 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
|
||||||
if save_as_half:
|
if save_as_half:
|
||||||
theta_0[key] = theta_0[key].half()
|
theta_0[key] = theta_0[key].half()
|
||||||
|
|
||||||
|
shared.state.sampling_step += 1
|
||||||
|
|
||||||
# I believe this part should be discarded, but I'll leave it for now until I am sure
|
# I believe this part should be discarded, but I'll leave it for now until I am sure
|
||||||
for key in theta_1.keys():
|
for key in theta_1.keys():
|
||||||
if 'model' in key and key not in theta_0:
|
if 'model' in key and key not in theta_0:
|
||||||
|
@ -393,6 +404,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
|
||||||
|
|
||||||
output_modelname = os.path.join(ckpt_dir, filename)
|
output_modelname = os.path.join(ckpt_dir, filename)
|
||||||
|
|
||||||
|
shared.state.nextjob()
|
||||||
shared.state.textinfo = f"Saving to {output_modelname}..."
|
shared.state.textinfo = f"Saving to {output_modelname}..."
|
||||||
print(f"Saving to {output_modelname}...")
|
print(f"Saving to {output_modelname}...")
|
||||||
|
|
||||||
|
@ -410,4 +422,4 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
|
||||||
shared.state.textinfo = "Checkpoint saved to " + output_modelname
|
shared.state.textinfo = "Checkpoint saved to " + output_modelname
|
||||||
shared.state.end()
|
shared.state.end()
|
||||||
|
|
||||||
return ["Checkpoint saved to " + output_modelname] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)]
|
return [*[gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)], "Checkpoint saved to " + output_modelname]
|
||||||
|
|
|
@ -72,7 +72,7 @@ def progressapi(req: ProgressRequest):
|
||||||
|
|
||||||
if job_count > 0:
|
if job_count > 0:
|
||||||
progress += job_no / job_count
|
progress += job_no / job_count
|
||||||
if sampling_steps > 0:
|
if sampling_steps > 0 and job_count > 0:
|
||||||
progress += 1 / job_count * sampling_step / sampling_steps
|
progress += 1 / job_count * sampling_step / sampling_steps
|
||||||
|
|
||||||
progress = min(progress, 1)
|
progress = min(progress, 1)
|
||||||
|
|
|
@ -1208,8 +1208,9 @@ def create_ui():
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
modelmerger_merge = gr.Button(elem_id="modelmerger_merge", value="Merge", variant='primary')
|
modelmerger_merge = gr.Button(elem_id="modelmerger_merge", value="Merge", variant='primary')
|
||||||
|
|
||||||
with gr.Column(variant='panel'):
|
with gr.Column(variant='compact', elem_id="modelmerger_results_container"):
|
||||||
submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False)
|
with gr.Group(elem_id="modelmerger_results_panel"):
|
||||||
|
modelmerger_result = gr.HTML(elem_id="modelmerger_result", show_label=False)
|
||||||
|
|
||||||
with gr.Blocks(analytics_enabled=False) as train_interface:
|
with gr.Blocks(analytics_enabled=False) as train_interface:
|
||||||
with gr.Row().style(equal_height=False):
|
with gr.Row().style(equal_height=False):
|
||||||
|
@ -1753,12 +1754,14 @@ def create_ui():
|
||||||
print("Error loading/saving model file:", file=sys.stderr)
|
print("Error loading/saving model file:", file=sys.stderr)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
modules.sd_models.list_models() # to remove the potentially missing models from the list
|
modules.sd_models.list_models() # to remove the potentially missing models from the list
|
||||||
return [f"Error merging checkpoints: {e}"] + [gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(4)]
|
return [*[gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(4)], f"Error merging checkpoints: {e}"]
|
||||||
return results
|
return results
|
||||||
|
|
||||||
modelmerger_merge.click(
|
modelmerger_merge.click(
|
||||||
fn=modelmerger,
|
fn=wrap_gradio_gpu_call(modelmerger, extra_outputs=lambda: [gr.update() for _ in range(4)]),
|
||||||
|
_js='modelmerger',
|
||||||
inputs=[
|
inputs=[
|
||||||
|
dummy_component,
|
||||||
primary_model_name,
|
primary_model_name,
|
||||||
secondary_model_name,
|
secondary_model_name,
|
||||||
tertiary_model_name,
|
tertiary_model_name,
|
||||||
|
@ -1770,11 +1773,11 @@ def create_ui():
|
||||||
config_source,
|
config_source,
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
submit_result,
|
|
||||||
primary_model_name,
|
primary_model_name,
|
||||||
secondary_model_name,
|
secondary_model_name,
|
||||||
tertiary_model_name,
|
tertiary_model_name,
|
||||||
component_dict['sd_model_checkpoint'],
|
component_dict['sd_model_checkpoint'],
|
||||||
|
modelmerger_result,
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -737,6 +737,11 @@ footer {
|
||||||
line-height: 2.4em;
|
line-height: 2.4em;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#modelmerger_results_container{
|
||||||
|
margin-top: 1em;
|
||||||
|
overflow: visible;
|
||||||
|
}
|
||||||
|
|
||||||
/* The following handles localization for right-to-left (RTL) languages like Arabic.
|
/* The following handles localization for right-to-left (RTL) languages like Arabic.
|
||||||
The rtl media type will only be activated by the logic in javascript/localization.js.
|
The rtl media type will only be activated by the logic in javascript/localization.js.
|
||||||
If you change anything above, you need to make sure it is RTL compliant by just running
|
If you change anything above, you need to make sure it is RTL compliant by just running
|
||||||
|
|
Loading…
Reference in a new issue