Merge branch 'master' into saving

This commit is contained in:
WDevelopsWebApps 2022-09-29 12:19:13 +02:00 committed by GitHub
commit f28ce3e3a1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
20 changed files with 472 additions and 131 deletions

2
.gitignore vendored
View file

@ -4,7 +4,7 @@ __pycache__
/venv /venv
/tmp /tmp
/model.ckpt /model.ckpt
/models/*.ckpt /models/**/*.ckpt
/GFPGANv1.3.pth /GFPGANv1.3.pth
/gfpgan/weights/*.pth /gfpgan/weights/*.pth
/ui-config.json /ui-config.json

View file

@ -15,6 +15,7 @@ titles = {
"\u267b\ufe0f": "Reuse seed from last generation, mostly useful if it was randomed", "\u267b\ufe0f": "Reuse seed from last generation, mostly useful if it was randomed",
"\u{1f3a8}": "Add a random artist to the prompt.", "\u{1f3a8}": "Add a random artist to the prompt.",
"\u2199\ufe0f": "Read generation parameters from prompt into user interface.", "\u2199\ufe0f": "Read generation parameters from prompt into user interface.",
"\uD83D\uDCC2": "Open images output directory",
"Inpaint a part of image": "Draw a mask over an image, and the script will regenerate the masked area with content according to prompt", "Inpaint a part of image": "Draw a mask over an image, and the script will regenerate the masked area with content according to prompt",
"SD upscale": "Upscale image normally, split result into tiles, improve each tile using img2img, merge whole image back", "SD upscale": "Upscale image normally, split result into tiles, improve each tile using img2img, merge whole image back",

View file

@ -182,4 +182,23 @@ onUiUpdate(function(){
}); });
json_elem.parentElement.style.display="none" json_elem.parentElement.style.display="none"
if (!txt2img_textarea) {
txt2img_textarea = gradioApp().querySelector("#txt2img_prompt > label > textarea");
txt2img_textarea?.addEventListener("input", () => update_token_counter("txt2img_token_button"));
}
if (!img2img_textarea) {
img2img_textarea = gradioApp().querySelector("#img2img_prompt > label > textarea");
img2img_textarea?.addEventListener("input", () => update_token_counter("img2img_token_button"));
}
}) })
let txt2img_textarea, img2img_textarea = undefined;
let wait_time = 800
let token_timeout;
function update_token_counter(button_id) {
if (token_timeout)
clearTimeout(token_timeout);
token_timeout = setTimeout(() => gradioApp().getElementById(button_id)?.click(), wait_time);
}

View file

@ -15,14 +15,14 @@ torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.12.1+cu113
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt") requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
commandline_args = os.environ.get('COMMANDLINE_ARGS', "") commandline_args = os.environ.get('COMMANDLINE_ARGS', "")
k_diffusion_package = os.environ.get('K_DIFFUSION_PACKAGE', "git+https://github.com/crowsonkb/k-diffusion.git@1a0703dfb7d24d8806267c3e7ccc4caf67fd1331")
gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379") gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379")
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "69ae4b35e0a0f6ee1af8bb9a5d0016ccb27e36dc") stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "69ae4b35e0a0f6ee1af8bb9a5d0016ccb27e36dc")
taming_transformers_commit_hash = os.environ.get('TAMING_TRANSFORMERS_COMMIT_HASH', "24268930bf1dce879235a7fddd0b2355b84d7ea6") taming_transformers_commit_hash = os.environ.get('TAMING_TRANSFORMERS_COMMIT_HASH', "24268930bf1dce879235a7fddd0b2355b84d7ea6")
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "9e3002b7cd64df7870e08527b7664eb2f2f5f3f5")
codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af") codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9") blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
ldsr_commit_hash = os.environ.get('LDSR_COMMIT_HASH',"abf33e7002d59d9085081bce93ec798dcabd49af") ldsr_commit_hash = os.environ.get('LDSR_COMMIT_HASH', "abf33e7002d59d9085081bce93ec798dcabd49af")
args = shlex.split(commandline_args) args = shlex.split(commandline_args)
@ -110,9 +110,6 @@ if not is_installed("torch") or not is_installed("torchvision"):
if not skip_torch_cuda_test: if not skip_torch_cuda_test:
run_python("import torch; assert torch.cuda.is_available(), 'Torch is not able to use GPU; add --skip-torch-cuda-test to COMMANDLINE_ARGS variable to disable this check'") run_python("import torch; assert torch.cuda.is_available(), 'Torch is not able to use GPU; add --skip-torch-cuda-test to COMMANDLINE_ARGS variable to disable this check'")
if not is_installed("k_diffusion.sampling"):
run_pip(f"install {k_diffusion_package}", "k-diffusion")
if not is_installed("gfpgan"): if not is_installed("gfpgan"):
run_pip(f"install {gfpgan_package}", "gfpgan") run_pip(f"install {gfpgan_package}", "gfpgan")
@ -120,6 +117,7 @@ os.makedirs(dir_repos, exist_ok=True)
git_clone("https://github.com/CompVis/stable-diffusion.git", repo_dir('stable-diffusion'), "Stable Diffusion", stable_diffusion_commit_hash) git_clone("https://github.com/CompVis/stable-diffusion.git", repo_dir('stable-diffusion'), "Stable Diffusion", stable_diffusion_commit_hash)
git_clone("https://github.com/CompVis/taming-transformers.git", repo_dir('taming-transformers'), "Taming Transformers", taming_transformers_commit_hash) git_clone("https://github.com/CompVis/taming-transformers.git", repo_dir('taming-transformers'), "Taming Transformers", taming_transformers_commit_hash)
git_clone("https://github.com/crowsonkb/k-diffusion.git", repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash)
git_clone("https://github.com/sczhou/CodeFormer.git", repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash) git_clone("https://github.com/sczhou/CodeFormer.git", repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash)
git_clone("https://github.com/salesforce/BLIP.git", repo_dir('BLIP'), "BLIP", blip_commit_hash) git_clone("https://github.com/salesforce/BLIP.git", repo_dir('BLIP'), "BLIP", blip_commit_hash)
# Using my repo until my changes are merged, as this makes interfacing with our version of SD-web a lot easier # Using my repo until my changes are merged, as this makes interfacing with our version of SD-web a lot easier

View file

@ -6,13 +6,14 @@ from PIL import Image
import torch import torch
import tqdm import tqdm
from modules import processing, shared, images, devices from modules import processing, shared, images, devices, sd_models
from modules.shared import opts from modules.shared import opts
import modules.gfpgan_model import modules.gfpgan_model
from modules.ui import plaintext_to_html from modules.ui import plaintext_to_html
import modules.codeformer_model import modules.codeformer_model
import piexif import piexif
import piexif.helper import piexif.helper
import gradio as gr
cached_images = {} cached_images = {}
@ -140,7 +141,7 @@ def run_pnginfo(image):
return '', geninfo, info return '', geninfo, info
def run_modelmerger(primary_model_name, secondary_model_name, interp_method, interp_amount): def run_modelmerger(primary_model_name, secondary_model_name, interp_method, interp_amount, save_as_half, custom_name):
# Linear interpolation (https://en.wikipedia.org/wiki/Linear_interpolation) # Linear interpolation (https://en.wikipedia.org/wiki/Linear_interpolation)
def weighted_sum(theta0, theta1, alpha): def weighted_sum(theta0, theta1, alpha):
return ((1 - alpha) * theta0) + (alpha * theta1) return ((1 - alpha) * theta0) + (alpha * theta1)
@ -150,23 +151,20 @@ def run_modelmerger(primary_model_name, secondary_model_name, interp_method, int
alpha = alpha * alpha * (3 - (2 * alpha)) alpha = alpha * alpha * (3 - (2 * alpha))
return theta0 + ((theta1 - theta0) * alpha) return theta0 + ((theta1 - theta0) * alpha)
if os.path.exists(primary_model_name): # Inverse Smoothstep (https://en.wikipedia.org/wiki/Smoothstep)
primary_model_filename = primary_model_name def inv_sigmoid(theta0, theta1, alpha):
primary_model_name = os.path.splitext(os.path.basename(primary_model_name))[0] import math
else: alpha = 0.5 - math.sin(math.asin(1.0 - 2.0 * alpha) / 3.0)
primary_model_filename = 'models/' + primary_model_name + '.ckpt' return theta0 + ((theta1 - theta0) * alpha)
if os.path.exists(secondary_model_name): primary_model_info = sd_models.checkpoints_list[primary_model_name]
secondary_model_filename = secondary_model_name secondary_model_info = sd_models.checkpoints_list[secondary_model_name]
secondary_model_name = os.path.splitext(os.path.basename(secondary_model_name))[0]
else:
secondary_model_filename = 'models/' + secondary_model_name + '.ckpt'
print(f"Loading {primary_model_filename}...") print(f"Loading {primary_model_info.filename}...")
primary_model = torch.load(primary_model_filename, map_location='cpu') primary_model = torch.load(primary_model_info.filename, map_location='cpu')
print(f"Loading {secondary_model_filename}...") print(f"Loading {secondary_model_info.filename}...")
secondary_model = torch.load(secondary_model_filename, map_location='cpu') secondary_model = torch.load(secondary_model_info.filename, map_location='cpu')
theta_0 = primary_model['state_dict'] theta_0 = primary_model['state_dict']
theta_1 = secondary_model['state_dict'] theta_1 = secondary_model['state_dict']
@ -174,21 +172,31 @@ def run_modelmerger(primary_model_name, secondary_model_name, interp_method, int
theta_funcs = { theta_funcs = {
"Weighted Sum": weighted_sum, "Weighted Sum": weighted_sum,
"Sigmoid": sigmoid, "Sigmoid": sigmoid,
"Inverse Sigmoid": inv_sigmoid,
} }
theta_func = theta_funcs[interp_method] theta_func = theta_funcs[interp_method]
print(f"Merging...") print(f"Merging...")
for key in tqdm.tqdm(theta_0.keys()): for key in tqdm.tqdm(theta_0.keys()):
if 'model' in key and key in theta_1: if 'model' in key and key in theta_1:
theta_0[key] = theta_func(theta_0[key], theta_1[key], (float(1.0) - interp_amount)) # Need to reverse the interp_amount to match the desired mix ration in the merged checkpoint theta_0[key] = theta_func(theta_0[key], theta_1[key], (float(1.0) - interp_amount)) # Need to reverse the interp_amount to match the desired mix ration in the merged checkpoint
if save_as_half:
theta_0[key] = theta_0[key].half()
for key in theta_1.keys(): for key in theta_1.keys():
if 'model' in key and key not in theta_0: if 'model' in key and key not in theta_0:
theta_0[key] = theta_1[key] theta_0[key] = theta_1[key]
if save_as_half:
theta_0[key] = theta_0[key].half()
filename = primary_model_info.model_name + '_' + str(round(interp_amount, 2)) + '-' + secondary_model_info.model_name + '_' + str(round((float(1.0) - interp_amount), 2)) + '-' + interp_method.replace(" ", "_") + '-merged.ckpt'
filename = filename if custom_name == '' else (custom_name + '.ckpt')
output_modelname = os.path.join(shared.cmd_opts.ckpt_dir, filename)
output_modelname = 'models/' + primary_model_name + '_' + str(round(interp_amount,2)) + '-' + secondary_model_name + '_' + str(round((float(1.0) - interp_amount),2)) + '-' + interp_method.replace(" ", "_") + '-merged.ckpt'
print(f"Saving to {output_modelname}...") print(f"Saving to {output_modelname}...")
torch.save(primary_model, output_modelname) torch.save(primary_model, output_modelname)
sd_models.list_models()
print(f"Checkpoint saved.") print(f"Checkpoint saved.")
return "Checkpoint saved to " + output_modelname return ["Checkpoint saved to " + output_modelname] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(3)]

View file

@ -124,4 +124,4 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
if opts.samples_log_stdout: if opts.samples_log_stdout:
print(generation_info_js) print(generation_info_js)
return processed.images, generation_info_js, plaintext_to_html(processed.info) return processed.images, generation_info_js, plaintext_to_html(processed.info)

View file

@ -20,6 +20,7 @@ path_dirs = [
(os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer'), (os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer'),
(os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP'), (os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP'),
(os.path.join(sd_path, '../latent-diffusion'), 'LDSR.py', 'LDSR'), (os.path.join(sd_path, '../latent-diffusion'), 'LDSR.py', 'LDSR'),
(os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion'),
] ]
paths = {} paths = {}

View file

@ -49,7 +49,7 @@ def apply_color_correction(correction, image):
class StableDiffusionProcessing: class StableDiffusionProcessing:
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", styles=None, seed=-1, subseed=-1, subseed_strength=0, seed_resize_from_h=-1, seed_resize_from_w=-1, seed_enable_extras=True, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None): def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", styles=None, seed=-1, subseed=-1, subseed_strength=0, seed_resize_from_h=-1, seed_resize_from_w=-1, seed_enable_extras=True, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None, eta=None):
self.sd_model = sd_model self.sd_model = sd_model
self.outpath_samples: str = outpath_samples self.outpath_samples: str = outpath_samples
self.outpath_grids: str = outpath_grids self.outpath_grids: str = outpath_grids
@ -75,15 +75,15 @@ class StableDiffusionProcessing:
self.do_not_save_grid: bool = do_not_save_grid self.do_not_save_grid: bool = do_not_save_grid
self.extra_generation_params: dict = extra_generation_params or {} self.extra_generation_params: dict = extra_generation_params or {}
self.overlay_images = overlay_images self.overlay_images = overlay_images
self.eta = eta
self.paste_to = None self.paste_to = None
self.color_corrections = None self.color_corrections = None
self.denoising_strength: float = 0 self.denoising_strength: float = 0
self.eta = opts.eta
self.ddim_discretize = opts.ddim_discretize self.ddim_discretize = opts.ddim_discretize
self.s_churn = opts.s_churn self.s_churn = opts.s_churn
self.s_tmin = opts.s_tmin self.s_tmin = opts.s_tmin
self.s_tmax = float('inf') # not representable as a standard ui option self.s_tmax = float('inf') # not representable as a standard ui option
self.s_noise = opts.s_noise self.s_noise = opts.s_noise
if not seed_enable_extras: if not seed_enable_extras:
@ -100,7 +100,7 @@ class StableDiffusionProcessing:
class Processed: class Processed:
def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0): def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None):
self.images = images_list self.images = images_list
self.prompt = p.prompt self.prompt = p.prompt
self.negative_prompt = p.negative_prompt self.negative_prompt = p.negative_prompt
@ -139,6 +139,7 @@ class Processed:
self.all_prompts = all_prompts or [self.prompt] self.all_prompts = all_prompts or [self.prompt]
self.all_seeds = all_seeds or [self.seed] self.all_seeds = all_seeds or [self.seed]
self.all_subseeds = all_subseeds or [self.subseed] self.all_subseeds = all_subseeds or [self.subseed]
self.infotexts = infotexts or [info]
def js(self): def js(self):
obj = { obj = {
@ -165,6 +166,7 @@ class Processed:
"denoising_strength": self.denoising_strength, "denoising_strength": self.denoising_strength,
"extra_generation_params": self.extra_generation_params, "extra_generation_params": self.extra_generation_params,
"index_of_first_image": self.index_of_first_image, "index_of_first_image": self.index_of_first_image,
"infotexts": self.infotexts,
} }
return json.dumps(obj) return json.dumps(obj)
@ -269,6 +271,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength), "Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
"Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"), "Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
"Denoising strength": getattr(p, 'denoising_strength', None), "Denoising strength": getattr(p, 'denoising_strength', None),
"Eta": (None if p.sampler.eta == p.sampler.default_eta else p.sampler.eta),
} }
generation_params.update(p.extra_generation_params) generation_params.update(p.extra_generation_params)
@ -277,7 +280,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
negative_prompt_text = "\nNegative prompt: " + p.negative_prompt if p.negative_prompt else "" negative_prompt_text = "\nNegative prompt: " + p.negative_prompt if p.negative_prompt else ""
return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip() + "".join(["\n\n" + x for x in comments]) return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip()
def process_images(p: StableDiffusionProcessing) -> Processed: def process_images(p: StableDiffusionProcessing) -> Processed:
@ -322,6 +325,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
if os.path.exists(cmd_opts.embeddings_dir): if os.path.exists(cmd_opts.embeddings_dir):
model_hijack.load_textual_inversion_embeddings(cmd_opts.embeddings_dir, p.sd_model) model_hijack.load_textual_inversion_embeddings(cmd_opts.embeddings_dir, p.sd_model)
infotexts = []
output_images = [] output_images = []
precision_scope = torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext precision_scope = torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext
ema_scope = (contextlib.nullcontext if cmd_opts.lowvram else p.sd_model.ema_scope) ema_scope = (contextlib.nullcontext if cmd_opts.lowvram else p.sd_model.ema_scope)
@ -404,6 +408,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
if opts.samples_save and not p.do_not_save_samples: if opts.samples_save and not p.do_not_save_samples:
images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p) images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p)
infotexts.append(infotext(n, i))
output_images.append(image) output_images.append(image)
state.nextjob() state.nextjob()
@ -416,6 +421,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
grid = images.image_grid(output_images, p.batch_size) grid = images.image_grid(output_images, p.batch_size)
if opts.return_grid: if opts.return_grid:
infotexts.insert(0, infotext())
output_images.insert(0, grid) output_images.insert(0, grid)
index_of_first_image = 1 index_of_first_image = 1
@ -423,7 +429,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
images.save_image(grid, p.outpath_grids, "grid", all_seeds[0], all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True) images.save_image(grid, p.outpath_grids, "grid", all_seeds[0], all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
devices.torch_gc() devices.torch_gc()
return Processed(p, output_images, all_seeds[0], infotext(), subseed=all_subseeds[0], all_prompts=all_prompts, all_seeds=all_seeds, all_subseeds=all_subseeds, index_of_first_image=index_of_first_image) return Processed(p, output_images, all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=all_subseeds[0], all_prompts=all_prompts, all_seeds=all_seeds, all_subseeds=all_subseeds, index_of_first_image=index_of_first_image, infotexts=infotexts)
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):

View file

@ -126,5 +126,93 @@ def reconstruct_cond_batch(c: ScheduledPromptBatch, current_step):
return res return res
re_attention = re.compile(r"""
\\\(|
\\\)|
\\\[|
\\]|
\\\\|
\\|
\(|
\[|
:([+-]?[.\d]+)\)|
\)|
]|
[^\\()\[\]:]+|
:
""", re.X)
#get_learned_conditioning_prompt_schedules(["fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]"], 100)
def parse_prompt_attention(text):
"""
Parses a string with attention tokens and returns a list of pairs: text and its assoicated weight.
Accepted tokens are:
(abc) - increases attention to abc by a multiplier of 1.1
(abc:3.12) - increases attention to abc by a multiplier of 3.12
[abc] - decreases attention to abc by a multiplier of 1.1
\( - literal character '('
\[ - literal character '['
\) - literal character ')'
\] - literal character ']'
\\ - literal character '\'
anything else - just text
Example:
'a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).'
produces:
[
['a ', 1.0],
['house', 1.5730000000000004],
[' ', 1.1],
['on', 1.0],
[' a ', 1.1],
['hill', 0.55],
[', sun, ', 1.1],
['sky', 1.4641000000000006],
['.', 1.1]
]
"""
res = []
round_brackets = []
square_brackets = []
round_bracket_multiplier = 1.1
square_bracket_multiplier = 1 / 1.1
def multiply_range(start_position, multiplier):
for p in range(start_position, len(res)):
res[p][1] *= multiplier
for m in re_attention.finditer(text):
text = m.group(0)
weight = m.group(1)
if text.startswith('\\'):
res.append([text[1:], 1.0])
elif text == '(':
round_brackets.append(len(res))
elif text == '[':
square_brackets.append(len(res))
elif weight is not None and len(round_brackets) > 0:
multiply_range(round_brackets.pop(), float(weight))
elif text == ')' and len(round_brackets) > 0:
multiply_range(round_brackets.pop(), round_bracket_multiplier)
elif text == ']' and len(square_brackets) > 0:
multiply_range(square_brackets.pop(), square_bracket_multiplier)
else:
res.append([text, 1.0])
for pos in round_brackets:
multiply_range(pos, round_bracket_multiplier)
for pos in square_brackets:
multiply_range(pos, square_bracket_multiplier)
if len(res) == 0:
res = [["", 1.0]]
return res

View file

@ -55,7 +55,7 @@ def load_scripts(basedir):
if not os.path.exists(basedir): if not os.path.exists(basedir):
return return
for filename in os.listdir(basedir): for filename in sorted(os.listdir(basedir)):
path = os.path.join(basedir, filename) path = os.path.join(basedir, filename)
if not os.path.isfile(path): if not os.path.isfile(path):

View file

@ -6,6 +6,7 @@ import torch
import numpy as np import numpy as np
from torch import einsum from torch import einsum
from modules import prompt_parser
from modules.shared import opts, device, cmd_opts from modules.shared import opts, device, cmd_opts
from ldm.util import default from ldm.util import default
@ -180,6 +181,7 @@ class StableDiffusionModelHijack:
dir_mtime = None dir_mtime = None
layers = None layers = None
circular_enabled = False circular_enabled = False
clip = None
def load_textual_inversion_embeddings(self, dirname, model): def load_textual_inversion_embeddings(self, dirname, model):
mt = os.path.getmtime(dirname) mt = os.path.getmtime(dirname)
@ -210,6 +212,7 @@ class StableDiffusionModelHijack:
param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11 param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11
assert len(param_dict) == 1, 'embedding file has multiple terms in it' assert len(param_dict) == 1, 'embedding file has multiple terms in it'
emb = next(iter(param_dict.items()))[1] emb = next(iter(param_dict.items()))[1]
# diffuser concepts
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
assert len(data.keys()) == 1, 'embedding file has multiple terms in it' assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
@ -235,13 +238,14 @@ class StableDiffusionModelHijack:
print(traceback.format_exc(), file=sys.stderr) print(traceback.format_exc(), file=sys.stderr)
continue continue
print(f"Loaded a total of {len(self.word_embeddings)} text inversion embeddings.") print(f"Loaded a total of {len(self.word_embeddings)} textual inversion embeddings.")
def hijack(self, m): def hijack(self, m):
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self) model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
m.cond_stage_model = FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self) m.cond_stage_model = FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
self.clip = m.cond_stage_model
if cmd_opts.opt_split_attention_v1: if cmd_opts.opt_split_attention_v1:
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1 ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1
@ -268,6 +272,11 @@ class StableDiffusionModelHijack:
for layer in [layer for layer in self.layers if type(layer) == torch.nn.Conv2d]: for layer in [layer for layer in self.layers if type(layer) == torch.nn.Conv2d]:
layer.padding_mode = 'circular' if enable else 'zeros' layer.padding_mode = 'circular' if enable else 'zeros'
def tokenize(self, text):
max_length = self.clip.max_length - 2
_, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text])
return remade_batch_tokens[0], token_count, max_length
class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
def __init__(self, wrapped, hijack): def __init__(self, wrapped, hijack):
@ -294,14 +303,101 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
if mult != 1.0: if mult != 1.0:
self.token_mults[ident] = mult self.token_mults[ident] = mult
def forward(self, text):
self.hijack.fixes = [] def tokenize_line(self, line, used_custom_terms, hijack_comments):
self.hijack.comments = [] id_start = self.wrapped.tokenizer.bos_token_id
id_end = self.wrapped.tokenizer.eos_token_id
maxlen = self.wrapped.max_length
if opts.enable_emphasis:
parsed = prompt_parser.parse_prompt_attention(line)
else:
parsed = [[line, 1.0]]
tokenized = self.wrapped.tokenizer([text for text, _ in parsed], truncation=False, add_special_tokens=False)["input_ids"]
fixes = []
remade_tokens = []
multipliers = []
for tokens, (text, weight) in zip(tokenized, parsed):
i = 0
while i < len(tokens):
token = tokens[i]
possible_matches = self.hijack.ids_lookup.get(token, None)
if possible_matches is None:
remade_tokens.append(token)
multipliers.append(weight)
else:
found = False
for ids, word in possible_matches:
if tokens[i:i + len(ids)] == ids:
emb_len = int(self.hijack.word_embeddings[word].shape[0])
fixes.append((len(remade_tokens), word))
remade_tokens += [0] * emb_len
multipliers += [weight] * emb_len
i += len(ids) - 1
found = True
used_custom_terms.append((word, self.hijack.word_embeddings_checksums[word]))
break
if not found:
remade_tokens.append(token)
multipliers.append(weight)
i += 1
if len(remade_tokens) > maxlen - 2:
vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
ovf = remade_tokens[maxlen - 2:]
overflowing_words = [vocab.get(int(x), "") for x in ovf]
overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
token_count = len(remade_tokens)
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end]
multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
return remade_tokens, fixes, multipliers, token_count
def process_text(self, texts):
used_custom_terms = []
remade_batch_tokens = [] remade_batch_tokens = []
hijack_comments = []
hijack_fixes = []
token_count = 0
cache = {}
batch_multipliers = []
for line in texts:
if line in cache:
remade_tokens, fixes, multipliers = cache[line]
else:
remade_tokens, fixes, multipliers, token_count = self.tokenize_line(line, used_custom_terms, hijack_comments)
cache[line] = (remade_tokens, fixes, multipliers)
remade_batch_tokens.append(remade_tokens)
hijack_fixes.append(fixes)
batch_multipliers.append(multipliers)
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
def process_text_old(self, text):
id_start = self.wrapped.tokenizer.bos_token_id id_start = self.wrapped.tokenizer.bos_token_id
id_end = self.wrapped.tokenizer.eos_token_id id_end = self.wrapped.tokenizer.eos_token_id
maxlen = self.wrapped.max_length maxlen = self.wrapped.max_length
used_custom_terms = [] used_custom_terms = []
remade_batch_tokens = []
overflowing_words = []
hijack_comments = []
hijack_fixes = []
token_count = 0
cache = {} cache = {}
batch_tokens = self.wrapped.tokenizer(text, truncation=False, add_special_tokens=False)["input_ids"] batch_tokens = self.wrapped.tokenizer(text, truncation=False, add_special_tokens=False)["input_ids"]
@ -353,9 +449,8 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
ovf = remade_tokens[maxlen - 2:] ovf = remade_tokens[maxlen - 2:]
overflowing_words = [vocab.get(int(x), "") for x in ovf] overflowing_words = [vocab.get(int(x), "") for x in ovf]
overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words)) overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
self.hijack.comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n") token_count = len(remade_tokens)
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens)) remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end] remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end]
cache[tuple_tokens] = (remade_tokens, fixes, multipliers) cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
@ -364,11 +459,23 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0] multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
remade_batch_tokens.append(remade_tokens) remade_batch_tokens.append(remade_tokens)
self.hijack.fixes.append(fixes) hijack_fixes.append(fixes)
batch_multipliers.append(multipliers) batch_multipliers.append(multipliers)
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
def forward(self, text):
if opts.use_old_emphasis_implementation:
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text)
else:
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text)
self.hijack.fixes = hijack_fixes
self.hijack.comments = hijack_comments
if len(used_custom_terms) > 0: if len(used_custom_terms) > 0:
self.hijack.comments.append("Used custom terms: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms])) self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
tokens = torch.asarray(remade_batch_tokens).to(device) tokens = torch.asarray(remade_batch_tokens).to(device)
outputs = self.wrapped.transformer(input_ids=tokens) outputs = self.wrapped.transformer(input_ids=tokens)

View file

@ -23,6 +23,10 @@ except Exception:
pass pass
def checkpoint_tiles():
return sorted([x.title for x in checkpoints_list.values()])
def list_models(): def list_models():
checkpoints_list.clear() checkpoints_list.clear()
@ -39,13 +43,14 @@ def list_models():
if name.startswith("\\") or name.startswith("/"): if name.startswith("\\") or name.startswith("/"):
name = name[1:] name = name[1:]
return f'{name} [{h}]' shortname = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
return f'{name} [{h}]', shortname
cmd_ckpt = shared.cmd_opts.ckpt cmd_ckpt = shared.cmd_opts.ckpt
if os.path.exists(cmd_ckpt): if os.path.exists(cmd_ckpt):
h = model_hash(cmd_ckpt) h = model_hash(cmd_ckpt)
title = modeltitle(cmd_ckpt, h) title, model_name = modeltitle(cmd_ckpt, h)
model_name = title.rsplit(".",1)[0] # remove extension if present
checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, model_name) checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, model_name)
elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file: elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
print(f"Checkpoint in --ckpt argument not found: {cmd_ckpt}", file=sys.stderr) print(f"Checkpoint in --ckpt argument not found: {cmd_ckpt}", file=sys.stderr)
@ -53,8 +58,7 @@ def list_models():
if os.path.exists(model_dir): if os.path.exists(model_dir):
for filename in glob.glob(model_dir + '/**/*.ckpt', recursive=True): for filename in glob.glob(model_dir + '/**/*.ckpt', recursive=True):
h = model_hash(filename) h = model_hash(filename)
title = modeltitle(filename, h) title, model_name = modeltitle(filename, h)
model_name = title.rsplit(".",1)[0] # remove extension if present
checkpoints_list[title] = CheckpointInfo(filename, title, h, model_name) checkpoints_list[title] = CheckpointInfo(filename, title, h, model_name)

View file

@ -40,10 +40,8 @@ samplers_for_img2img = [x for x in samplers if x.name != 'PLMS']
sampler_extra_params = { sampler_extra_params = {
'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'], 'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
'sample_euler_ancestral': ['eta'],
'sample_heun': ['s_churn', 's_tmin', 's_tmax', 's_noise'], 'sample_heun': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
'sample_dpm_2': ['s_churn', 's_tmin', 's_tmax', 's_noise'], 'sample_dpm_2': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
'sample_dpm_2_ancestral': ['eta'],
} }
def setup_img2img_steps(p, steps=None): def setup_img2img_steps(p, steps=None):
@ -101,6 +99,8 @@ class VanillaStableDiffusionSampler:
self.init_latent = None self.init_latent = None
self.sampler_noises = None self.sampler_noises = None
self.step = 0 self.step = 0
self.eta = None
self.default_eta = 0.0
def number_of_needed_noises(self, p): def number_of_needed_noises(self, p):
return 0 return 0
@ -123,20 +123,29 @@ class VanillaStableDiffusionSampler:
self.step += 1 self.step += 1
return res return res
def initialize(self, p):
self.eta = p.eta or opts.eta_ddim
for fieldname in ['p_sample_ddim', 'p_sample_plms']:
if hasattr(self.sampler, fieldname):
setattr(self.sampler, fieldname, self.p_sample_ddim_hook)
self.mask = p.mask if hasattr(p, 'mask') else None
self.nmask = p.nmask if hasattr(p, 'nmask') else None
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None): def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None):
steps, t_enc = setup_img2img_steps(p, steps) steps, t_enc = setup_img2img_steps(p, steps)
self.initialize(p)
# existing code fails with cetain step counts, like 9 # existing code fails with cetain step counts, like 9
try: try:
self.sampler.make_schedule(ddim_num_steps=steps, ddim_eta=p.ddim_eta, ddim_discretize=p.ddim_discretize, verbose=False) self.sampler.make_schedule(ddim_num_steps=steps, ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False)
except Exception: except Exception:
self.sampler.make_schedule(ddim_num_steps=steps+1,ddim_eta=p.ddim_eta, ddim_discretize=p.ddim_discretize, verbose=False) self.sampler.make_schedule(ddim_num_steps=steps+1, ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False)
x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise) x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise)
self.sampler.p_sample_ddim = self.p_sample_ddim_hook
self.mask = p.mask if hasattr(p, 'mask') else None
self.nmask = p.nmask if hasattr(p, 'nmask') else None
self.init_latent = x self.init_latent = x
self.step = 0 self.step = 0
@ -145,11 +154,8 @@ class VanillaStableDiffusionSampler:
return samples return samples
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None): def sample(self, p, x, conditioning, unconditional_conditioning, steps=None):
for fieldname in ['p_sample_ddim', 'p_sample_plms']: self.initialize(p)
if hasattr(self.sampler, fieldname):
setattr(self.sampler, fieldname, self.p_sample_ddim_hook)
self.mask = None
self.nmask = None
self.init_latent = None self.init_latent = None
self.step = 0 self.step = 0
@ -157,9 +163,9 @@ class VanillaStableDiffusionSampler:
# existing code fails with cetin step counts, like 9 # existing code fails with cetin step counts, like 9
try: try:
samples_ddim, _ = self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=p.eta) samples_ddim, _ = self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)
except Exception: except Exception:
samples_ddim, _ = self.sampler.sample(S=steps+1, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=p.eta) samples_ddim, _ = self.sampler.sample(S=steps+1, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)
return samples_ddim return samples_ddim
@ -237,6 +243,8 @@ class KDiffusionSampler:
self.sampler_noises = None self.sampler_noises = None
self.sampler_noise_index = 0 self.sampler_noise_index = 0
self.stop_at = None self.stop_at = None
self.eta = None
self.default_eta = 1.0
def callback_state(self, d): def callback_state(self, d):
store_latent(d["denoised"]) store_latent(d["denoised"])
@ -255,22 +263,12 @@ class KDiffusionSampler:
self.sampler_noise_index += 1 self.sampler_noise_index += 1
return res return res
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None): def initialize(self, p):
steps, t_enc = setup_img2img_steps(p, steps)
sigmas = self.model_wrap.get_sigmas(steps)
noise = noise * sigmas[steps - t_enc - 1]
xi = x + noise
sigma_sched = sigmas[steps - t_enc - 1:]
self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
self.model_wrap_cfg.init_latent = x
self.model_wrap.step = 0 self.model_wrap.step = 0
self.sampler_noise_index = 0 self.sampler_noise_index = 0
self.eta = p.eta or opts.eta_ancestral
if hasattr(k_diffusion.sampling, 'trange'): if hasattr(k_diffusion.sampling, 'trange'):
k_diffusion.sampling.trange = lambda *args, **kwargs: extended_trange(self, *args, **kwargs) k_diffusion.sampling.trange = lambda *args, **kwargs: extended_trange(self, *args, **kwargs)
@ -283,6 +281,25 @@ class KDiffusionSampler:
if hasattr(p, param_name) and param_name in inspect.signature(self.func).parameters: if hasattr(p, param_name) and param_name in inspect.signature(self.func).parameters:
extra_params_kwargs[param_name] = getattr(p, param_name) extra_params_kwargs[param_name] = getattr(p, param_name)
if 'eta' in inspect.signature(self.func).parameters:
extra_params_kwargs['eta'] = self.eta
return extra_params_kwargs
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None):
steps, t_enc = setup_img2img_steps(p, steps)
sigmas = self.model_wrap.get_sigmas(steps)
noise = noise * sigmas[steps - t_enc - 1]
xi = x + noise
extra_params_kwargs = self.initialize(p)
sigma_sched = sigmas[steps - t_enc - 1:]
self.model_wrap_cfg.init_latent = x
return self.func(self.model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs) return self.func(self.model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs)
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None): def sample(self, p, x, conditioning, unconditional_conditioning, steps=None):
@ -291,19 +308,7 @@ class KDiffusionSampler:
sigmas = self.model_wrap.get_sigmas(steps) sigmas = self.model_wrap.get_sigmas(steps)
x = x * sigmas[0] x = x * sigmas[0]
self.model_wrap_cfg.step = 0 extra_params_kwargs = self.initialize(p)
self.sampler_noise_index = 0
if hasattr(k_diffusion.sampling, 'trange'):
k_diffusion.sampling.trange = lambda *args, **kwargs: extended_trange(self, *args, **kwargs)
if self.sampler_noises is not None:
k_diffusion.sampling.torch = TorchHijack(self)
extra_params_kwargs = {}
for param_name in self.extra_params:
if hasattr(p, param_name) and param_name in inspect.signature(self.func).parameters:
extra_params_kwargs[param_name] = getattr(p, param_name)
samples = self.func(self.model_wrap_cfg, x, sigmas, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs) samples = self.func(self.model_wrap_cfg, x, sigmas, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs)

View file

@ -143,6 +143,7 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
"export_for_4chan": OptionInfo(True, "If PNG image is larger than 4MB or any dimension is larger than 4000, downscale and save copy as JPG"), "export_for_4chan": OptionInfo(True, "If PNG image is larger than 4MB or any dimension is larger than 4000, downscale and save copy as JPG"),
"use_original_name_batch": OptionInfo(False, "Use original name for output filename during batch process in extras tab"), "use_original_name_batch": OptionInfo(False, "Use original name for output filename during batch process in extras tab"),
"save_selected_only": OptionInfo(True, "When using 'Save' button, only save a single selected image"),
})) }))
options_templates.update(options_section(('saving-paths', "Paths for saving"), { options_templates.update(options_section(('saving-paths', "Paths for saving"), {
@ -180,7 +181,6 @@ options_templates.update(options_section(('face-restoration', "Face restoration"
"face_restoration_model": OptionInfo(None, "Face restoration model", gr.Radio, lambda: {"choices": [x.name() for x in face_restorers]}), "face_restoration_model": OptionInfo(None, "Face restoration model", gr.Radio, lambda: {"choices": [x.name() for x in face_restorers]}),
"code_former_weight": OptionInfo(0.5, "CodeFormer weight parameter; 0 = maximum effect; 1 = minimum effect", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}), "code_former_weight": OptionInfo(0.5, "CodeFormer weight parameter; 0 = maximum effect; 1 = minimum effect", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),
"face_restoration_unload": OptionInfo(False, "Move face restoration model from VRAM into RAM after processing"), "face_restoration_unload": OptionInfo(False, "Move face restoration model from VRAM into RAM after processing"),
"save_selected_only": OptionInfo(False, "When using 'Save' button, only save a single selected image"),
})) }))
options_templates.update(options_section(('system', "System"), { options_templates.update(options_section(('system', "System"), {
@ -190,12 +190,13 @@ options_templates.update(options_section(('system', "System"), {
})) }))
options_templates.update(options_section(('sd', "Stable Diffusion"), { options_templates.update(options_section(('sd', "Stable Diffusion"), {
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Radio, lambda: {"choices": [x.title for x in modules.sd_models.checkpoints_list.values()]}), "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Radio, lambda: {"choices": modules.sd_models.checkpoint_tiles()}),
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."), "img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
"save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"), "save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
"img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."), "img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."),
"enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."), "enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."),
"enable_emphasis": OptionInfo(True, "Use (text) to make model pay more attention to text and [text] to make it pay less attention"), "enable_emphasis": OptionInfo(True, "Eemphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"),
"use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."),
"enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"), "enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
"filter_nsfw": OptionInfo(False, "Filter NSFW content"), "filter_nsfw": OptionInfo(False, "Filter NSFW content"),
"random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}), "random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}),
@ -221,8 +222,9 @@ options_templates.update(options_section(('ui', "User interface"), {
})) }))
options_templates.update(options_section(('sampler-params', "Sampler parameters"), { options_templates.update(options_section(('sampler-params', "Sampler parameters"), {
"eta": OptionInfo(0.0, "DDIM and K Ancestral eta", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), "eta_ddim": OptionInfo(0.0, "eta (noise multiplier) for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
"ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform','quad']}), "eta_ancestral": OptionInfo(1.0, "eta (noise multiplier) for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
"ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}),
's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), 's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), 's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), 's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),

View file

@ -9,10 +9,12 @@ import random
import sys import sys
import time import time
import traceback import traceback
import platform
import subprocess as sp
import numpy as np import numpy as np
import torch import torch
from PIL import Image from PIL import Image, PngImagePlugin
import gradio as gr import gradio as gr
import gradio.utils import gradio.utils
@ -22,6 +24,7 @@ from modules.paths import script_path
from modules.shared import opts, cmd_opts from modules.shared import opts, cmd_opts
import modules.shared as shared import modules.shared as shared
from modules.sd_samplers import samplers, samplers_for_img2img from modules.sd_samplers import samplers, samplers_for_img2img
from modules.sd_hijack import model_hijack
import modules.ldsr_model import modules.ldsr_model
import modules.scripts import modules.scripts
import modules.gfpgan_model import modules.gfpgan_model
@ -61,7 +64,7 @@ random_symbol = '\U0001f3b2\ufe0f' # 🎲️
reuse_symbol = '\u267b\ufe0f' # ♻️ reuse_symbol = '\u267b\ufe0f' # ♻️
art_symbol = '\U0001f3a8' # 🎨 art_symbol = '\U0001f3a8' # 🎨
paste_symbol = '\u2199\ufe0f' # ↙ paste_symbol = '\u2199\ufe0f' # ↙
folder_symbol = '\uD83D\uDCC2'
def plaintext_to_html(text): def plaintext_to_html(text):
text = "<p>" + "<br>\n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "</p>" text = "<p>" + "<br>\n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "</p>"
@ -102,6 +105,7 @@ def save_files(js_data, images, index):
setattr(self, key, value) setattr(self, key, value)
data = json.loads(js_data) data = json.loads(js_data)
p = MyObject(data) p = MyObject(data)
path = opts.outdir_save path = opts.outdir_save
save_to_dirs = opts.save_to_dirs save_to_dirs = opts.save_to_dirs
@ -111,10 +115,14 @@ def save_files(js_data, images, index):
path = os.path.join(opts.outdir_save, dirname) path = os.path.join(opts.outdir_save, dirname)
os.makedirs(path, exist_ok=True) os.makedirs(path, exist_ok=True)
if index > -1 and opts.save_selected_only and (index > 0 or not opts.return_grid): # ensures we are looking at a specific non-grid picture, and we have save_selected_only
if index > -1 and opts.save_selected_only and (index >= data["index_of_first_image"]): # ensures we are looking at a specific non-grid picture, and we have save_selected_only
images = [images[index]] images = [images[index]]
data["seed"] += (index - 1 if opts.return_grid else index) infotexts = [data["infotexts"][index]]
else:
infotexts = data["infotexts"]
with open(os.path.join(opts.outdir_save, "log.csv"), "a", encoding="utf8", newline='') as file: with open(os.path.join(opts.outdir_save, "log.csv"), "a", encoding="utf8", newline='') as file:
at_start = file.tell() == 0 at_start = file.tell() == 0
@ -137,8 +145,11 @@ def save_files(js_data, images, index):
if filedata.startswith("data:image/png;base64,"): if filedata.startswith("data:image/png;base64,"):
filedata = filedata[len("data:image/png;base64,"):] filedata = filedata[len("data:image/png;base64,"):]
with open(filepath, "wb") as imgfile: pnginfo = PngImagePlugin.PngInfo()
imgfile.write(base64.decodebytes(filedata.encode('utf-8'))) pnginfo.add_text('parameters', infotexts[i])
image = Image.open(io.BytesIO(base64.decodebytes(filedata.encode('utf-8'))))
image.save(filepath, quality=opts.jpeg_quality, pnginfo=pnginfo)
filenames.append(filename) filenames.append(filename)
@ -350,6 +361,10 @@ def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info:
outputs=[seed, dummy_component] outputs=[seed, dummy_component]
) )
def update_token_counter(text):
tokens, token_count, max_length = model_hijack.tokenize(text)
style_class = ' class="red"' if (token_count > max_length) else ""
return f"<span {style_class}>{token_count}/{max_length}</span>"
def create_toprow(is_img2img): def create_toprow(is_img2img):
id_part = "img2img" if is_img2img else "txt2img" id_part = "img2img" if is_img2img else "txt2img"
@ -359,11 +374,14 @@ def create_toprow(is_img2img):
with gr.Row(): with gr.Row():
with gr.Column(scale=80): with gr.Column(scale=80):
with gr.Row(): with gr.Row():
prompt = gr.Textbox(label="Prompt", elem_id="prompt", show_label=False, placeholder="Prompt", lines=2) prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, placeholder="Prompt", lines=2)
with gr.Column(scale=1, elem_id="roll_col"): with gr.Column(scale=1, elem_id="roll_col"):
roll = gr.Button(value=art_symbol, elem_id="roll", visible=len(shared.artist_db.artists) > 0) roll = gr.Button(value=art_symbol, elem_id="roll", visible=len(shared.artist_db.artists) > 0)
paste = gr.Button(value=paste_symbol, elem_id="paste") paste = gr.Button(value=paste_symbol, elem_id="paste")
token_counter = gr.HTML(value="<span></span>", elem_id=f"{id_part}_token_counter")
hidden_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
hidden_button.click(fn=update_token_counter, inputs=[prompt], outputs=[token_counter])
with gr.Column(scale=10, elem_id="style_pos_col"): with gr.Column(scale=10, elem_id="style_pos_col"):
prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())), visible=len(shared.prompt_styles.styles) > 1) prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())), visible=len(shared.prompt_styles.styles) > 1)
@ -470,6 +488,8 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
send_to_img2img = gr.Button('Send to img2img') send_to_img2img = gr.Button('Send to img2img')
send_to_inpaint = gr.Button('Send to inpaint') send_to_inpaint = gr.Button('Send to inpaint')
send_to_extras = gr.Button('Send to extras') send_to_extras = gr.Button('Send to extras')
button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else 'open_folder'
open_txt2img_folder = gr.Button(folder_symbol, elem_id=button_id)
with gr.Group(): with gr.Group():
html_info = gr.HTML() html_info = gr.HTML()
@ -646,6 +666,8 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
img2img_send_to_img2img = gr.Button('Send to img2img') img2img_send_to_img2img = gr.Button('Send to img2img')
img2img_send_to_inpaint = gr.Button('Send to inpaint') img2img_send_to_inpaint = gr.Button('Send to inpaint')
img2img_send_to_extras = gr.Button('Send to extras') img2img_send_to_extras = gr.Button('Send to extras')
button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else 'open_folder'
open_img2img_folder = gr.Button(folder_symbol, elem_id=button_id)
with gr.Group(): with gr.Group():
html_info = gr.HTML() html_info = gr.HTML()
@ -818,6 +840,8 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
html_info = gr.HTML() html_info = gr.HTML()
extras_send_to_img2img = gr.Button('Send to img2img') extras_send_to_img2img = gr.Button('Send to img2img')
extras_send_to_inpaint = gr.Button('Send to inpaint') extras_send_to_inpaint = gr.Button('Send to inpaint')
button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else ''
open_extras_folder = gr.Button('Open output directory', elem_id=button_id)
submit.click( submit.click(
fn=run_extras, fn=run_extras,
@ -878,32 +902,20 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
with gr.Blocks() as modelmerger_interface: with gr.Blocks() as modelmerger_interface:
with gr.Row().style(equal_height=False): with gr.Row().style(equal_height=False):
with gr.Column(variant='panel'): with gr.Column(variant='panel'):
gr.HTML(value="<p>A merger of the two checkpoints will be generated in your <b>/models</b> directory.</p>") gr.HTML(value="<p>A merger of the two checkpoints will be generated in your <b>checkpoint</b> directory.</p>")
with gr.Row(): with gr.Row():
ckpt_name_list = sorted([x.model_name for x in modules.sd_models.checkpoints_list.values()]) primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary Model Name")
primary_model_name = gr.Dropdown(ckpt_name_list, elem_id="modelmerger_primary_model_name", label="Primary Model Name") secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary Model Name")
secondary_model_name = gr.Dropdown(ckpt_name_list, elem_id="modelmerger_secondary_model_name", label="Secondary Model Name") custom_name = gr.Textbox(label="Custom Name (Optional)")
interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Interpolation Amount', value=0.3) interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Interpolation Amount', value=0.3)
interp_method = gr.Radio(choices=["Weighted Sum", "Sigmoid"], value="Weighted Sum", label="Interpolation Method") interp_method = gr.Radio(choices=["Weighted Sum", "Sigmoid", "Inverse Sigmoid"], value="Weighted Sum", label="Interpolation Method")
submit = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary') save_as_half = gr.Checkbox(value=False, label="Safe as float16")
modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary')
with gr.Column(variant='panel'): with gr.Column(variant='panel'):
submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False) submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False)
submit.click(
fn=run_modelmerger,
inputs=[
primary_model_name,
secondary_model_name,
interp_method,
interp_amount
],
outputs=[
submit_result,
]
)
def create_setting_component(key): def create_setting_component(key):
def fun(): def fun():
return opts.data[key] if key in opts.data else opts.data_labels[key].default return opts.data[key] if key in opts.data else opts.data_labels[key].default
@ -927,6 +939,17 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
return comp(label=info.label, value=fun, **(args or {})) return comp(label=info.label, value=fun, **(args or {}))
components = [] components = []
component_dict = {}
def open_folder(f):
if not shared.cmd_opts.hide_ui_dir_config:
path = os.path.normpath(f)
if platform.system() == "Windows":
os.startfile(path)
elif platform.system() == "Darwin":
sp.Popen(["open", path])
else:
sp.Popen(["xdg-open", path])
def run_settings(*args): def run_settings(*args):
changed = 0 changed = 0
@ -982,7 +1005,9 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
gr.HTML(elem_id="settings_header_text_{}".format(item.section[0]), value='<h1 class="gr-button-lg">{}</h1>'.format(item.section[1])) gr.HTML(elem_id="settings_header_text_{}".format(item.section[0]), value='<h1 class="gr-button-lg">{}</h1>'.format(item.section[1]))
components.append(create_setting_component(k)) component = create_setting_component(k)
component_dict[k] = component
components.append(component)
items_displayed += 1 items_displayed += 1
request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications") request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications")
@ -1032,7 +1057,34 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
inputs=components, inputs=components,
outputs=[result, text_settings], outputs=[result, text_settings],
) )
def modelmerger(*args):
try:
results = run_modelmerger(*args)
except Exception as e:
print("Error loading/saving model file:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
modules.sd_models.list_models() #To remove the potentially missing models from the list
return ["Error loading/saving model file. It doesn't exist or the name contains illegal characters"] + [gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(3)]
return results
modelmerger_merge.click(
fn=modelmerger,
inputs=[
primary_model_name,
secondary_model_name,
interp_method,
interp_amount,
save_as_half,
custom_name,
],
outputs=[
submit_result,
primary_model_name,
secondary_model_name,
component_dict['sd_model_checkpoint'],
]
)
paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration', 'Seed', 'Size-1', 'Size-2'] paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration', 'Seed', 'Size-1', 'Size-2']
txt2img_fields = [field for field,name in txt2img_paste_fields if name in paste_field_names] txt2img_fields = [field for field,name in txt2img_paste_fields if name in paste_field_names]
img2img_fields = [field for field,name in img2img_paste_fields if name in paste_field_names] img2img_fields = [field for field,name in img2img_paste_fields if name in paste_field_names]
@ -1071,6 +1123,24 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
outputs=[extras_image], outputs=[extras_image],
) )
open_txt2img_folder.click(
fn=lambda: open_folder(opts.outdir_samples or opts.outdir_txt2img_samples),
inputs=[],
outputs=[],
)
open_img2img_folder.click(
fn=lambda: open_folder(opts.outdir_samples or opts.outdir_img2img_samples),
inputs=[],
outputs=[],
)
open_extras_folder.click(
fn=lambda: open_folder(opts.outdir_samples or opts.outdir_extras_samples),
inputs=[],
outputs=[],
)
img2img_send_to_extras.click( img2img_send_to_extras.click(
fn=lambda x: image_from_url_text(x), fn=lambda x: image_from_url_text(x),
_js="extract_image_from_gallery_extras", _js="extract_image_from_gallery_extras",

View file

@ -6,7 +6,6 @@ font-roboto
gfpgan gfpgan
gradio gradio
invisible-watermark invisible-watermark
git+https://github.com/crowsonkb/k-diffusion.git
numpy numpy
omegaconf omegaconf
piexif piexif
@ -16,5 +15,12 @@ realesrgan
scikit-image>=0.19 scikit-image>=0.19
git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379 git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379
timm==0.4.12 timm==0.4.12
transformers transformers==4.19.2
torch torch
einops
jsonmerge
clean-fid
git+https://github.com/openai/CLIP@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1
resize-right
torchdiffeq
kornia

View file

@ -14,4 +14,11 @@ fonts
font-roboto font-roboto
timm==0.6.7 timm==0.6.7
fairscale==0.4.9 fairscale==0.4.9
piexif==1.1.3 piexif==1.1.3
einops==0.4.1
jsonmerge==1.8.0
clean-fid==0.1.29
git+https://github.com/openai/CLIP@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1
resize-right==0.0.2
torchdiffeq==0.2.3
kornia==0.6.7

View file

@ -87,12 +87,12 @@ axis_options = [
AxisOption("Prompt S/R", str, apply_prompt, format_value), AxisOption("Prompt S/R", str, apply_prompt, format_value),
AxisOption("Sampler", str, apply_sampler, format_value), AxisOption("Sampler", str, apply_sampler, format_value),
AxisOption("Checkpoint name", str, apply_checkpoint, format_value), AxisOption("Checkpoint name", str, apply_checkpoint, format_value),
AxisOption("Sigma Churn", float, apply_field("s_churn"), format_value_add_label), AxisOption("Sigma Churn", float, apply_field("s_churn"), format_value_add_label),
AxisOption("Sigma min", float, apply_field("s_tmin"), format_value_add_label), AxisOption("Sigma min", float, apply_field("s_tmin"), format_value_add_label),
AxisOption("Sigma max", float, apply_field("s_tmax"), format_value_add_label), AxisOption("Sigma max", float, apply_field("s_tmax"), format_value_add_label),
AxisOption("Sigma noise", float, apply_field("s_noise"), format_value_add_label), AxisOption("Sigma noise", float, apply_field("s_noise"), format_value_add_label),
AxisOption("DDIM Eta", float, apply_field("ddim_eta"), format_value_add_label), AxisOption("Eta", float, apply_field("eta"), format_value_add_label),
AxisOptionImg2Img("Denoising", float, apply_field("denoising_strength"), format_value_add_label),# as it is now all AxisOptionImg2Img items must go after AxisOption ones AxisOptionImg2Img("Denoising", float, apply_field("denoising_strength"), format_value_add_label), # as it is now all AxisOptionImg2Img items must go after AxisOption ones
] ]

View file

@ -1,5 +1,11 @@
.output-html p {margin: 0 0.5em;} .output-html p {margin: 0 0.5em;}
.row > *,
.row > .gr-form > * {
min-width: min(120px, 100%);
flex: 1 1 0%;
}
.performance { .performance {
font-size: 0.85em; font-size: 0.85em;
color: #444; color: #444;
@ -43,13 +49,17 @@
margin-right: auto; margin-right: auto;
} }
#random_seed, #random_subseed, #reuse_seed, #reuse_subseed{ #random_seed, #random_subseed, #reuse_seed, #reuse_subseed, #open_folder{
min-width: auto; min-width: auto;
flex-grow: 0; flex-grow: 0;
padding-left: 0.25em; padding-left: 0.25em;
padding-right: 0.25em; padding-right: 0.25em;
} }
#hidden_element{
display: none;
}
#seed_row, #subseed_row{ #seed_row, #subseed_row{
gap: 0.5rem; gap: 0.5rem;
} }
@ -389,3 +399,7 @@ input[type="range"]{
border-radius: 8px; border-radius: 8px;
display: none; display: none;
} }
.red {
color: red;
}

View file

@ -1,6 +1,7 @@
import os import os
import threading import threading
from modules import devices
from modules.paths import script_path from modules.paths import script_path
import signal import signal
@ -47,6 +48,8 @@ def wrap_queued_call(func):
def wrap_gradio_gpu_call(func): def wrap_gradio_gpu_call(func):
def f(*args, **kwargs): def f(*args, **kwargs):
devices.torch_gc()
shared.state.sampling_step = 0 shared.state.sampling_step = 0
shared.state.job_count = -1 shared.state.job_count = -1
shared.state.job_no = 0 shared.state.job_no = 0
@ -62,6 +65,8 @@ def wrap_gradio_gpu_call(func):
shared.state.job = "" shared.state.job = ""
shared.state.job_count = 0 shared.state.job_count = 0
devices.torch_gc()
return res return res
return modules.ui.wrap_gradio_call(f) return modules.ui.wrap_gradio_call(f)