allow baking in VAE in checkpoint merger tab
do not save config if it's the default for checkpoint merger tab change file naming scheme for checkpoint merger tab allow just saving A without any merging for checkpoint merger tab some stylistic changes for UI in checkpoint merger tab
This commit is contained in:
parent
c7e50425f6
commit
0f5dbfffd0
7 changed files with 102 additions and 59 deletions
|
@ -92,6 +92,7 @@ titles = {
|
||||||
|
|
||||||
"Weighted sum": "Result = A * (1 - M) + B * M",
|
"Weighted sum": "Result = A * (1 - M) + B * M",
|
||||||
"Add difference": "Result = A + (B - C) * M",
|
"Add difference": "Result = A + (B - C) * M",
|
||||||
|
"No interpolation": "Result = A",
|
||||||
|
|
||||||
"Initialization text": "If the number of tokens is more than the number of vectors, some may be skipped.\nLeave the textbox empty to start with zeroed out vectors",
|
"Initialization text": "If the number of tokens is more than the number of vectors, some may be skipped.\nLeave the textbox empty to start with zeroed out vectors",
|
||||||
"Learning rate": "How fast should training go. Low values will take longer to train, high values may fail to converge (not generate accurate results) and/or may break the embedding (This has happened if you see Loss: nan in the training info textbox. If this happens, you need to manually restore your embedding from an older not-broken backup).\n\nYou can set a single numeric value, or multiple learning rates using the syntax:\n\n rate_1:max_steps_1, rate_2:max_steps_2, ...\n\nEG: 0.005:100, 1e-3:1000, 1e-5\n\nWill train with rate of 0.005 for first 100 steps, then 1e-3 until 1000 steps, then 1e-5 for all remaining steps.",
|
"Learning rate": "How fast should training go. Low values will take longer to train, high values may fail to converge (not generate accurate results) and/or may break the embedding (This has happened if you see Loss: nan in the training info textbox. If this happens, you need to manually restore your embedding from an older not-broken backup).\n\nYou can set a single numeric value, or multiple learning rates using the syntax:\n\n rate_1:max_steps_1, rate_2:max_steps_2, ...\n\nEG: 0.005:100, 1e-3:1000, 1e-5\n\nWill train with rate of 0.005 for first 100 steps, then 1e-3 until 1000 steps, then 1e-5 for all remaining steps.",
|
||||||
|
|
|
@ -176,8 +176,6 @@ function modelmerger(){
|
||||||
var id = randomId()
|
var id = randomId()
|
||||||
requestProgress(id, gradioApp().getElementById('modelmerger_results_panel'), null, function(){})
|
requestProgress(id, gradioApp().getElementById('modelmerger_results_panel'), null, function(){})
|
||||||
|
|
||||||
gradioApp().getElementById('modelmerger_result').innerHTML = ''
|
|
||||||
|
|
||||||
var res = create_submit_args(arguments)
|
var res = create_submit_args(arguments)
|
||||||
res[0] = id
|
res[0] = id
|
||||||
return res
|
return res
|
||||||
|
|
|
@ -15,7 +15,7 @@ from typing import Callable, List, OrderedDict, Tuple
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from modules import processing, shared, images, devices, sd_models, sd_samplers
|
from modules import processing, shared, images, devices, sd_models, sd_samplers, sd_vae
|
||||||
from modules.shared import opts
|
from modules.shared import opts
|
||||||
import modules.gfpgan_model
|
import modules.gfpgan_model
|
||||||
from modules.ui import plaintext_to_html
|
from modules.ui import plaintext_to_html
|
||||||
|
@ -251,7 +251,8 @@ def run_pnginfo(image):
|
||||||
|
|
||||||
def create_config(ckpt_result, config_source, a, b, c):
|
def create_config(ckpt_result, config_source, a, b, c):
|
||||||
def config(x):
|
def config(x):
|
||||||
return sd_models.find_checkpoint_config(x) if x else None
|
res = sd_models.find_checkpoint_config(x) if x else None
|
||||||
|
return res if res != shared.sd_default_config else None
|
||||||
|
|
||||||
if config_source == 0:
|
if config_source == 0:
|
||||||
cfg = config(a) or config(b) or config(c)
|
cfg = config(a) or config(b) or config(c)
|
||||||
|
@ -274,10 +275,12 @@ def create_config(ckpt_result, config_source, a, b, c):
|
||||||
shutil.copyfile(cfg, checkpoint_filename)
|
shutil.copyfile(cfg, checkpoint_filename)
|
||||||
|
|
||||||
|
|
||||||
def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source):
|
chckpoint_dict_skip_on_merge = ["cond_stage_model.transformer.text_model.embeddings.position_ids"]
|
||||||
|
|
||||||
|
|
||||||
|
def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae):
|
||||||
shared.state.begin()
|
shared.state.begin()
|
||||||
shared.state.job = 'model-merge'
|
shared.state.job = 'model-merge'
|
||||||
shared.state.job_count = 1
|
|
||||||
|
|
||||||
def fail(message):
|
def fail(message):
|
||||||
shared.state.textinfo = message
|
shared.state.textinfo = message
|
||||||
|
@ -293,41 +296,68 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
|
||||||
def add_difference(theta0, theta1_2_diff, alpha):
|
def add_difference(theta0, theta1_2_diff, alpha):
|
||||||
return theta0 + (alpha * theta1_2_diff)
|
return theta0 + (alpha * theta1_2_diff)
|
||||||
|
|
||||||
|
def filename_weighed_sum():
|
||||||
|
a = primary_model_info.model_name
|
||||||
|
b = secondary_model_info.model_name
|
||||||
|
Ma = round(1 - multiplier, 2)
|
||||||
|
Mb = round(multiplier, 2)
|
||||||
|
|
||||||
|
return f"{Ma}({a}) + {Mb}({b})"
|
||||||
|
|
||||||
|
def filename_add_differnece():
|
||||||
|
a = primary_model_info.model_name
|
||||||
|
b = secondary_model_info.model_name
|
||||||
|
c = tertiary_model_info.model_name
|
||||||
|
M = round(multiplier, 2)
|
||||||
|
|
||||||
|
return f"{a} + {M}({b} - {c})"
|
||||||
|
|
||||||
|
def filename_nothing():
|
||||||
|
return primary_model_info.model_name
|
||||||
|
|
||||||
|
theta_funcs = {
|
||||||
|
"Weighted sum": (filename_weighed_sum, None, weighted_sum),
|
||||||
|
"Add difference": (filename_add_differnece, get_difference, add_difference),
|
||||||
|
"No interpolation": (filename_nothing, None, None),
|
||||||
|
}
|
||||||
|
filename_generator, theta_func1, theta_func2 = theta_funcs[interp_method]
|
||||||
|
shared.state.job_count = (1 if theta_func1 else 0) + (1 if theta_func2 else 0)
|
||||||
|
|
||||||
if not primary_model_name:
|
if not primary_model_name:
|
||||||
return fail("Failed: Merging requires a primary model.")
|
return fail("Failed: Merging requires a primary model.")
|
||||||
|
|
||||||
primary_model_info = sd_models.checkpoints_list[primary_model_name]
|
primary_model_info = sd_models.checkpoints_list[primary_model_name]
|
||||||
|
|
||||||
if not secondary_model_name:
|
if theta_func2 and not secondary_model_name:
|
||||||
return fail("Failed: Merging requires a secondary model.")
|
return fail("Failed: Merging requires a secondary model.")
|
||||||
|
|
||||||
secondary_model_info = sd_models.checkpoints_list[secondary_model_name]
|
|
||||||
|
|
||||||
theta_funcs = {
|
secondary_model_info = sd_models.checkpoints_list[secondary_model_name] if theta_func2 else None
|
||||||
"Weighted sum": (None, weighted_sum),
|
|
||||||
"Add difference": (get_difference, add_difference),
|
|
||||||
}
|
|
||||||
theta_func1, theta_func2 = theta_funcs[interp_method]
|
|
||||||
|
|
||||||
if theta_func1 and not tertiary_model_name:
|
if theta_func1 and not tertiary_model_name:
|
||||||
return fail(f"Failed: Interpolation method ({interp_method}) requires a tertiary model.")
|
return fail(f"Failed: Interpolation method ({interp_method}) requires a tertiary model.")
|
||||||
|
|
||||||
tertiary_model_info = sd_models.checkpoints_list[tertiary_model_name] if theta_func1 else None
|
tertiary_model_info = sd_models.checkpoints_list[tertiary_model_name] if theta_func1 else None
|
||||||
|
|
||||||
result_is_inpainting_model = False
|
result_is_inpainting_model = False
|
||||||
|
|
||||||
shared.state.textinfo = f"Loading {secondary_model_info.filename}..."
|
if theta_func2:
|
||||||
print(f"Loading {secondary_model_info.filename}...")
|
shared.state.textinfo = f"Loading B"
|
||||||
theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu')
|
print(f"Loading {secondary_model_info.filename}...")
|
||||||
|
theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu')
|
||||||
|
else:
|
||||||
|
theta_1 = None
|
||||||
|
|
||||||
if theta_func1:
|
if theta_func1:
|
||||||
shared.state.job_count += 1
|
shared.state.textinfo = f"Loading C"
|
||||||
|
|
||||||
print(f"Loading {tertiary_model_info.filename}...")
|
print(f"Loading {tertiary_model_info.filename}...")
|
||||||
theta_2 = sd_models.read_state_dict(tertiary_model_info.filename, map_location='cpu')
|
theta_2 = sd_models.read_state_dict(tertiary_model_info.filename, map_location='cpu')
|
||||||
|
|
||||||
|
shared.state.textinfo = 'Merging B and C'
|
||||||
shared.state.sampling_steps = len(theta_1.keys())
|
shared.state.sampling_steps = len(theta_1.keys())
|
||||||
for key in tqdm.tqdm(theta_1.keys()):
|
for key in tqdm.tqdm(theta_1.keys()):
|
||||||
|
if key in chckpoint_dict_skip_on_merge:
|
||||||
|
continue
|
||||||
|
|
||||||
if 'model' in key:
|
if 'model' in key:
|
||||||
if key in theta_2:
|
if key in theta_2:
|
||||||
t2 = theta_2.get(key, torch.zeros_like(theta_1[key]))
|
t2 = theta_2.get(key, torch.zeros_like(theta_1[key]))
|
||||||
|
@ -345,12 +375,10 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
|
||||||
theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cpu')
|
theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cpu')
|
||||||
|
|
||||||
print("Merging...")
|
print("Merging...")
|
||||||
|
shared.state.textinfo = 'Merging A and B'
|
||||||
chckpoint_dict_skip_on_merge = ["cond_stage_model.transformer.text_model.embeddings.position_ids"]
|
|
||||||
|
|
||||||
shared.state.sampling_steps = len(theta_0.keys())
|
shared.state.sampling_steps = len(theta_0.keys())
|
||||||
for key in tqdm.tqdm(theta_0.keys()):
|
for key in tqdm.tqdm(theta_0.keys()):
|
||||||
if 'model' in key and key in theta_1:
|
if theta_1 and 'model' in key and key in theta_1:
|
||||||
|
|
||||||
if key in chckpoint_dict_skip_on_merge:
|
if key in chckpoint_dict_skip_on_merge:
|
||||||
continue
|
continue
|
||||||
|
@ -358,7 +386,6 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
|
||||||
a = theta_0[key]
|
a = theta_0[key]
|
||||||
b = theta_1[key]
|
b = theta_1[key]
|
||||||
|
|
||||||
shared.state.textinfo = f'Merging layer {key}'
|
|
||||||
# this enables merging an inpainting model (A) with another one (B);
|
# this enables merging an inpainting model (A) with another one (B);
|
||||||
# where normal model would have 4 channels, for latenst space, inpainting model would
|
# where normal model would have 4 channels, for latenst space, inpainting model would
|
||||||
# have another 4 channels for unmasked picture's latent space, plus one channel for mask, for a total of 9
|
# have another 4 channels for unmasked picture's latent space, plus one channel for mask, for a total of 9
|
||||||
|
@ -378,34 +405,31 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
|
||||||
|
|
||||||
shared.state.sampling_step += 1
|
shared.state.sampling_step += 1
|
||||||
|
|
||||||
# I believe this part should be discarded, but I'll leave it for now until I am sure
|
|
||||||
for key in theta_1.keys():
|
|
||||||
if 'model' in key and key not in theta_0:
|
|
||||||
|
|
||||||
if key in chckpoint_dict_skip_on_merge:
|
|
||||||
continue
|
|
||||||
|
|
||||||
theta_0[key] = theta_1[key]
|
|
||||||
if save_as_half:
|
|
||||||
theta_0[key] = theta_0[key].half()
|
|
||||||
del theta_1
|
del theta_1
|
||||||
|
|
||||||
|
bake_in_vae_filename = sd_vae.vae_dict.get(bake_in_vae, None)
|
||||||
|
if bake_in_vae_filename is not None:
|
||||||
|
print(f"Baking in VAE from {bake_in_vae_filename}")
|
||||||
|
shared.state.textinfo = 'Baking in VAE'
|
||||||
|
vae_dict = sd_vae.load_vae_dict(bake_in_vae_filename, map_location='cpu')
|
||||||
|
|
||||||
|
for key in vae_dict.keys():
|
||||||
|
theta_0_key = 'first_stage_model.' + key
|
||||||
|
if theta_0_key in theta_0:
|
||||||
|
theta_0[theta_0_key] = vae_dict[key].half() if save_as_half else vae_dict[key]
|
||||||
|
|
||||||
|
del vae_dict
|
||||||
|
|
||||||
ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path
|
ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path
|
||||||
|
|
||||||
filename = \
|
filename = filename_generator() if custom_name == '' else custom_name
|
||||||
primary_model_info.model_name + '_' + str(round(1-multiplier, 2)) + '-' + \
|
filename += ".inpainting" if result_is_inpainting_model else ""
|
||||||
secondary_model_info.model_name + '_' + str(round(multiplier, 2)) + '-' + \
|
filename += "." + checkpoint_format
|
||||||
interp_method.replace(" ", "_") + \
|
|
||||||
'-merged.' + \
|
|
||||||
("inpainting." if result_is_inpainting_model else "") + \
|
|
||||||
checkpoint_format
|
|
||||||
|
|
||||||
filename = filename if custom_name == '' else (custom_name + '.' + checkpoint_format)
|
|
||||||
|
|
||||||
output_modelname = os.path.join(ckpt_dir, filename)
|
output_modelname = os.path.join(ckpt_dir, filename)
|
||||||
|
|
||||||
shared.state.nextjob()
|
shared.state.nextjob()
|
||||||
shared.state.textinfo = f"Saving to {output_modelname}..."
|
shared.state.textinfo = "Saving"
|
||||||
print(f"Saving to {output_modelname}...")
|
print(f"Saving to {output_modelname}...")
|
||||||
|
|
||||||
_, extension = os.path.splitext(output_modelname)
|
_, extension = os.path.splitext(output_modelname)
|
||||||
|
@ -418,8 +442,8 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
|
||||||
|
|
||||||
create_config(output_modelname, config_source, primary_model_info, secondary_model_info, tertiary_model_info)
|
create_config(output_modelname, config_source, primary_model_info, secondary_model_info, tertiary_model_info)
|
||||||
|
|
||||||
print("Checkpoint saved.")
|
print(f"Checkpoint saved to {output_modelname}.")
|
||||||
shared.state.textinfo = "Checkpoint saved to " + output_modelname
|
shared.state.textinfo = "Checkpoint saved"
|
||||||
shared.state.end()
|
shared.state.end()
|
||||||
|
|
||||||
return [*[gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)], "Checkpoint saved to " + output_modelname]
|
return [*[gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)], "Checkpoint saved to " + output_modelname]
|
||||||
|
|
|
@ -120,6 +120,12 @@ def resolve_vae(checkpoint_file):
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
|
|
||||||
|
def load_vae_dict(filename, map_location):
|
||||||
|
vae_ckpt = sd_models.read_state_dict(filename, map_location=map_location)
|
||||||
|
vae_dict_1 = {k: v for k, v in vae_ckpt.items() if k[0:4] != "loss" and k not in vae_ignore_keys}
|
||||||
|
return vae_dict_1
|
||||||
|
|
||||||
|
|
||||||
def load_vae(model, vae_file=None, vae_source="from unknown source"):
|
def load_vae(model, vae_file=None, vae_source="from unknown source"):
|
||||||
global vae_dict, loaded_vae_file
|
global vae_dict, loaded_vae_file
|
||||||
# save_settings = False
|
# save_settings = False
|
||||||
|
@ -137,8 +143,7 @@ def load_vae(model, vae_file=None, vae_source="from unknown source"):
|
||||||
print(f"Loading VAE weights {vae_source}: {vae_file}")
|
print(f"Loading VAE weights {vae_source}: {vae_file}")
|
||||||
store_base_vae(model)
|
store_base_vae(model)
|
||||||
|
|
||||||
vae_ckpt = sd_models.read_state_dict(vae_file, map_location=shared.weight_load_location)
|
vae_dict_1 = load_vae_dict(vae_file, map_location=shared.weight_load_location)
|
||||||
vae_dict_1 = {k: v for k, v in vae_ckpt.items() if k[0:4] != "loss" and k not in vae_ignore_keys}
|
|
||||||
_load_vae_dict(model, vae_dict_1)
|
_load_vae_dict(model, vae_dict_1)
|
||||||
|
|
||||||
if cache_enabled:
|
if cache_enabled:
|
||||||
|
|
|
@ -20,10 +20,11 @@ from modules.paths import models_path, script_path, sd_path
|
||||||
|
|
||||||
demo = None
|
demo = None
|
||||||
|
|
||||||
|
sd_default_config = os.path.join(script_path, "configs/v1-inference.yaml")
|
||||||
sd_model_file = os.path.join(script_path, 'model.ckpt')
|
sd_model_file = os.path.join(script_path, 'model.ckpt')
|
||||||
default_sd_model_file = sd_model_file
|
default_sd_model_file = sd_model_file
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--config", type=str, default=os.path.join(script_path, "configs/v1-inference.yaml"), help="path to config which constructs model",)
|
parser.add_argument("--config", type=str, default=sd_default_config, help="path to config which constructs model",)
|
||||||
parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",)
|
parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",)
|
||||||
parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints")
|
parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints")
|
||||||
parser.add_argument("--vae-dir", type=str, default=None, help="Path to directory with VAE files")
|
parser.add_argument("--vae-dir", type=str, default=None, help="Path to directory with VAE files")
|
||||||
|
|
|
@ -20,7 +20,7 @@ import numpy as np
|
||||||
from PIL import Image, PngImagePlugin
|
from PIL import Image, PngImagePlugin
|
||||||
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
|
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
|
||||||
|
|
||||||
from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru
|
from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae
|
||||||
from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML
|
from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML
|
||||||
from modules.paths import script_path
|
from modules.paths import script_path
|
||||||
|
|
||||||
|
@ -1185,7 +1185,7 @@ def create_ui():
|
||||||
with gr.Column(variant='compact'):
|
with gr.Column(variant='compact'):
|
||||||
gr.HTML(value="<p style='margin-bottom: 2.5em'>A merger of the two checkpoints will be generated in your <b>checkpoint</b> directory.</p>")
|
gr.HTML(value="<p style='margin-bottom: 2.5em'>A merger of the two checkpoints will be generated in your <b>checkpoint</b> directory.</p>")
|
||||||
|
|
||||||
with FormRow():
|
with FormRow(elem_id="modelmerger_models"):
|
||||||
primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)")
|
primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)")
|
||||||
create_refresh_button(primary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_A")
|
create_refresh_button(primary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_A")
|
||||||
|
|
||||||
|
@ -1197,13 +1197,20 @@ def create_ui():
|
||||||
|
|
||||||
custom_name = gr.Textbox(label="Custom Name (Optional)", elem_id="modelmerger_custom_name")
|
custom_name = gr.Textbox(label="Custom Name (Optional)", elem_id="modelmerger_custom_name")
|
||||||
interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3, elem_id="modelmerger_interp_amount")
|
interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3, elem_id="modelmerger_interp_amount")
|
||||||
interp_method = gr.Radio(choices=["Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method", elem_id="modelmerger_interp_method")
|
interp_method = gr.Radio(choices=["No interpolation", "Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method", elem_id="modelmerger_interp_method")
|
||||||
|
|
||||||
with FormRow():
|
with FormRow():
|
||||||
checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="ckpt", label="Checkpoint format", elem_id="modelmerger_checkpoint_format")
|
checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="ckpt", label="Checkpoint format", elem_id="modelmerger_checkpoint_format")
|
||||||
save_as_half = gr.Checkbox(value=False, label="Save as float16", elem_id="modelmerger_save_as_half")
|
save_as_half = gr.Checkbox(value=False, label="Save as float16", elem_id="modelmerger_save_as_half")
|
||||||
|
|
||||||
config_source = gr.Radio(choices=["A, B or C", "B", "C", "Don't"], value="A, B or C", label="Copy config from", type="index", elem_id="modelmerger_config_method")
|
with FormRow():
|
||||||
|
with gr.Column():
|
||||||
|
config_source = gr.Radio(choices=["A, B or C", "B", "C", "Don't"], value="A, B or C", label="Copy config from", type="index", elem_id="modelmerger_config_method")
|
||||||
|
|
||||||
|
with gr.Column():
|
||||||
|
with FormRow():
|
||||||
|
bake_in_vae = gr.Dropdown(choices=["None"] + list(sd_vae.vae_dict), value="None", label="Bake in VAE", elem_id="modelmerger_bake_in_vae")
|
||||||
|
create_refresh_button(bake_in_vae, sd_vae.refresh_vae_list, lambda: {"choices": ["None"] + list(sd_vae.vae_dict)}, "modelmerger_refresh_bake_in_vae")
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
modelmerger_merge = gr.Button(elem_id="modelmerger_merge", value="Merge", variant='primary')
|
modelmerger_merge = gr.Button(elem_id="modelmerger_merge", value="Merge", variant='primary')
|
||||||
|
@ -1757,6 +1764,7 @@ def create_ui():
|
||||||
return [*[gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(4)], f"Error merging checkpoints: {e}"]
|
return [*[gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(4)], f"Error merging checkpoints: {e}"]
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
modelmerger_merge.click(fn=lambda: '', inputs=[], outputs=[modelmerger_result])
|
||||||
modelmerger_merge.click(
|
modelmerger_merge.click(
|
||||||
fn=wrap_gradio_gpu_call(modelmerger, extra_outputs=lambda: [gr.update() for _ in range(4)]),
|
fn=wrap_gradio_gpu_call(modelmerger, extra_outputs=lambda: [gr.update() for _ in range(4)]),
|
||||||
_js='modelmerger',
|
_js='modelmerger',
|
||||||
|
@ -1771,6 +1779,7 @@ def create_ui():
|
||||||
custom_name,
|
custom_name,
|
||||||
checkpoint_format,
|
checkpoint_format,
|
||||||
config_source,
|
config_source,
|
||||||
|
bake_in_vae,
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
primary_model_name,
|
primary_model_name,
|
||||||
|
|
15
style.css
15
style.css
|
@ -641,6 +641,16 @@ canvas[key="mask"] {
|
||||||
margin: 0.6em 0em 0.55em 0;
|
margin: 0.6em 0em 0.55em 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#modelmerger_results_container{
|
||||||
|
margin-top: 1em;
|
||||||
|
overflow: visible;
|
||||||
|
}
|
||||||
|
|
||||||
|
#modelmerger_models{
|
||||||
|
gap: 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
#quicksettings .gr-button-tool{
|
#quicksettings .gr-button-tool{
|
||||||
margin: 0;
|
margin: 0;
|
||||||
}
|
}
|
||||||
|
@ -737,11 +747,6 @@ footer {
|
||||||
line-height: 2.4em;
|
line-height: 2.4em;
|
||||||
}
|
}
|
||||||
|
|
||||||
#modelmerger_results_container{
|
|
||||||
margin-top: 1em;
|
|
||||||
overflow: visible;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* The following handles localization for right-to-left (RTL) languages like Arabic.
|
/* The following handles localization for right-to-left (RTL) languages like Arabic.
|
||||||
The rtl media type will only be activated by the logic in javascript/localization.js.
|
The rtl media type will only be activated by the logic in javascript/localization.js.
|
||||||
If you change anything above, you need to make sure it is RTL compliant by just running
|
If you change anything above, you need to make sure it is RTL compliant by just running
|
||||||
|
|
Loading…
Reference in a new issue