Merge branch 'AUTOMATIC1111:master' into fix/alternating-words-emphasis

This commit is contained in:
catboxanon 2023-01-11 08:58:43 -05:00 committed by GitHub
commit 035f2af050
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 46 additions and 7 deletions

View file

@ -151,6 +151,7 @@ function showGalleryImage() {
e.addEventListener('mousedown', function (evt) { e.addEventListener('mousedown', function (evt) {
if(!opts.js_modal_lightbox || evt.button != 0) return; if(!opts.js_modal_lightbox || evt.button != 0) return;
modalZoomSet(gradioApp().getElementById('modalImage'), opts.js_modal_lightbox_initially_zoomed) modalZoomSet(gradioApp().getElementById('modalImage'), opts.js_modal_lightbox_initially_zoomed)
evt.preventDefault()
showModal(evt) showModal(evt)
}, true); }, true);
} }

View file

@ -3,6 +3,7 @@ import math
import os import os
import sys import sys
import traceback import traceback
import shutil
import numpy as np import numpy as np
from PIL import Image from PIL import Image
@ -248,7 +249,32 @@ def run_pnginfo(image):
return '', geninfo, info return '', geninfo, info
def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format): def create_config(ckpt_result, config_source, a, b, c):
def config(x):
return sd_models.find_checkpoint_config(x) if x else None
if config_source == 0:
cfg = config(a) or config(b) or config(c)
elif config_source == 1:
cfg = config(b)
elif config_source == 2:
cfg = config(c)
else:
cfg = None
if cfg is None:
return
filename, _ = os.path.splitext(ckpt_result)
checkpoint_filename = filename + ".yaml"
print("Copying config:")
print(" from:", cfg)
print(" to:", checkpoint_filename)
shutil.copyfile(cfg, checkpoint_filename)
def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source):
shared.state.begin() shared.state.begin()
shared.state.job = 'model-merge' shared.state.job = 'model-merge'
@ -356,6 +382,8 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
sd_models.list_models() sd_models.list_models()
create_config(output_modelname, config_source, primary_model_info, secondary_model_info, tertiary_model_info)
print("Checkpoint saved.") print("Checkpoint saved.")
shared.state.textinfo = "Checkpoint saved to " + output_modelname shared.state.textinfo = "Checkpoint saved to " + output_modelname
shared.state.end() shared.state.end()

View file

@ -152,7 +152,7 @@ def basedir():
scripts_data = [] scripts_data = []
ScriptFile = namedtuple("ScriptFile", ["basedir", "filename", "path"]) ScriptFile = namedtuple("ScriptFile", ["basedir", "filename", "path"])
ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir"]) ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir", "module"])
def list_scripts(scriptdirname, extension): def list_scripts(scriptdirname, extension):
@ -206,7 +206,7 @@ def load_scripts():
for key, script_class in module.__dict__.items(): for key, script_class in module.__dict__.items():
if type(script_class) == type and issubclass(script_class, Script): if type(script_class) == type and issubclass(script_class, Script):
scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir)) scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module))
except Exception: except Exception:
print(f"Error loading script: {scriptfile.filename}", file=sys.stderr) print(f"Error loading script: {scriptfile.filename}", file=sys.stderr)
@ -241,7 +241,7 @@ class ScriptRunner:
self.alwayson_scripts.clear() self.alwayson_scripts.clear()
self.selectable_scripts.clear() self.selectable_scripts.clear()
for script_class, path, basedir in scripts_data: for script_class, path, basedir, script_module in scripts_data:
script = script_class() script = script_class()
script.filename = path script.filename = path
script.is_txt2img = not is_img2img script.is_txt2img = not is_img2img

View file

@ -333,10 +333,14 @@ def load_model(checkpoint_info=None):
timer = Timer() timer = Timer()
sd_model = None
try: try:
with sd_disable_initialization.DisableInitialization(): with sd_disable_initialization.DisableInitialization():
sd_model = instantiate_from_config(sd_config.model) sd_model = instantiate_from_config(sd_config.model)
except Exception as e: except Exception as e:
pass
if sd_model is None:
print('Failed to create model quickly; will retry using slow method.', file=sys.stderr) print('Failed to create model quickly; will retry using slow method.', file=sys.stderr)
sd_model = instantiate_from_config(sd_config.model) sd_model = instantiate_from_config(sd_config.model)

View file

@ -9,6 +9,7 @@ import tqdm
import html import html
import datetime import datetime
import csv import csv
import safetensors.torch
from PIL import Image, PngImagePlugin from PIL import Image, PngImagePlugin
@ -150,6 +151,8 @@ class EmbeddingDatabase:
name = data.get('name', name) name = data.get('name', name)
elif ext in ['.BIN', '.PT']: elif ext in ['.BIN', '.PT']:
data = torch.load(path, map_location="cpu") data = torch.load(path, map_location="cpu")
elif ext in ['.SAFETENSORS']:
data = safetensors.torch.load_file(path, device="cpu")
else: else:
return return

View file

@ -1129,7 +1129,7 @@ def create_ui():
with gr.Column(variant='panel'): with gr.Column(variant='panel'):
gr.HTML(value="<p>A merger of the two checkpoints will be generated in your <b>checkpoint</b> directory.</p>") gr.HTML(value="<p>A merger of the two checkpoints will be generated in your <b>checkpoint</b> directory.</p>")
with gr.Row(): with FormRow():
primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)") primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)")
create_refresh_button(primary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_A") create_refresh_button(primary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_A")
@ -1143,11 +1143,13 @@ def create_ui():
interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3, elem_id="modelmerger_interp_amount") interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3, elem_id="modelmerger_interp_amount")
interp_method = gr.Radio(choices=["Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method", elem_id="modelmerger_interp_method") interp_method = gr.Radio(choices=["Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method", elem_id="modelmerger_interp_method")
with gr.Row(): with FormRow():
checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="ckpt", label="Checkpoint format", elem_id="modelmerger_checkpoint_format") checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="ckpt", label="Checkpoint format", elem_id="modelmerger_checkpoint_format")
save_as_half = gr.Checkbox(value=False, label="Save as float16", elem_id="modelmerger_save_as_half") save_as_half = gr.Checkbox(value=False, label="Save as float16", elem_id="modelmerger_save_as_half")
modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary') config_source = gr.Radio(choices=["A, B or C", "B", "C", "Don't"], value="A, B or C", label="Copy config from", type="index", elem_id="modelmerger_config_method")
modelmerger_merge = gr.Button(elem_id="modelmerger_merge", value="Merge", variant='primary')
with gr.Column(variant='panel'): with gr.Column(variant='panel'):
submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False) submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False)
@ -1703,6 +1705,7 @@ def create_ui():
save_as_half, save_as_half,
custom_name, custom_name,
checkpoint_format, checkpoint_format,
config_source,
], ],
outputs=[ outputs=[
submit_result, submit_result,