allow baking in VAE in checkpoint merger tab

do not save config if it's the default for checkpoint merger tab
change file naming scheme for checkpoint merger tab
allow just saving A without any merging for checkpoint merger tab
some stylistic changes for UI in checkpoint merger tab
This commit is contained in:
AUTOMATIC 2023-01-19 10:39:51 +03:00
parent c7e50425f6
commit 0f5dbfffd0
7 changed files with 102 additions and 59 deletions

View file

@ -92,6 +92,7 @@ titles = {
"Weighted sum": "Result = A * (1 - M) + B * M", "Weighted sum": "Result = A * (1 - M) + B * M",
"Add difference": "Result = A + (B - C) * M", "Add difference": "Result = A + (B - C) * M",
"No interpolation": "Result = A",
"Initialization text": "If the number of tokens is more than the number of vectors, some may be skipped.\nLeave the textbox empty to start with zeroed out vectors", "Initialization text": "If the number of tokens is more than the number of vectors, some may be skipped.\nLeave the textbox empty to start with zeroed out vectors",
"Learning rate": "How fast should training go. Low values will take longer to train, high values may fail to converge (not generate accurate results) and/or may break the embedding (This has happened if you see Loss: nan in the training info textbox. If this happens, you need to manually restore your embedding from an older not-broken backup).\n\nYou can set a single numeric value, or multiple learning rates using the syntax:\n\n rate_1:max_steps_1, rate_2:max_steps_2, ...\n\nEG: 0.005:100, 1e-3:1000, 1e-5\n\nWill train with rate of 0.005 for first 100 steps, then 1e-3 until 1000 steps, then 1e-5 for all remaining steps.", "Learning rate": "How fast should training go. Low values will take longer to train, high values may fail to converge (not generate accurate results) and/or may break the embedding (This has happened if you see Loss: nan in the training info textbox. If this happens, you need to manually restore your embedding from an older not-broken backup).\n\nYou can set a single numeric value, or multiple learning rates using the syntax:\n\n rate_1:max_steps_1, rate_2:max_steps_2, ...\n\nEG: 0.005:100, 1e-3:1000, 1e-5\n\nWill train with rate of 0.005 for first 100 steps, then 1e-3 until 1000 steps, then 1e-5 for all remaining steps.",

View file

@ -176,8 +176,6 @@ function modelmerger(){
var id = randomId() var id = randomId()
requestProgress(id, gradioApp().getElementById('modelmerger_results_panel'), null, function(){}) requestProgress(id, gradioApp().getElementById('modelmerger_results_panel'), null, function(){})
gradioApp().getElementById('modelmerger_result').innerHTML = ''
var res = create_submit_args(arguments) var res = create_submit_args(arguments)
res[0] = id res[0] = id
return res return res

View file

@ -15,7 +15,7 @@ from typing import Callable, List, OrderedDict, Tuple
from functools import partial from functools import partial
from dataclasses import dataclass from dataclasses import dataclass
from modules import processing, shared, images, devices, sd_models, sd_samplers from modules import processing, shared, images, devices, sd_models, sd_samplers, sd_vae
from modules.shared import opts from modules.shared import opts
import modules.gfpgan_model import modules.gfpgan_model
from modules.ui import plaintext_to_html from modules.ui import plaintext_to_html
@ -251,7 +251,8 @@ def run_pnginfo(image):
def create_config(ckpt_result, config_source, a, b, c): def create_config(ckpt_result, config_source, a, b, c):
def config(x): def config(x):
return sd_models.find_checkpoint_config(x) if x else None res = sd_models.find_checkpoint_config(x) if x else None
return res if res != shared.sd_default_config else None
if config_source == 0: if config_source == 0:
cfg = config(a) or config(b) or config(c) cfg = config(a) or config(b) or config(c)
@ -274,10 +275,12 @@ def create_config(ckpt_result, config_source, a, b, c):
shutil.copyfile(cfg, checkpoint_filename) shutil.copyfile(cfg, checkpoint_filename)
def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source): chckpoint_dict_skip_on_merge = ["cond_stage_model.transformer.text_model.embeddings.position_ids"]
def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae):
shared.state.begin() shared.state.begin()
shared.state.job = 'model-merge' shared.state.job = 'model-merge'
shared.state.job_count = 1
def fail(message): def fail(message):
shared.state.textinfo = message shared.state.textinfo = message
@ -293,41 +296,68 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
def add_difference(theta0, theta1_2_diff, alpha): def add_difference(theta0, theta1_2_diff, alpha):
return theta0 + (alpha * theta1_2_diff) return theta0 + (alpha * theta1_2_diff)
def filename_weighed_sum():
a = primary_model_info.model_name
b = secondary_model_info.model_name
Ma = round(1 - multiplier, 2)
Mb = round(multiplier, 2)
return f"{Ma}({a}) + {Mb}({b})"
def filename_add_differnece():
a = primary_model_info.model_name
b = secondary_model_info.model_name
c = tertiary_model_info.model_name
M = round(multiplier, 2)
return f"{a} + {M}({b} - {c})"
def filename_nothing():
return primary_model_info.model_name
theta_funcs = {
"Weighted sum": (filename_weighed_sum, None, weighted_sum),
"Add difference": (filename_add_differnece, get_difference, add_difference),
"No interpolation": (filename_nothing, None, None),
}
filename_generator, theta_func1, theta_func2 = theta_funcs[interp_method]
shared.state.job_count = (1 if theta_func1 else 0) + (1 if theta_func2 else 0)
if not primary_model_name: if not primary_model_name:
return fail("Failed: Merging requires a primary model.") return fail("Failed: Merging requires a primary model.")
primary_model_info = sd_models.checkpoints_list[primary_model_name] primary_model_info = sd_models.checkpoints_list[primary_model_name]
if not secondary_model_name: if theta_func2 and not secondary_model_name:
return fail("Failed: Merging requires a secondary model.") return fail("Failed: Merging requires a secondary model.")
secondary_model_info = sd_models.checkpoints_list[secondary_model_name]
theta_funcs = { secondary_model_info = sd_models.checkpoints_list[secondary_model_name] if theta_func2 else None
"Weighted sum": (None, weighted_sum),
"Add difference": (get_difference, add_difference),
}
theta_func1, theta_func2 = theta_funcs[interp_method]
if theta_func1 and not tertiary_model_name: if theta_func1 and not tertiary_model_name:
return fail(f"Failed: Interpolation method ({interp_method}) requires a tertiary model.") return fail(f"Failed: Interpolation method ({interp_method}) requires a tertiary model.")
tertiary_model_info = sd_models.checkpoints_list[tertiary_model_name] if theta_func1 else None tertiary_model_info = sd_models.checkpoints_list[tertiary_model_name] if theta_func1 else None
result_is_inpainting_model = False result_is_inpainting_model = False
shared.state.textinfo = f"Loading {secondary_model_info.filename}..." if theta_func2:
print(f"Loading {secondary_model_info.filename}...") shared.state.textinfo = f"Loading B"
theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu') print(f"Loading {secondary_model_info.filename}...")
theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu')
else:
theta_1 = None
if theta_func1: if theta_func1:
shared.state.job_count += 1 shared.state.textinfo = f"Loading C"
print(f"Loading {tertiary_model_info.filename}...") print(f"Loading {tertiary_model_info.filename}...")
theta_2 = sd_models.read_state_dict(tertiary_model_info.filename, map_location='cpu') theta_2 = sd_models.read_state_dict(tertiary_model_info.filename, map_location='cpu')
shared.state.textinfo = 'Merging B and C'
shared.state.sampling_steps = len(theta_1.keys()) shared.state.sampling_steps = len(theta_1.keys())
for key in tqdm.tqdm(theta_1.keys()): for key in tqdm.tqdm(theta_1.keys()):
if key in chckpoint_dict_skip_on_merge:
continue
if 'model' in key: if 'model' in key:
if key in theta_2: if key in theta_2:
t2 = theta_2.get(key, torch.zeros_like(theta_1[key])) t2 = theta_2.get(key, torch.zeros_like(theta_1[key]))
@ -345,12 +375,10 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cpu') theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cpu')
print("Merging...") print("Merging...")
shared.state.textinfo = 'Merging A and B'
chckpoint_dict_skip_on_merge = ["cond_stage_model.transformer.text_model.embeddings.position_ids"]
shared.state.sampling_steps = len(theta_0.keys()) shared.state.sampling_steps = len(theta_0.keys())
for key in tqdm.tqdm(theta_0.keys()): for key in tqdm.tqdm(theta_0.keys()):
if 'model' in key and key in theta_1: if theta_1 and 'model' in key and key in theta_1:
if key in chckpoint_dict_skip_on_merge: if key in chckpoint_dict_skip_on_merge:
continue continue
@ -358,7 +386,6 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
a = theta_0[key] a = theta_0[key]
b = theta_1[key] b = theta_1[key]
shared.state.textinfo = f'Merging layer {key}'
# this enables merging an inpainting model (A) with another one (B); # this enables merging an inpainting model (A) with another one (B);
# where normal model would have 4 channels, for latenst space, inpainting model would # where normal model would have 4 channels, for latenst space, inpainting model would
# have another 4 channels for unmasked picture's latent space, plus one channel for mask, for a total of 9 # have another 4 channels for unmasked picture's latent space, plus one channel for mask, for a total of 9
@ -378,34 +405,31 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
shared.state.sampling_step += 1 shared.state.sampling_step += 1
# I believe this part should be discarded, but I'll leave it for now until I am sure
for key in theta_1.keys():
if 'model' in key and key not in theta_0:
if key in chckpoint_dict_skip_on_merge:
continue
theta_0[key] = theta_1[key]
if save_as_half:
theta_0[key] = theta_0[key].half()
del theta_1 del theta_1
bake_in_vae_filename = sd_vae.vae_dict.get(bake_in_vae, None)
if bake_in_vae_filename is not None:
print(f"Baking in VAE from {bake_in_vae_filename}")
shared.state.textinfo = 'Baking in VAE'
vae_dict = sd_vae.load_vae_dict(bake_in_vae_filename, map_location='cpu')
for key in vae_dict.keys():
theta_0_key = 'first_stage_model.' + key
if theta_0_key in theta_0:
theta_0[theta_0_key] = vae_dict[key].half() if save_as_half else vae_dict[key]
del vae_dict
ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path
filename = \ filename = filename_generator() if custom_name == '' else custom_name
primary_model_info.model_name + '_' + str(round(1-multiplier, 2)) + '-' + \ filename += ".inpainting" if result_is_inpainting_model else ""
secondary_model_info.model_name + '_' + str(round(multiplier, 2)) + '-' + \ filename += "." + checkpoint_format
interp_method.replace(" ", "_") + \
'-merged.' + \
("inpainting." if result_is_inpainting_model else "") + \
checkpoint_format
filename = filename if custom_name == '' else (custom_name + '.' + checkpoint_format)
output_modelname = os.path.join(ckpt_dir, filename) output_modelname = os.path.join(ckpt_dir, filename)
shared.state.nextjob() shared.state.nextjob()
shared.state.textinfo = f"Saving to {output_modelname}..." shared.state.textinfo = "Saving"
print(f"Saving to {output_modelname}...") print(f"Saving to {output_modelname}...")
_, extension = os.path.splitext(output_modelname) _, extension = os.path.splitext(output_modelname)
@ -418,8 +442,8 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
create_config(output_modelname, config_source, primary_model_info, secondary_model_info, tertiary_model_info) create_config(output_modelname, config_source, primary_model_info, secondary_model_info, tertiary_model_info)
print("Checkpoint saved.") print(f"Checkpoint saved to {output_modelname}.")
shared.state.textinfo = "Checkpoint saved to " + output_modelname shared.state.textinfo = "Checkpoint saved"
shared.state.end() shared.state.end()
return [*[gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)], "Checkpoint saved to " + output_modelname] return [*[gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)], "Checkpoint saved to " + output_modelname]

View file

@ -120,6 +120,12 @@ def resolve_vae(checkpoint_file):
return None, None return None, None
def load_vae_dict(filename, map_location):
vae_ckpt = sd_models.read_state_dict(filename, map_location=map_location)
vae_dict_1 = {k: v for k, v in vae_ckpt.items() if k[0:4] != "loss" and k not in vae_ignore_keys}
return vae_dict_1
def load_vae(model, vae_file=None, vae_source="from unknown source"): def load_vae(model, vae_file=None, vae_source="from unknown source"):
global vae_dict, loaded_vae_file global vae_dict, loaded_vae_file
# save_settings = False # save_settings = False
@ -137,8 +143,7 @@ def load_vae(model, vae_file=None, vae_source="from unknown source"):
print(f"Loading VAE weights {vae_source}: {vae_file}") print(f"Loading VAE weights {vae_source}: {vae_file}")
store_base_vae(model) store_base_vae(model)
vae_ckpt = sd_models.read_state_dict(vae_file, map_location=shared.weight_load_location) vae_dict_1 = load_vae_dict(vae_file, map_location=shared.weight_load_location)
vae_dict_1 = {k: v for k, v in vae_ckpt.items() if k[0:4] != "loss" and k not in vae_ignore_keys}
_load_vae_dict(model, vae_dict_1) _load_vae_dict(model, vae_dict_1)
if cache_enabled: if cache_enabled:

View file

@ -20,10 +20,11 @@ from modules.paths import models_path, script_path, sd_path
demo = None demo = None
sd_default_config = os.path.join(script_path, "configs/v1-inference.yaml")
sd_model_file = os.path.join(script_path, 'model.ckpt') sd_model_file = os.path.join(script_path, 'model.ckpt')
default_sd_model_file = sd_model_file default_sd_model_file = sd_model_file
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default=os.path.join(script_path, "configs/v1-inference.yaml"), help="path to config which constructs model",) parser.add_argument("--config", type=str, default=sd_default_config, help="path to config which constructs model",)
parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",) parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",)
parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints") parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints")
parser.add_argument("--vae-dir", type=str, default=None, help="Path to directory with VAE files") parser.add_argument("--vae-dir", type=str, default=None, help="Path to directory with VAE files")

View file

@ -20,7 +20,7 @@ import numpy as np
from PIL import Image, PngImagePlugin from PIL import Image, PngImagePlugin
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae
from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML
from modules.paths import script_path from modules.paths import script_path
@ -1185,7 +1185,7 @@ def create_ui():
with gr.Column(variant='compact'): with gr.Column(variant='compact'):
gr.HTML(value="<p style='margin-bottom: 2.5em'>A merger of the two checkpoints will be generated in your <b>checkpoint</b> directory.</p>") gr.HTML(value="<p style='margin-bottom: 2.5em'>A merger of the two checkpoints will be generated in your <b>checkpoint</b> directory.</p>")
with FormRow(): with FormRow(elem_id="modelmerger_models"):
primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)") primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)")
create_refresh_button(primary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_A") create_refresh_button(primary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_A")
@ -1197,13 +1197,20 @@ def create_ui():
custom_name = gr.Textbox(label="Custom Name (Optional)", elem_id="modelmerger_custom_name") custom_name = gr.Textbox(label="Custom Name (Optional)", elem_id="modelmerger_custom_name")
interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3, elem_id="modelmerger_interp_amount") interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3, elem_id="modelmerger_interp_amount")
interp_method = gr.Radio(choices=["Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method", elem_id="modelmerger_interp_method") interp_method = gr.Radio(choices=["No interpolation", "Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method", elem_id="modelmerger_interp_method")
with FormRow(): with FormRow():
checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="ckpt", label="Checkpoint format", elem_id="modelmerger_checkpoint_format") checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="ckpt", label="Checkpoint format", elem_id="modelmerger_checkpoint_format")
save_as_half = gr.Checkbox(value=False, label="Save as float16", elem_id="modelmerger_save_as_half") save_as_half = gr.Checkbox(value=False, label="Save as float16", elem_id="modelmerger_save_as_half")
config_source = gr.Radio(choices=["A, B or C", "B", "C", "Don't"], value="A, B or C", label="Copy config from", type="index", elem_id="modelmerger_config_method") with FormRow():
with gr.Column():
config_source = gr.Radio(choices=["A, B or C", "B", "C", "Don't"], value="A, B or C", label="Copy config from", type="index", elem_id="modelmerger_config_method")
with gr.Column():
with FormRow():
bake_in_vae = gr.Dropdown(choices=["None"] + list(sd_vae.vae_dict), value="None", label="Bake in VAE", elem_id="modelmerger_bake_in_vae")
create_refresh_button(bake_in_vae, sd_vae.refresh_vae_list, lambda: {"choices": ["None"] + list(sd_vae.vae_dict)}, "modelmerger_refresh_bake_in_vae")
with gr.Row(): with gr.Row():
modelmerger_merge = gr.Button(elem_id="modelmerger_merge", value="Merge", variant='primary') modelmerger_merge = gr.Button(elem_id="modelmerger_merge", value="Merge", variant='primary')
@ -1757,6 +1764,7 @@ def create_ui():
return [*[gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(4)], f"Error merging checkpoints: {e}"] return [*[gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(4)], f"Error merging checkpoints: {e}"]
return results return results
modelmerger_merge.click(fn=lambda: '', inputs=[], outputs=[modelmerger_result])
modelmerger_merge.click( modelmerger_merge.click(
fn=wrap_gradio_gpu_call(modelmerger, extra_outputs=lambda: [gr.update() for _ in range(4)]), fn=wrap_gradio_gpu_call(modelmerger, extra_outputs=lambda: [gr.update() for _ in range(4)]),
_js='modelmerger', _js='modelmerger',
@ -1771,6 +1779,7 @@ def create_ui():
custom_name, custom_name,
checkpoint_format, checkpoint_format,
config_source, config_source,
bake_in_vae,
], ],
outputs=[ outputs=[
primary_model_name, primary_model_name,

View file

@ -641,6 +641,16 @@ canvas[key="mask"] {
margin: 0.6em 0em 0.55em 0; margin: 0.6em 0em 0.55em 0;
} }
#modelmerger_results_container{
margin-top: 1em;
overflow: visible;
}
#modelmerger_models{
gap: 0;
}
#quicksettings .gr-button-tool{ #quicksettings .gr-button-tool{
margin: 0; margin: 0;
} }
@ -737,11 +747,6 @@ footer {
line-height: 2.4em; line-height: 2.4em;
} }
#modelmerger_results_container{
margin-top: 1em;
overflow: visible;
}
/* The following handles localization for right-to-left (RTL) languages like Arabic. /* The following handles localization for right-to-left (RTL) languages like Arabic.
The rtl media type will only be activated by the logic in javascript/localization.js. The rtl media type will only be activated by the logic in javascript/localization.js.
If you change anything above, you need to make sure it is RTL compliant by just running If you change anything above, you need to make sure it is RTL compliant by just running