allow baking in VAE in checkpoint merger tab

do not save config if it's the default for checkpoint merger tab
change file naming scheme for checkpoint merger tab
allow just saving A without any merging for checkpoint merger tab
some stylistic changes for UI in checkpoint merger tab
This commit is contained in:
AUTOMATIC 2023-01-19 10:39:51 +03:00
parent c7e50425f6
commit 0f5dbfffd0
7 changed files with 102 additions and 59 deletions

View file

@ -92,6 +92,7 @@ titles = {
"Weighted sum": "Result = A * (1 - M) + B * M",
"Add difference": "Result = A + (B - C) * M",
"No interpolation": "Result = A",
"Initialization text": "If the number of tokens is more than the number of vectors, some may be skipped.\nLeave the textbox empty to start with zeroed out vectors",
"Learning rate": "How fast should training go. Low values will take longer to train, high values may fail to converge (not generate accurate results) and/or may break the embedding (This has happened if you see Loss: nan in the training info textbox. If this happens, you need to manually restore your embedding from an older not-broken backup).\n\nYou can set a single numeric value, or multiple learning rates using the syntax:\n\n rate_1:max_steps_1, rate_2:max_steps_2, ...\n\nEG: 0.005:100, 1e-3:1000, 1e-5\n\nWill train with rate of 0.005 for first 100 steps, then 1e-3 until 1000 steps, then 1e-5 for all remaining steps.",

View file

@ -176,8 +176,6 @@ function modelmerger(){
var id = randomId()
requestProgress(id, gradioApp().getElementById('modelmerger_results_panel'), null, function(){})
gradioApp().getElementById('modelmerger_result').innerHTML = ''
var res = create_submit_args(arguments)
res[0] = id
return res

View file

@ -15,7 +15,7 @@ from typing import Callable, List, OrderedDict, Tuple
from functools import partial
from dataclasses import dataclass
from modules import processing, shared, images, devices, sd_models, sd_samplers
from modules import processing, shared, images, devices, sd_models, sd_samplers, sd_vae
from modules.shared import opts
import modules.gfpgan_model
from modules.ui import plaintext_to_html
@ -251,7 +251,8 @@ def run_pnginfo(image):
def create_config(ckpt_result, config_source, a, b, c):
def config(x):
return sd_models.find_checkpoint_config(x) if x else None
res = sd_models.find_checkpoint_config(x) if x else None
return res if res != shared.sd_default_config else None
if config_source == 0:
cfg = config(a) or config(b) or config(c)
@ -274,10 +275,12 @@ def create_config(ckpt_result, config_source, a, b, c):
shutil.copyfile(cfg, checkpoint_filename)
def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source):
chckpoint_dict_skip_on_merge = ["cond_stage_model.transformer.text_model.embeddings.position_ids"]
def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae):
shared.state.begin()
shared.state.job = 'model-merge'
shared.state.job_count = 1
def fail(message):
shared.state.textinfo = message
@ -293,41 +296,68 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
def add_difference(theta0, theta1_2_diff, alpha):
return theta0 + (alpha * theta1_2_diff)
def filename_weighed_sum():
a = primary_model_info.model_name
b = secondary_model_info.model_name
Ma = round(1 - multiplier, 2)
Mb = round(multiplier, 2)
return f"{Ma}({a}) + {Mb}({b})"
def filename_add_differnece():
a = primary_model_info.model_name
b = secondary_model_info.model_name
c = tertiary_model_info.model_name
M = round(multiplier, 2)
return f"{a} + {M}({b} - {c})"
def filename_nothing():
return primary_model_info.model_name
theta_funcs = {
"Weighted sum": (filename_weighed_sum, None, weighted_sum),
"Add difference": (filename_add_differnece, get_difference, add_difference),
"No interpolation": (filename_nothing, None, None),
}
filename_generator, theta_func1, theta_func2 = theta_funcs[interp_method]
shared.state.job_count = (1 if theta_func1 else 0) + (1 if theta_func2 else 0)
if not primary_model_name:
return fail("Failed: Merging requires a primary model.")
primary_model_info = sd_models.checkpoints_list[primary_model_name]
if not secondary_model_name:
if theta_func2 and not secondary_model_name:
return fail("Failed: Merging requires a secondary model.")
secondary_model_info = sd_models.checkpoints_list[secondary_model_name]
theta_funcs = {
"Weighted sum": (None, weighted_sum),
"Add difference": (get_difference, add_difference),
}
theta_func1, theta_func2 = theta_funcs[interp_method]
secondary_model_info = sd_models.checkpoints_list[secondary_model_name] if theta_func2 else None
if theta_func1 and not tertiary_model_name:
return fail(f"Failed: Interpolation method ({interp_method}) requires a tertiary model.")
tertiary_model_info = sd_models.checkpoints_list[tertiary_model_name] if theta_func1 else None
result_is_inpainting_model = False
shared.state.textinfo = f"Loading {secondary_model_info.filename}..."
print(f"Loading {secondary_model_info.filename}...")
theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu')
if theta_func2:
shared.state.textinfo = f"Loading B"
print(f"Loading {secondary_model_info.filename}...")
theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu')
else:
theta_1 = None
if theta_func1:
shared.state.job_count += 1
shared.state.textinfo = f"Loading C"
print(f"Loading {tertiary_model_info.filename}...")
theta_2 = sd_models.read_state_dict(tertiary_model_info.filename, map_location='cpu')
shared.state.textinfo = 'Merging B and C'
shared.state.sampling_steps = len(theta_1.keys())
for key in tqdm.tqdm(theta_1.keys()):
if key in chckpoint_dict_skip_on_merge:
continue
if 'model' in key:
if key in theta_2:
t2 = theta_2.get(key, torch.zeros_like(theta_1[key]))
@ -345,12 +375,10 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cpu')
print("Merging...")
chckpoint_dict_skip_on_merge = ["cond_stage_model.transformer.text_model.embeddings.position_ids"]
shared.state.textinfo = 'Merging A and B'
shared.state.sampling_steps = len(theta_0.keys())
for key in tqdm.tqdm(theta_0.keys()):
if 'model' in key and key in theta_1:
if theta_1 and 'model' in key and key in theta_1:
if key in chckpoint_dict_skip_on_merge:
continue
@ -358,7 +386,6 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
a = theta_0[key]
b = theta_1[key]
shared.state.textinfo = f'Merging layer {key}'
# this enables merging an inpainting model (A) with another one (B);
# where normal model would have 4 channels, for latenst space, inpainting model would
# have another 4 channels for unmasked picture's latent space, plus one channel for mask, for a total of 9
@ -378,34 +405,31 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
shared.state.sampling_step += 1
# I believe this part should be discarded, but I'll leave it for now until I am sure
for key in theta_1.keys():
if 'model' in key and key not in theta_0:
if key in chckpoint_dict_skip_on_merge:
continue
theta_0[key] = theta_1[key]
if save_as_half:
theta_0[key] = theta_0[key].half()
del theta_1
bake_in_vae_filename = sd_vae.vae_dict.get(bake_in_vae, None)
if bake_in_vae_filename is not None:
print(f"Baking in VAE from {bake_in_vae_filename}")
shared.state.textinfo = 'Baking in VAE'
vae_dict = sd_vae.load_vae_dict(bake_in_vae_filename, map_location='cpu')
for key in vae_dict.keys():
theta_0_key = 'first_stage_model.' + key
if theta_0_key in theta_0:
theta_0[theta_0_key] = vae_dict[key].half() if save_as_half else vae_dict[key]
del vae_dict
ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path
filename = \
primary_model_info.model_name + '_' + str(round(1-multiplier, 2)) + '-' + \
secondary_model_info.model_name + '_' + str(round(multiplier, 2)) + '-' + \
interp_method.replace(" ", "_") + \
'-merged.' + \
("inpainting." if result_is_inpainting_model else "") + \
checkpoint_format
filename = filename if custom_name == '' else (custom_name + '.' + checkpoint_format)
filename = filename_generator() if custom_name == '' else custom_name
filename += ".inpainting" if result_is_inpainting_model else ""
filename += "." + checkpoint_format
output_modelname = os.path.join(ckpt_dir, filename)
shared.state.nextjob()
shared.state.textinfo = f"Saving to {output_modelname}..."
shared.state.textinfo = "Saving"
print(f"Saving to {output_modelname}...")
_, extension = os.path.splitext(output_modelname)
@ -418,8 +442,8 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
create_config(output_modelname, config_source, primary_model_info, secondary_model_info, tertiary_model_info)
print("Checkpoint saved.")
shared.state.textinfo = "Checkpoint saved to " + output_modelname
print(f"Checkpoint saved to {output_modelname}.")
shared.state.textinfo = "Checkpoint saved"
shared.state.end()
return [*[gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)], "Checkpoint saved to " + output_modelname]

View file

@ -120,6 +120,12 @@ def resolve_vae(checkpoint_file):
return None, None
def load_vae_dict(filename, map_location):
vae_ckpt = sd_models.read_state_dict(filename, map_location=map_location)
vae_dict_1 = {k: v for k, v in vae_ckpt.items() if k[0:4] != "loss" and k not in vae_ignore_keys}
return vae_dict_1
def load_vae(model, vae_file=None, vae_source="from unknown source"):
global vae_dict, loaded_vae_file
# save_settings = False
@ -137,8 +143,7 @@ def load_vae(model, vae_file=None, vae_source="from unknown source"):
print(f"Loading VAE weights {vae_source}: {vae_file}")
store_base_vae(model)
vae_ckpt = sd_models.read_state_dict(vae_file, map_location=shared.weight_load_location)
vae_dict_1 = {k: v for k, v in vae_ckpt.items() if k[0:4] != "loss" and k not in vae_ignore_keys}
vae_dict_1 = load_vae_dict(vae_file, map_location=shared.weight_load_location)
_load_vae_dict(model, vae_dict_1)
if cache_enabled:

View file

@ -20,10 +20,11 @@ from modules.paths import models_path, script_path, sd_path
demo = None
sd_default_config = os.path.join(script_path, "configs/v1-inference.yaml")
sd_model_file = os.path.join(script_path, 'model.ckpt')
default_sd_model_file = sd_model_file
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default=os.path.join(script_path, "configs/v1-inference.yaml"), help="path to config which constructs model",)
parser.add_argument("--config", type=str, default=sd_default_config, help="path to config which constructs model",)
parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",)
parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints")
parser.add_argument("--vae-dir", type=str, default=None, help="Path to directory with VAE files")

View file

@ -20,7 +20,7 @@ import numpy as np
from PIL import Image, PngImagePlugin
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru
from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae
from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML
from modules.paths import script_path
@ -1185,7 +1185,7 @@ def create_ui():
with gr.Column(variant='compact'):
gr.HTML(value="<p style='margin-bottom: 2.5em'>A merger of the two checkpoints will be generated in your <b>checkpoint</b> directory.</p>")
with FormRow():
with FormRow(elem_id="modelmerger_models"):
primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)")
create_refresh_button(primary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_A")
@ -1197,13 +1197,20 @@ def create_ui():
custom_name = gr.Textbox(label="Custom Name (Optional)", elem_id="modelmerger_custom_name")
interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3, elem_id="modelmerger_interp_amount")
interp_method = gr.Radio(choices=["Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method", elem_id="modelmerger_interp_method")
interp_method = gr.Radio(choices=["No interpolation", "Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method", elem_id="modelmerger_interp_method")
with FormRow():
checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="ckpt", label="Checkpoint format", elem_id="modelmerger_checkpoint_format")
save_as_half = gr.Checkbox(value=False, label="Save as float16", elem_id="modelmerger_save_as_half")
config_source = gr.Radio(choices=["A, B or C", "B", "C", "Don't"], value="A, B or C", label="Copy config from", type="index", elem_id="modelmerger_config_method")
with FormRow():
with gr.Column():
config_source = gr.Radio(choices=["A, B or C", "B", "C", "Don't"], value="A, B or C", label="Copy config from", type="index", elem_id="modelmerger_config_method")
with gr.Column():
with FormRow():
bake_in_vae = gr.Dropdown(choices=["None"] + list(sd_vae.vae_dict), value="None", label="Bake in VAE", elem_id="modelmerger_bake_in_vae")
create_refresh_button(bake_in_vae, sd_vae.refresh_vae_list, lambda: {"choices": ["None"] + list(sd_vae.vae_dict)}, "modelmerger_refresh_bake_in_vae")
with gr.Row():
modelmerger_merge = gr.Button(elem_id="modelmerger_merge", value="Merge", variant='primary')
@ -1757,6 +1764,7 @@ def create_ui():
return [*[gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(4)], f"Error merging checkpoints: {e}"]
return results
modelmerger_merge.click(fn=lambda: '', inputs=[], outputs=[modelmerger_result])
modelmerger_merge.click(
fn=wrap_gradio_gpu_call(modelmerger, extra_outputs=lambda: [gr.update() for _ in range(4)]),
_js='modelmerger',
@ -1771,6 +1779,7 @@ def create_ui():
custom_name,
checkpoint_format,
config_source,
bake_in_vae,
],
outputs=[
primary_model_name,

View file

@ -641,6 +641,16 @@ canvas[key="mask"] {
margin: 0.6em 0em 0.55em 0;
}
#modelmerger_results_container{
margin-top: 1em;
overflow: visible;
}
#modelmerger_models{
gap: 0;
}
#quicksettings .gr-button-tool{
margin: 0;
}
@ -737,11 +747,6 @@ footer {
line-height: 2.4em;
}
#modelmerger_results_container{
margin-top: 1em;
overflow: visible;
}
/* The following handles localization for right-to-left (RTL) languages like Arabic.
The rtl media type will only be activated by the logic in javascript/localization.js.
If you change anything above, you need to make sure it is RTL compliant by just running