Merge branch 'AUTOMATIC1111:master' into fix/alternating-words-emphasis
This commit is contained in:
commit
035f2af050
6 changed files with 46 additions and 7 deletions
|
@ -151,6 +151,7 @@ function showGalleryImage() {
|
||||||
e.addEventListener('mousedown', function (evt) {
|
e.addEventListener('mousedown', function (evt) {
|
||||||
if(!opts.js_modal_lightbox || evt.button != 0) return;
|
if(!opts.js_modal_lightbox || evt.button != 0) return;
|
||||||
modalZoomSet(gradioApp().getElementById('modalImage'), opts.js_modal_lightbox_initially_zoomed)
|
modalZoomSet(gradioApp().getElementById('modalImage'), opts.js_modal_lightbox_initially_zoomed)
|
||||||
|
evt.preventDefault()
|
||||||
showModal(evt)
|
showModal(evt)
|
||||||
}, true);
|
}, true);
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,6 +3,7 @@ import math
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
|
import shutil
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
@ -248,7 +249,32 @@ def run_pnginfo(image):
|
||||||
return '', geninfo, info
|
return '', geninfo, info
|
||||||
|
|
||||||
|
|
||||||
def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format):
|
def create_config(ckpt_result, config_source, a, b, c):
|
||||||
|
def config(x):
|
||||||
|
return sd_models.find_checkpoint_config(x) if x else None
|
||||||
|
|
||||||
|
if config_source == 0:
|
||||||
|
cfg = config(a) or config(b) or config(c)
|
||||||
|
elif config_source == 1:
|
||||||
|
cfg = config(b)
|
||||||
|
elif config_source == 2:
|
||||||
|
cfg = config(c)
|
||||||
|
else:
|
||||||
|
cfg = None
|
||||||
|
|
||||||
|
if cfg is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
filename, _ = os.path.splitext(ckpt_result)
|
||||||
|
checkpoint_filename = filename + ".yaml"
|
||||||
|
|
||||||
|
print("Copying config:")
|
||||||
|
print(" from:", cfg)
|
||||||
|
print(" to:", checkpoint_filename)
|
||||||
|
shutil.copyfile(cfg, checkpoint_filename)
|
||||||
|
|
||||||
|
|
||||||
|
def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source):
|
||||||
shared.state.begin()
|
shared.state.begin()
|
||||||
shared.state.job = 'model-merge'
|
shared.state.job = 'model-merge'
|
||||||
|
|
||||||
|
@ -356,6 +382,8 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
|
||||||
|
|
||||||
sd_models.list_models()
|
sd_models.list_models()
|
||||||
|
|
||||||
|
create_config(output_modelname, config_source, primary_model_info, secondary_model_info, tertiary_model_info)
|
||||||
|
|
||||||
print("Checkpoint saved.")
|
print("Checkpoint saved.")
|
||||||
shared.state.textinfo = "Checkpoint saved to " + output_modelname
|
shared.state.textinfo = "Checkpoint saved to " + output_modelname
|
||||||
shared.state.end()
|
shared.state.end()
|
||||||
|
|
|
@ -152,7 +152,7 @@ def basedir():
|
||||||
|
|
||||||
scripts_data = []
|
scripts_data = []
|
||||||
ScriptFile = namedtuple("ScriptFile", ["basedir", "filename", "path"])
|
ScriptFile = namedtuple("ScriptFile", ["basedir", "filename", "path"])
|
||||||
ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir"])
|
ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir", "module"])
|
||||||
|
|
||||||
|
|
||||||
def list_scripts(scriptdirname, extension):
|
def list_scripts(scriptdirname, extension):
|
||||||
|
@ -206,7 +206,7 @@ def load_scripts():
|
||||||
|
|
||||||
for key, script_class in module.__dict__.items():
|
for key, script_class in module.__dict__.items():
|
||||||
if type(script_class) == type and issubclass(script_class, Script):
|
if type(script_class) == type and issubclass(script_class, Script):
|
||||||
scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir))
|
scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module))
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
print(f"Error loading script: {scriptfile.filename}", file=sys.stderr)
|
print(f"Error loading script: {scriptfile.filename}", file=sys.stderr)
|
||||||
|
@ -241,7 +241,7 @@ class ScriptRunner:
|
||||||
self.alwayson_scripts.clear()
|
self.alwayson_scripts.clear()
|
||||||
self.selectable_scripts.clear()
|
self.selectable_scripts.clear()
|
||||||
|
|
||||||
for script_class, path, basedir in scripts_data:
|
for script_class, path, basedir, script_module in scripts_data:
|
||||||
script = script_class()
|
script = script_class()
|
||||||
script.filename = path
|
script.filename = path
|
||||||
script.is_txt2img = not is_img2img
|
script.is_txt2img = not is_img2img
|
||||||
|
|
|
@ -333,10 +333,14 @@ def load_model(checkpoint_info=None):
|
||||||
|
|
||||||
timer = Timer()
|
timer = Timer()
|
||||||
|
|
||||||
|
sd_model = None
|
||||||
try:
|
try:
|
||||||
with sd_disable_initialization.DisableInitialization():
|
with sd_disable_initialization.DisableInitialization():
|
||||||
sd_model = instantiate_from_config(sd_config.model)
|
sd_model = instantiate_from_config(sd_config.model)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if sd_model is None:
|
||||||
print('Failed to create model quickly; will retry using slow method.', file=sys.stderr)
|
print('Failed to create model quickly; will retry using slow method.', file=sys.stderr)
|
||||||
sd_model = instantiate_from_config(sd_config.model)
|
sd_model = instantiate_from_config(sd_config.model)
|
||||||
|
|
||||||
|
|
|
@ -9,6 +9,7 @@ import tqdm
|
||||||
import html
|
import html
|
||||||
import datetime
|
import datetime
|
||||||
import csv
|
import csv
|
||||||
|
import safetensors.torch
|
||||||
|
|
||||||
from PIL import Image, PngImagePlugin
|
from PIL import Image, PngImagePlugin
|
||||||
|
|
||||||
|
@ -150,6 +151,8 @@ class EmbeddingDatabase:
|
||||||
name = data.get('name', name)
|
name = data.get('name', name)
|
||||||
elif ext in ['.BIN', '.PT']:
|
elif ext in ['.BIN', '.PT']:
|
||||||
data = torch.load(path, map_location="cpu")
|
data = torch.load(path, map_location="cpu")
|
||||||
|
elif ext in ['.SAFETENSORS']:
|
||||||
|
data = safetensors.torch.load_file(path, device="cpu")
|
||||||
else:
|
else:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
|
@ -1129,7 +1129,7 @@ def create_ui():
|
||||||
with gr.Column(variant='panel'):
|
with gr.Column(variant='panel'):
|
||||||
gr.HTML(value="<p>A merger of the two checkpoints will be generated in your <b>checkpoint</b> directory.</p>")
|
gr.HTML(value="<p>A merger of the two checkpoints will be generated in your <b>checkpoint</b> directory.</p>")
|
||||||
|
|
||||||
with gr.Row():
|
with FormRow():
|
||||||
primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)")
|
primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)")
|
||||||
create_refresh_button(primary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_A")
|
create_refresh_button(primary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_A")
|
||||||
|
|
||||||
|
@ -1143,11 +1143,13 @@ def create_ui():
|
||||||
interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3, elem_id="modelmerger_interp_amount")
|
interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3, elem_id="modelmerger_interp_amount")
|
||||||
interp_method = gr.Radio(choices=["Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method", elem_id="modelmerger_interp_method")
|
interp_method = gr.Radio(choices=["Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method", elem_id="modelmerger_interp_method")
|
||||||
|
|
||||||
with gr.Row():
|
with FormRow():
|
||||||
checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="ckpt", label="Checkpoint format", elem_id="modelmerger_checkpoint_format")
|
checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="ckpt", label="Checkpoint format", elem_id="modelmerger_checkpoint_format")
|
||||||
save_as_half = gr.Checkbox(value=False, label="Save as float16", elem_id="modelmerger_save_as_half")
|
save_as_half = gr.Checkbox(value=False, label="Save as float16", elem_id="modelmerger_save_as_half")
|
||||||
|
|
||||||
modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary')
|
config_source = gr.Radio(choices=["A, B or C", "B", "C", "Don't"], value="A, B or C", label="Copy config from", type="index", elem_id="modelmerger_config_method")
|
||||||
|
|
||||||
|
modelmerger_merge = gr.Button(elem_id="modelmerger_merge", value="Merge", variant='primary')
|
||||||
|
|
||||||
with gr.Column(variant='panel'):
|
with gr.Column(variant='panel'):
|
||||||
submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False)
|
submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False)
|
||||||
|
@ -1703,6 +1705,7 @@ def create_ui():
|
||||||
save_as_half,
|
save_as_half,
|
||||||
custom_name,
|
custom_name,
|
||||||
checkpoint_format,
|
checkpoint_format,
|
||||||
|
config_source,
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
submit_result,
|
submit_result,
|
||||||
|
|
Loading…
Reference in a new issue