Merge branch 'master' into test_resolve_conflicts

This commit is contained in:
MalumaDev 2022-10-18 08:55:08 +02:00 committed by GitHub
commit 1997ccff13
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
16 changed files with 287 additions and 60 deletions

View file

@ -22,6 +22,12 @@ jobs:
uses: actions/setup-python@v3 uses: actions/setup-python@v3
with: with:
python-version: 3.10.6 python-version: 3.10.6
- uses: actions/cache@v2
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }}
restore-keys: |
${{ runner.os }}-pip-
- name: Install PyLint - name: Install PyLint
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip

View file

@ -31,8 +31,8 @@ function imageMaskResize() {
wrapper.style.width = `${wW}px`; wrapper.style.width = `${wW}px`;
wrapper.style.height = `${wH}px`; wrapper.style.height = `${wH}px`;
wrapper.style.left = `${(w-wW)/2}px`; wrapper.style.left = `0px`;
wrapper.style.top = `${(h-wH)/2}px`; wrapper.style.top = `0px`;
canvases.forEach( c => { canvases.forEach( c => {
c.style.width = c.style.height = ''; c.style.width = c.style.height = '';

View file

@ -116,6 +116,7 @@ function showGalleryImage() {
e.dataset.modded = true; e.dataset.modded = true;
if(e && e.parentElement.tagName == 'DIV'){ if(e && e.parentElement.tagName == 'DIV'){
e.style.cursor='pointer' e.style.cursor='pointer'
e.style.userSelect='none'
e.addEventListener('click', function (evt) { e.addEventListener('click', function (evt) {
if(!opts.js_modal_lightbox) return; if(!opts.js_modal_lightbox) return;
modalZoomSet(gradioApp().getElementById('modalImage'), opts.js_modal_lightbox_initially_zoomed) modalZoomSet(gradioApp().getElementById('modalImage'), opts.js_modal_lightbox_initially_zoomed)

146
javascript/localization.js Normal file
View file

@ -0,0 +1,146 @@
// localization = {} -- the dict with translations is created by the backend
ignore_ids_for_localization={
setting_sd_hypernetwork: 'OPTION',
setting_sd_model_checkpoint: 'OPTION',
setting_realesrgan_enabled_models: 'OPTION',
modelmerger_primary_model_name: 'OPTION',
modelmerger_secondary_model_name: 'OPTION',
modelmerger_tertiary_model_name: 'OPTION',
train_embedding: 'OPTION',
train_hypernetwork: 'OPTION',
txt2img_style_index: 'OPTION',
txt2img_style2_index: 'OPTION',
img2img_style_index: 'OPTION',
img2img_style2_index: 'OPTION',
setting_random_artist_categories: 'SPAN',
setting_face_restoration_model: 'SPAN',
setting_realesrgan_enabled_models: 'SPAN',
extras_upscaler_1: 'SPAN',
extras_upscaler_2: 'SPAN',
}
re_num = /^[\.\d]+$/
re_emoji = /[\p{Extended_Pictographic}\u{1F3FB}-\u{1F3FF}\u{1F9B0}-\u{1F9B3}]/u
original_lines = {}
translated_lines = {}
function textNodesUnder(el){
var n, a=[], walk=document.createTreeWalker(el,NodeFilter.SHOW_TEXT,null,false);
while(n=walk.nextNode()) a.push(n);
return a;
}
function canBeTranslated(node, text){
if(! text) return false;
if(! node.parentElement) return false;
parentType = node.parentElement.nodeName
if(parentType=='SCRIPT' || parentType=='STYLE' || parentType=='TEXTAREA') return false;
if (parentType=='OPTION' || parentType=='SPAN'){
pnode = node
for(var level=0; level<4; level++){
pnode = pnode.parentElement
if(! pnode) break;
if(ignore_ids_for_localization[pnode.id] == parentType) return false;
}
}
if(re_num.test(text)) return false;
if(re_emoji.test(text)) return false;
return true
}
function getTranslation(text){
if(! text) return undefined
if(translated_lines[text] === undefined){
original_lines[text] = 1
}
tl = localization[text]
if(tl !== undefined){
translated_lines[tl] = 1
}
return tl
}
function processTextNode(node){
text = node.textContent.trim()
if(! canBeTranslated(node, text)) return
tl = getTranslation(text)
if(tl !== undefined){
node.textContent = tl
}
}
function processNode(node){
if(node.nodeType == 3){
processTextNode(node)
return
}
if(node.title){
tl = getTranslation(node.title)
if(tl !== undefined){
node.title = tl
}
}
if(node.placeholder){
tl = getTranslation(node.placeholder)
if(tl !== undefined){
node.placeholder = tl
}
}
textNodesUnder(node).forEach(function(node){
processTextNode(node)
})
}
function dumpTranslations(){
dumped = {}
Object.keys(original_lines).forEach(function(text){
if(dumped[text] !== undefined) return
dumped[text] = localization[text] || text
})
return dumped
}
onUiUpdate(function(m){
m.forEach(function(mutation){
mutation.addedNodes.forEach(function(node){
processNode(node)
})
});
})
document.addEventListener("DOMContentLoaded", function() {
processNode(gradioApp())
})
function download_localization() {
text = JSON.stringify(dumpTranslations(), null, 4)
var element = document.createElement('a');
element.setAttribute('href', 'data:text/plain;charset=utf-8,' + encodeURIComponent(text));
element.setAttribute('download', "localization.json");
element.style.display = 'none';
document.body.appendChild(element);
element.click();
document.body.removeChild(element);
}

View file

@ -72,11 +72,17 @@ function check_gallery(id_gallery){
let galleryButtons = gradioApp().querySelectorAll('#'+id_gallery+' .gallery-item') let galleryButtons = gradioApp().querySelectorAll('#'+id_gallery+' .gallery-item')
let galleryBtnSelected = gradioApp().querySelector('#'+id_gallery+' .gallery-item.\\!ring-2') let galleryBtnSelected = gradioApp().querySelector('#'+id_gallery+' .gallery-item.\\!ring-2')
if (prevSelectedIndex !== -1 && galleryButtons.length>prevSelectedIndex && !galleryBtnSelected) { if (prevSelectedIndex !== -1 && galleryButtons.length>prevSelectedIndex && !galleryBtnSelected) {
//automatically re-open previously selected index (if exists) // automatically re-open previously selected index (if exists)
activeElement = document.activeElement; activeElement = gradioApp().activeElement;
galleryButtons[prevSelectedIndex].click(); galleryButtons[prevSelectedIndex].click();
showGalleryImage(); showGalleryImage();
if(activeElement) activeElement.focus()
if(activeElement){
// i fought this for about an hour; i don't know why the focus is lost or why this helps recover it
// if somenoe has a better solution please by all means
setTimeout(function() { activeElement.focus() }, 1);
}
} }
}) })
galleryObservers[id_gallery].observe( gallery, { childList:true, subtree:false }) galleryObservers[id_gallery].observe( gallery, { childList:true, subtree:false })

View file

@ -91,7 +91,8 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
def upscale(image, scaler_index, resize, mode, resize_w, resize_h, crop): def upscale(image, scaler_index, resize, mode, resize_w, resize_h, crop):
small = image.crop((image.width // 2, image.height // 2, image.width // 2 + 10, image.height // 2 + 10)) small = image.crop((image.width // 2, image.height // 2, image.width // 2 + 10, image.height // 2 + 10))
pixels = tuple(np.array(small).flatten().tolist()) pixels = tuple(np.array(small).flatten().tolist())
key = (resize, scaler_index, image.width, image.height, gfpgan_visibility, codeformer_visibility, codeformer_weight) + pixels key = (resize, scaler_index, image.width, image.height, gfpgan_visibility, codeformer_visibility, codeformer_weight,
resize_mode, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop) + pixels
c = cached_images.get(key) c = cached_images.get(key)
if c is None: if c is None:
@ -175,11 +176,14 @@ def run_pnginfo(image):
def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_name, interp_method, multiplier, save_as_half, custom_name): def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_name, interp_method, multiplier, save_as_half, custom_name):
def weighted_sum(theta0, theta1, theta2, alpha): def weighted_sum(theta0, theta1, alpha):
return ((1 - alpha) * theta0) + (alpha * theta1) return ((1 - alpha) * theta0) + (alpha * theta1)
def add_difference(theta0, theta1, theta2, alpha): def get_difference(theta1, theta2):
return theta0 + (theta1 - theta2) * alpha return theta1 - theta2
def add_difference(theta0, theta1_2_diff, alpha):
return theta0 + (alpha * theta1_2_diff)
primary_model_info = sd_models.checkpoints_list[primary_model_name] primary_model_info = sd_models.checkpoints_list[primary_model_name]
secondary_model_info = sd_models.checkpoints_list[secondary_model_name] secondary_model_info = sd_models.checkpoints_list[secondary_model_name]
@ -198,23 +202,28 @@ def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_nam
teritary_model = torch.load(teritary_model_info.filename, map_location='cpu') teritary_model = torch.load(teritary_model_info.filename, map_location='cpu')
theta_2 = sd_models.get_state_dict_from_checkpoint(teritary_model) theta_2 = sd_models.get_state_dict_from_checkpoint(teritary_model)
else: else:
teritary_model = None
theta_2 = None theta_2 = None
theta_funcs = { theta_funcs = {
"Weighted sum": weighted_sum, "Weighted sum": (None, weighted_sum),
"Add difference": add_difference, "Add difference": (get_difference, add_difference),
} }
theta_func = theta_funcs[interp_method] theta_func1, theta_func2 = theta_funcs[interp_method]
print(f"Merging...") print(f"Merging...")
if theta_func1:
for key in tqdm.tqdm(theta_1.keys()):
if 'model' in key:
t2 = theta_2.get(key, torch.zeros_like(theta_1[key]))
theta_1[key] = theta_func1(theta_1[key], t2)
del theta_2, teritary_model
for key in tqdm.tqdm(theta_0.keys()): for key in tqdm.tqdm(theta_0.keys()):
if 'model' in key and key in theta_1: if 'model' in key and key in theta_1:
t2 = (theta_2 or {}).get(key)
if t2 is None:
t2 = torch.zeros_like(theta_0[key])
theta_0[key] = theta_func(theta_0[key], theta_1[key], t2, multiplier) theta_0[key] = theta_func2(theta_0[key], theta_1[key], multiplier)
if save_as_half: if save_as_half:
theta_0[key] = theta_0[key].half() theta_0[key] = theta_0[key].half()

View file

@ -123,7 +123,7 @@ class InterrogateModels:
return caption[0] return caption[0]
def interrogate(self, pil_image, include_ranks=False): def interrogate(self, pil_image):
res = None res = None
try: try:
@ -156,10 +156,10 @@ class InterrogateModels:
for name, topn, items in self.categories: for name, topn, items in self.categories:
matches = self.rank(image_features, items, top_count=topn) matches = self.rank(image_features, items, top_count=topn)
for match, score in matches: for match, score in matches:
if include_ranks: if shared.opts.interrogate_return_ranks:
res += ", " + match res += f", ({match}:{score/100:.3f})"
else: else:
res += f", ({match}:{score})" res += ", " + match
except Exception: except Exception:
print(f"Error interrogating", file=sys.stderr) print(f"Error interrogating", file=sys.stderr)

31
modules/localization.py Normal file
View file

@ -0,0 +1,31 @@
import json
import os
import sys
import traceback
localizations = {}
def list_localizations(dirname):
localizations.clear()
for file in os.listdir(dirname):
fn, ext = os.path.splitext(file)
if ext.lower() != ".json":
continue
localizations[fn] = os.path.join(dirname, file)
def localization_js(current_localization_name):
fn = localizations.get(current_localization_name, None)
data = {}
if fn is not None:
try:
with open(fn, "r", encoding="utf8") as file:
data = json.load(file)
except Exception:
print(f"Error loading localization from {fn}:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
return f"var localization = {json.dumps(data)}\n"

View file

@ -58,6 +58,9 @@ def load_scripts(basedir):
for filename in sorted(os.listdir(basedir)): for filename in sorted(os.listdir(basedir)):
path = os.path.join(basedir, filename) path = os.path.join(basedir, filename)
if os.path.splitext(path)[1].lower() != '.py':
continue
if not os.path.isfile(path): if not os.path.isfile(path):
continue continue

View file

@ -14,7 +14,7 @@ import modules.memmon
import modules.sd_models import modules.sd_models
import modules.styles import modules.styles
import modules.devices as devices import modules.devices as devices
from modules import sd_samplers, sd_models from modules import sd_samplers, sd_models, localization
from modules.hypernetworks import hypernetwork from modules.hypernetworks import hypernetwork
from modules.paths import models_path, script_path, sd_path from modules.paths import models_path, script_path, sd_path
@ -34,6 +34,7 @@ parser.add_argument("--embeddings-dir", type=str, default=os.path.join(script_pa
parser.add_argument("--aesthetic_embeddings-dir", type=str, default=os.path.join(script_path, 'aesthetic_embeddings'), parser.add_argument("--aesthetic_embeddings-dir", type=str, default=os.path.join(script_path, 'aesthetic_embeddings'),
help="aesthetic_embeddings directory(default: aesthetic_embeddings)") help="aesthetic_embeddings directory(default: aesthetic_embeddings)")
parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory") parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory")
parser.add_argument("--localizations-dir", type=str, default=os.path.join(script_path, 'localizations'), help="localizations directory")
parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui") parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui")
parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a little speed for low VRM usage") parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a little speed for low VRM usage")
parser.add_argument("--lowvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a lot of speed for very low VRM usage") parser.add_argument("--lowvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a lot of speed for very low VRM usage")
@ -106,6 +107,7 @@ os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True)
hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir) hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir)
loaded_hypernetwork = None loaded_hypernetwork = None
aesthetic_embeddings = {} aesthetic_embeddings = {}
def update_aesthetic_embeddings(): def update_aesthetic_embeddings():
@ -116,6 +118,7 @@ def update_aesthetic_embeddings():
update_aesthetic_embeddings() update_aesthetic_embeddings()
def reload_hypernetworks(): def reload_hypernetworks():
global hypernetworks global hypernetworks
@ -163,6 +166,8 @@ interrogator = modules.interrogate.InterrogateModels("interrogate")
face_restorers = [] face_restorers = []
localization.list_localizations(cmd_opts.localizations_dir)
def realesrgan_models_names(): def realesrgan_models_names():
import modules.realesrgan_model import modules.realesrgan_model
@ -308,6 +313,7 @@ options_templates.update(options_section(('ui', "User interface"), {
"js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"), "js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),
"show_progress_in_title": OptionInfo(True, "Show generation progress in window title."), "show_progress_in_title": OptionInfo(True, "Show generation progress in window title."),
'quicksettings': OptionInfo("sd_model_checkpoint", "Quicksettings list"), 'quicksettings': OptionInfo("sd_model_checkpoint", "Quicksettings list"),
'localization': OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)),
})) }))
options_templates.update(options_section(('sampler-params', "Sampler parameters"), { options_templates.update(options_section(('sampler-params', "Sampler parameters"), {

View file

@ -137,6 +137,7 @@ class EmbeddingDatabase:
continue continue
print(f"Loaded a total of {len(self.word_embeddings)} textual inversion embeddings.") print(f"Loaded a total of {len(self.word_embeddings)} textual inversion embeddings.")
print("Embeddings:", ', '.join(self.word_embeddings.keys()))
def find_embedding_at_position(self, tokens, offset): def find_embedding_at_position(self, tokens, offset):
token = tokens[offset] token = tokens[offset]

View file

@ -23,7 +23,7 @@ import gradio as gr
import gradio.utils import gradio.utils
import gradio.routes import gradio.routes
from modules import sd_hijack, sd_models from modules import sd_hijack, sd_models, localization
from modules.paths import script_path from modules.paths import script_path
from modules.shared import opts, cmd_opts, restricted_opts, aesthetic_embeddings from modules.shared import opts, cmd_opts, restricted_opts, aesthetic_embeddings
@ -1102,10 +1102,10 @@ def create_ui(wrap_gradio_gpu_call):
upscaling_crop = gr.Checkbox(label='Crop to fit', value=True) upscaling_crop = gr.Checkbox(label='Crop to fit', value=True)
with gr.Group(): with gr.Group():
extras_upscaler_1 = gr.Radio(label='Upscaler 1', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") extras_upscaler_1 = gr.Radio(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index")
with gr.Group(): with gr.Group():
extras_upscaler_2 = gr.Radio(label='Upscaler 2', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") extras_upscaler_2 = gr.Radio(label='Upscaler 2', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index")
extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=1) extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=1)
with gr.Group(): with gr.Group():
@ -1282,10 +1282,10 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Tab(label="Train"): with gr.Tab(label="Train"):
gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding; must specify a directory with a set of 1:1 ratio images</p>") gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding; must specify a directory with a set of 1:1 ratio images</p>")
with gr.Row(): with gr.Row():
train_embedding_name = gr.Dropdown(label='Embedding', choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())) train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name") create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name")
with gr.Row(): with gr.Row():
train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', choices=[x for x in shared.hypernetworks.keys()]) train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=[x for x in shared.hypernetworks.keys()])
create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted([x for x in shared.hypernetworks.keys()])}, "refresh_train_hypernetwork_name") create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted([x for x in shared.hypernetworks.keys()])}, "refresh_train_hypernetwork_name")
learn_rate = gr.Textbox(label='Learning rate', placeholder="Learning rate", value="0.005") learn_rate = gr.Textbox(label='Learning rate', placeholder="Learning rate", value="0.005")
batch_size = gr.Number(label='Batch size', value=1, precision=0) batch_size = gr.Number(label='Batch size', value=1, precision=0)
@ -1452,16 +1452,18 @@ def create_ui(wrap_gradio_gpu_call):
else: else:
raise Exception(f'bad options item type: {str(t)} for key {key}') raise Exception(f'bad options item type: {str(t)} for key {key}')
elem_id = "setting_"+key
if info.refresh is not None: if info.refresh is not None:
if is_quicksettings: if is_quicksettings:
res = comp(label=info.label, value=fun, **(args or {})) res = comp(label=info.label, value=fun, elem_id=elem_id, **(args or {}))
refresh_button = create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key) create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key)
else: else:
with gr.Row(variant="compact"): with gr.Row(variant="compact"):
res = comp(label=info.label, value=fun, **(args or {})) res = comp(label=info.label, value=fun, elem_id=elem_id, **(args or {}))
refresh_button = create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key) create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key)
else: else:
res = comp(label=info.label, value=fun, **(args or {})) res = comp(label=info.label, value=fun, elem_id=elem_id, **(args or {}))
return res return res
@ -1585,6 +1587,9 @@ Requested path was: {f}
with gr.Row(): with gr.Row():
request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications") request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications")
download_localization = gr.Button(value='Download localization template', elem_id="download_localization")
with gr.Row():
reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary') reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary')
restart_gradio = gr.Button(value='Restart Gradio and Refresh components (Custom Scripts, ui.py, js and css only)', variant='primary') restart_gradio = gr.Button(value='Restart Gradio and Refresh components (Custom Scripts, ui.py, js and css only)', variant='primary')
@ -1595,6 +1600,13 @@ Requested path was: {f}
_js='function(){}' _js='function(){}'
) )
download_localization.click(
fn=lambda: None,
inputs=[],
outputs=[],
_js='download_localization'
)
def reload_scripts(): def reload_scripts():
modules.scripts.reload_script_body_only() modules.scripts.reload_script_body_only()
@ -1843,6 +1855,7 @@ Requested path was: {f}
visit(txt2img_interface, loadsave, "txt2img") visit(txt2img_interface, loadsave, "txt2img")
visit(img2img_interface, loadsave, "img2img") visit(img2img_interface, loadsave, "img2img")
visit(extras_interface, loadsave, "extras") visit(extras_interface, loadsave, "extras")
visit(modelmerger_interface, loadsave, "modelmerger")
if not error_loading and (not os.path.exists(ui_config_file) or settings_count != len(ui_settings)): if not error_loading and (not os.path.exists(ui_config_file) or settings_count != len(ui_settings)):
with open(ui_config_file, "w", encoding="utf8") as file: with open(ui_config_file, "w", encoding="utf8") as file:
@ -1859,6 +1872,7 @@ for filename in sorted(os.listdir(jsdir)):
with open(os.path.join(jsdir, filename), "r", encoding="utf8") as jsfile: with open(os.path.join(jsdir, filename), "r", encoding="utf8") as jsfile:
javascript += f"\n<script>{jsfile.read()}</script>" javascript += f"\n<script>{jsfile.read()}</script>"
javascript += f"\n<script>{localization.localization_js(shared.opts.localization)}</script>"
if 'gradio_routes_templates_response' not in globals(): if 'gradio_routes_templates_response' not in globals():
def template_response(*args, **kwargs): def template_response(*args, **kwargs):

View file

@ -21,20 +21,20 @@ function onUiTabChange(callback){
uiTabChangeCallbacks.push(callback) uiTabChangeCallbacks.push(callback)
} }
function runCallback(x){ function runCallback(x, m){
try { try {
x() x(m)
} catch (e) { } catch (e) {
(console.error || console.log).call(console, e.message, e); (console.error || console.log).call(console, e.message, e);
} }
} }
function executeCallbacks(queue) { function executeCallbacks(queue, m) {
queue.forEach(runCallback) queue.forEach(function(x){runCallback(x, m)})
} }
document.addEventListener("DOMContentLoaded", function() { document.addEventListener("DOMContentLoaded", function() {
var mutationObserver = new MutationObserver(function(m){ var mutationObserver = new MutationObserver(function(m){
executeCallbacks(uiUpdateCallbacks); executeCallbacks(uiUpdateCallbacks, m);
const newTab = get_uiCurrentTab(); const newTab = get_uiCurrentTab();
if ( newTab && ( newTab !== uiCurrentTab ) ) { if ( newTab && ( newTab !== uiCurrentTab ) ) {
uiCurrentTab = newTab; uiCurrentTab = newTab;

View file

@ -233,6 +233,21 @@ def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend, include_lone_
return processed_result return processed_result
class SharedSettingsStackHelper(object):
def __enter__(self):
self.CLIP_stop_at_last_layers = opts.CLIP_stop_at_last_layers
self.hypernetwork = opts.sd_hypernetwork
self.model = shared.sd_model
def __exit__(self, exc_type, exc_value, tb):
modules.sd_models.reload_model_weights(self.model)
hypernetwork.load_hypernetwork(self.hypernetwork)
hypernetwork.apply_strength()
opts.data["CLIP_stop_at_last_layers"] = self.CLIP_stop_at_last_layers
re_range = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\(([+-]\d+)\s*\))?\s*") re_range = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\(([+-]\d+)\s*\))?\s*")
re_range_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\(([+-]\d+(?:.\d*)?)\s*\))?\s*") re_range_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\(([+-]\d+(?:.\d*)?)\s*\))?\s*")
@ -267,9 +282,6 @@ class Script(scripts.Script):
if not opts.return_grid: if not opts.return_grid:
p.batch_size = 1 p.batch_size = 1
CLIP_stop_at_last_layers = opts.CLIP_stop_at_last_layers
def process_axis(opt, vals): def process_axis(opt, vals):
if opt.label == 'Nothing': if opt.label == 'Nothing':
return [0] return [0]
@ -367,27 +379,19 @@ class Script(scripts.Script):
return process_images(pc) return process_images(pc)
processed = draw_xy_grid( with SharedSettingsStackHelper():
p, processed = draw_xy_grid(
xs=xs, p,
ys=ys, xs=xs,
x_labels=[x_opt.format_value(p, x_opt, x) for x in xs], ys=ys,
y_labels=[y_opt.format_value(p, y_opt, y) for y in ys], x_labels=[x_opt.format_value(p, x_opt, x) for x in xs],
cell=cell, y_labels=[y_opt.format_value(p, y_opt, y) for y in ys],
draw_legend=draw_legend, cell=cell,
include_lone_images=include_lone_images draw_legend=draw_legend,
) include_lone_images=include_lone_images
)
if opts.grid_save: if opts.grid_save:
images.save_image(processed.images[0], p.outpath_grids, "xy_grid", prompt=p.prompt, seed=processed.seed, grid=True, p=p) images.save_image(processed.images[0], p.outpath_grids, "xy_grid", prompt=p.prompt, seed=processed.seed, grid=True, p=p)
# restore checkpoint in case it was changed by axes
modules.sd_models.reload_model_weights(shared.sd_model)
hypernetwork.load_hypernetwork(opts.sd_hypernetwork)
hypernetwork.apply_strength()
opts.data["CLIP_stop_at_last_layers"] = CLIP_stop_at_last_layers
return processed return processed

View file

@ -478,7 +478,7 @@ input[type="range"]{
padding: 0; padding: 0;
} }
#refresh_sd_model_checkpoint, #refresh_sd_hypernetwork, #refresh_train_hypernetwork_name, #refresh_train_embedding_name{ #refresh_sd_model_checkpoint, #refresh_sd_hypernetwork, #refresh_train_hypernetwork_name, #refresh_train_embedding_name, #refresh_localization{
max-width: 2.5em; max-width: 2.5em;
min-width: 2.5em; min-width: 2.5em;
height: 2.4em; height: 2.4em;