Merge branch 'master' into image_info_tab

This commit is contained in:
AUTOMATIC1111 2022-09-17 14:57:10 +03:00 committed by GitHub
commit 0d7fdb1791
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
15 changed files with 773 additions and 95 deletions

View file

@ -51,7 +51,7 @@ Alternatively, use [Google Colab](https://colab.research.google.com/drive/1Iy-xW
1. Install [Python 3.10.6](https://www.python.org/downloads/windows/), checking "Add Python to PATH" 1. Install [Python 3.10.6](https://www.python.org/downloads/windows/), checking "Add Python to PATH"
2. Install [git](https://git-scm.com/download/win). 2. Install [git](https://git-scm.com/download/win).
3. Download the stable-diffusion-webui repository, for example by running `git clone https://github.com/AUTOMATIC1111/stable-diffusion-webui.git`. 3. Download the stable-diffusion-webui repository, for example by running `git clone https://github.com/AUTOMATIC1111/stable-diffusion-webui.git`.
4. Place `model.ckpt` in the base directory, alongside `webui.py`. 4. Place `model.ckpt` in the `models` directory.
5. _*(Optional)*_ Place `GFPGANv1.3.pth` in the base directory, alongside `webui.py`. 5. _*(Optional)*_ Place `GFPGANv1.3.pth` in the base directory, alongside `webui.py`.
6. Run `webui-user.bat` from Windows Explorer as normal, non-administrate, user. 6. Run `webui-user.bat` from Windows Explorer as normal, non-administrate, user.
@ -81,6 +81,7 @@ The documentation was moved from this README over to the project's [wiki](https:
- Ideas for optimizations - https://github.com/basujindal/stable-diffusion - Ideas for optimizations - https://github.com/basujindal/stable-diffusion
- Doggettx - Cross Attention layer optimization - https://github.com/Doggettx/stable-diffusion, original idea for prompt editing. - Doggettx - Cross Attention layer optimization - https://github.com/Doggettx/stable-diffusion, original idea for prompt editing.
- Idea for SD upscale - https://github.com/jquesnelle/txt2imghd - Idea for SD upscale - https://github.com/jquesnelle/txt2imghd
- Noise generation for outpainting mk2 - https://github.com/parlance-zz/g-diffuser-bot
- CLIP interrogator idea and borrowing some code - https://github.com/pharmapsychotic/clip-interrogator - CLIP interrogator idea and borrowing some code - https://github.com/pharmapsychotic/clip-interrogator
- Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user. - Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user.
- (You) - (You)

View file

@ -274,7 +274,7 @@ def apply_filename_pattern(x, p, seed, prompt):
x = x.replace("[height]", str(p.height)) x = x.replace("[height]", str(p.height))
x = x.replace("[sampler]", sd_samplers.samplers[p.sampler_index].name) x = x.replace("[sampler]", sd_samplers.samplers[p.sampler_index].name)
x = x.replace("[model_hash]", shared.sd_model_hash) x = x.replace("[model_hash]", shared.sd_model.sd_model_hash)
x = x.replace("[date]", datetime.date.today().isoformat()) x = x.replace("[date]", datetime.date.today().isoformat())
if cmd_opts.hide_ui_dir_config: if cmd_opts.hide_ui_dir_config:
@ -353,13 +353,12 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
}) })
if extension.lower() in ("jpg", "jpeg", "webp"): if extension.lower() in ("jpg", "jpeg", "webp"):
image.save(fullfn, quality=opts.jpeg_quality, exif_bytes=exif_bytes()) image.save(fullfn, quality=opts.jpeg_quality)
if opts.enable_pnginfo and info is not None:
piexif.insert(exif_bytes(), fullfn)
else: else:
image.save(fullfn, quality=opts.jpeg_quality, pnginfo=pnginfo) image.save(fullfn, quality=opts.jpeg_quality, pnginfo=pnginfo)
if extension.lower() == "webp":
piexif.insert(exif_bytes, fullfn)
target_side_length = 4000 target_side_length = 4000
oversize = image.width > target_side_length or image.height > target_side_length oversize = image.width > target_side_length or image.height > target_side_length
if opts.export_for_4chan and (oversize or os.stat(fullfn).st_size > 4 * 1024 * 1024): if opts.export_for_4chan and (oversize or os.stat(fullfn).st_size > 4 * 1024 * 1024):
@ -370,7 +369,9 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
elif oversize: elif oversize:
image = image.resize((image.width * target_side_length // image.height, target_side_length), LANCZOS) image = image.resize((image.width * target_side_length // image.height, target_side_length), LANCZOS)
image.save(fullfn_without_extension + ".jpg", quality=opts.jpeg_quality, exif_bytes=exif_bytes()) image.save(fullfn_without_extension + ".jpg", quality=opts.jpeg_quality)
if opts.enable_pnginfo and info is not None:
piexif.insert(exif_bytes(), fullfn)
if opts.save_txt and info is not None: if opts.save_txt and info is not None:
with open(f"{fullfn_without_extension}.txt", "w", encoding="utf8") as file: with open(f"{fullfn_without_extension}.txt", "w", encoding="utf8") as file:

77
modules/memmon.py Normal file
View file

@ -0,0 +1,77 @@
import threading
import time
from collections import defaultdict
import torch
class MemUsageMonitor(threading.Thread):
run_flag = None
device = None
disabled = False
opts = None
data = None
def __init__(self, name, device, opts):
threading.Thread.__init__(self)
self.name = name
self.device = device
self.opts = opts
self.daemon = True
self.run_flag = threading.Event()
self.data = defaultdict(int)
def run(self):
if self.disabled:
return
while True:
self.run_flag.wait()
torch.cuda.reset_peak_memory_stats()
self.data.clear()
if self.opts.memmon_poll_rate <= 0:
self.run_flag.clear()
continue
self.data["min_free"] = torch.cuda.mem_get_info()[0]
while self.run_flag.is_set():
free, total = torch.cuda.mem_get_info() # calling with self.device errors, torch bug?
self.data["min_free"] = min(self.data["min_free"], free)
time.sleep(1 / self.opts.memmon_poll_rate)
def dump_debug(self):
print(self, 'recorded data:')
for k, v in self.read().items():
print(k, -(v // -(1024 ** 2)))
print(self, 'raw torch memory stats:')
tm = torch.cuda.memory_stats(self.device)
for k, v in tm.items():
if 'bytes' not in k:
continue
print('\t' if 'peak' in k else '', k, -(v // -(1024 ** 2)))
print(torch.cuda.memory_summary())
def monitor(self):
self.run_flag.set()
def read(self):
free, total = torch.cuda.mem_get_info()
self.data["total"] = total
torch_stats = torch.cuda.memory_stats(self.device)
self.data["active_peak"] = torch_stats["active_bytes.all.peak"]
self.data["reserved_peak"] = torch_stats["reserved_bytes.all.peak"]
self.data["system_peak"] = total - self.data["min_free"]
return self.data
def stop(self):
self.run_flag.clear()
return self.read()

View file

@ -188,7 +188,11 @@ def fix_seed(p):
def process_images(p: StableDiffusionProcessing) -> Processed: def process_images(p: StableDiffusionProcessing) -> Processed:
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch""" """this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
if type(p.prompt) == list:
assert(len(p.prompt) > 0)
else:
assert p.prompt is not None assert p.prompt is not None
devices.torch_gc() devices.torch_gc()
fix_seed(p) fix_seed(p)
@ -227,7 +231,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
"Seed": all_seeds[index], "Seed": all_seeds[index],
"Face restoration": (opts.face_restoration_model if p.restore_faces else None), "Face restoration": (opts.face_restoration_model if p.restore_faces else None),
"Size": f"{p.width}x{p.height}", "Size": f"{p.width}x{p.height}",
"Model hash": (None if not opts.add_model_hash_to_info or not shared.sd_model_hash else shared.sd_model_hash), "Model hash": (None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
"Batch size": (None if p.batch_size < 2 else p.batch_size), "Batch size": (None if p.batch_size < 2 else p.batch_size),
"Batch pos": (None if p.batch_size < 2 else position_in_batch), "Batch pos": (None if p.batch_size < 2 else position_in_batch),
"Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]), "Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
@ -265,6 +269,9 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
seeds = all_seeds[n * p.batch_size:(n + 1) * p.batch_size] seeds = all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
subseeds = all_subseeds[n * p.batch_size:(n + 1) * p.batch_size] subseeds = all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
if (len(prompts) == 0):
break
#uc = p.sd_model.get_learned_conditioning(len(prompts) * [p.negative_prompt]) #uc = p.sd_model.get_learned_conditioning(len(prompts) * [p.negative_prompt])
#c = p.sd_model.get_learned_conditioning(prompts) #c = p.sd_model.get_learned_conditioning(prompts)
uc = prompt_parser.get_learned_conditioning(len(prompts) * [p.negative_prompt], p.steps) uc = prompt_parser.get_learned_conditioning(len(prompts) * [p.negative_prompt], p.steps)

148
modules/sd_models.py Normal file
View file

@ -0,0 +1,148 @@
import glob
import os.path
import sys
from collections import namedtuple
import torch
from omegaconf import OmegaConf
from ldm.util import instantiate_from_config
from modules import shared
CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash'])
checkpoints_list = {}
try:
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
from transformers import logging
logging.set_verbosity_error()
except Exception:
pass
def list_models():
checkpoints_list.clear()
model_dir = os.path.abspath(shared.cmd_opts.ckpt_dir)
def modeltitle(path, h):
abspath = os.path.abspath(path)
if abspath.startswith(model_dir):
name = abspath.replace(model_dir, '')
else:
name = os.path.basename(path)
if name.startswith("\\") or name.startswith("/"):
name = name[1:]
return f'{name} [{h}]'
cmd_ckpt = shared.cmd_opts.ckpt
if os.path.exists(cmd_ckpt):
h = model_hash(cmd_ckpt)
title = modeltitle(cmd_ckpt, h)
checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h)
elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
print(f"Checkpoint in --ckpt argument not found: {cmd_ckpt}", file=sys.stderr)
if os.path.exists(model_dir):
for filename in glob.glob(model_dir + '/**/*.ckpt', recursive=True):
h = model_hash(filename)
title = modeltitle(filename, h)
checkpoints_list[title] = CheckpointInfo(filename, title, h)
def model_hash(filename):
try:
with open(filename, "rb") as file:
import hashlib
m = hashlib.sha256()
file.seek(0x100000)
m.update(file.read(0x10000))
return m.hexdigest()[0:8]
except FileNotFoundError:
return 'NOFILE'
def select_checkpoint():
model_checkpoint = shared.opts.sd_model_checkpoint
checkpoint_info = checkpoints_list.get(model_checkpoint, None)
if checkpoint_info is not None:
return checkpoint_info
if len(checkpoints_list) == 0:
print(f"Checkpoint {model_checkpoint} not found and no other checkpoints found", file=sys.stderr)
return None
checkpoint_info = next(iter(checkpoints_list.values()))
if model_checkpoint is not None:
print(f"Checkpoint {model_checkpoint} not found; loading fallback {checkpoint_info.title}", file=sys.stderr)
return checkpoint_info
def load_model_weights(model, checkpoint_file, sd_model_hash):
print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}")
pl_sd = torch.load(checkpoint_file, map_location="cpu")
if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}")
sd = pl_sd["state_dict"]
model.load_state_dict(sd, strict=False)
if shared.cmd_opts.opt_channelslast:
model.to(memory_format=torch.channels_last)
if not shared.cmd_opts.no_half:
model.half()
model.sd_model_hash = sd_model_hash
model.sd_model_checkpint = checkpoint_file
def load_model():
from modules import lowvram, sd_hijack
checkpoint_info = select_checkpoint()
sd_config = OmegaConf.load(shared.cmd_opts.config)
sd_model = instantiate_from_config(sd_config.model)
load_model_weights(sd_model, checkpoint_info.filename, checkpoint_info.hash)
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram)
else:
sd_model.to(shared.device)
sd_hijack.model_hijack.hijack(sd_model)
sd_model.eval()
print(f"Model loaded.")
return sd_model
def reload_model_weights(sd_model, info=None):
from modules import lowvram, devices
checkpoint_info = info or select_checkpoint()
if sd_model.sd_model_checkpint == checkpoint_info.filename:
return
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
lowvram.send_everything_to_cpu()
else:
sd_model.to(devices.cpu)
load_model_weights(sd_model, checkpoint_info.filename, checkpoint_info.hash)
if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
sd_model.to(devices.device)
print(f"Weights loaded.")
return sd_model

View file

@ -12,14 +12,16 @@ from modules.paths import script_path, sd_path
from modules.devices import get_optimal_device from modules.devices import get_optimal_device
import modules.styles import modules.styles
import modules.interrogate import modules.interrogate
import modules.memmon
import modules.sd_models
sd_model_file = os.path.join(script_path, 'model.ckpt') sd_model_file = os.path.join(script_path, 'model.ckpt')
if not os.path.exists(sd_model_file): default_sd_model_file = sd_model_file
sd_model_file = "models/ldm/stable-diffusion-v1/model.ckpt"
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default=os.path.join(sd_path, "configs/stable-diffusion/v1-inference.yaml"), help="path to config which constructs model",) parser.add_argument("--config", type=str, default=os.path.join(sd_path, "configs/stable-diffusion/v1-inference.yaml"), help="path to config which constructs model",)
parser.add_argument("--ckpt", type=str, default=os.path.join(sd_path, sd_model_file), help="path to checkpoint of model",) parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; this checkpoint will be added to the list of checkpoints and loaded by default if you don't have a checkpoint selected in settings",)
parser.add_argument("--ckpt-dir", type=str, default=os.path.join(script_path, 'models'), help="path to directory with stable diffusion checkpoints",)
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN')) parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default='GFPGANv1.3.pth') parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default='GFPGANv1.3.pth')
parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats") parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats")
@ -87,13 +89,17 @@ interrogator = modules.interrogate.InterrogateModels("interrogate")
face_restorers = [] face_restorers = []
modules.sd_models.list_models()
class Options: class Options:
class OptionInfo: class OptionInfo:
def __init__(self, default=None, label="", component=None, component_args=None): def __init__(self, default=None, label="", component=None, component_args=None, onchange=None):
self.default = default self.default = default
self.label = label self.label = label
self.component = component self.component = component
self.component_args = component_args self.component_args = component_args
self.onchange = onchange
data = None data = None
hide_dirs = {"visible": False} if cmd_opts.hide_ui_dir_config else None hide_dirs = {"visible": False} if cmd_opts.hide_ui_dir_config else None
@ -138,6 +144,7 @@ class Options:
"show_progressbar": OptionInfo(True, "Show progressbar"), "show_progressbar": OptionInfo(True, "Show progressbar"),
"show_progress_every_n_steps": OptionInfo(0, "Show show image creation progress every N sampling steps. Set 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 32, "step": 1}), "show_progress_every_n_steps": OptionInfo(0, "Show show image creation progress every N sampling steps. Set 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 32, "step": 1}),
"multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job. Broken in PyCharm console."), "multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job. Broken in PyCharm console."),
"memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation. Set to 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 40, "step":1}),
"face_restoration_model": OptionInfo(None, "Face restoration model", gr.Radio, lambda: {"choices": [x.name() for x in face_restorers]}), "face_restoration_model": OptionInfo(None, "Face restoration model", gr.Radio, lambda: {"choices": [x.name() for x in face_restorers]}),
"code_former_weight": OptionInfo(0.5, "CodeFormer weight parameter; 0 = maximum effect; 1 = minimum effect", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}), "code_former_weight": OptionInfo(0.5, "CodeFormer weight parameter; 0 = maximum effect; 1 = minimum effect", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),
"save_images_before_face_restoration": OptionInfo(False, "Save a copy of image before doing face restoration."), "save_images_before_face_restoration": OptionInfo(False, "Save a copy of image before doing face restoration."),
@ -148,6 +155,7 @@ class Options:
"interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum description length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}), "interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum description length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}),
"interrogate_clip_max_length": OptionInfo(48, "Interrogate: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}), "interrogate_clip_max_length": OptionInfo(48, "Interrogate: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}),
"interrogate_clip_dict_limit": OptionInfo(1500, "Interrogate: maximum number of lines in text file (0 = No limit)"), "interrogate_clip_dict_limit": OptionInfo(1500, "Interrogate: maximum number of lines in text file (0 = No limit)"),
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Radio, lambda: {"choices": [x.title for x in modules.sd_models.checkpoints_list.values()]}),
} }
def __init__(self): def __init__(self):
@ -178,6 +186,10 @@ class Options:
with open(filename, "r", encoding="utf8") as file: with open(filename, "r", encoding="utf8") as file:
self.data = json.load(file) self.data = json.load(file)
def onchange(self, key, func):
item = self.data_labels.get(key)
item.onchange = func
opts = Options() opts = Options()
if os.path.exists(config_filename): if os.path.exists(config_filename):
@ -186,7 +198,6 @@ if os.path.exists(config_filename):
sd_upscalers = [] sd_upscalers = []
sd_model = None sd_model = None
sd_model_hash = ''
progress_print_out = sys.stdout progress_print_out = sys.stdout
@ -217,3 +228,6 @@ class TotalTQDM:
total_tqdm = TotalTQDM() total_tqdm = TotalTQDM()
mem_mon = modules.memmon.MemUsageMonitor("MemMon", device, opts)
mem_mon.start()

View file

@ -119,6 +119,7 @@ def save_files(js_data, images, index):
def wrap_gradio_call(func): def wrap_gradio_call(func):
def f(*args, **kwargs): def f(*args, **kwargs):
shared.mem_mon.monitor()
t = time.perf_counter() t = time.perf_counter()
try: try:
@ -135,8 +136,20 @@ def wrap_gradio_call(func):
elapsed = time.perf_counter() - t elapsed = time.perf_counter() - t
mem_stats = {k: -(v//-(1024*1024)) for k,v in shared.mem_mon.stop().items()}
active_peak = mem_stats['active_peak']
reserved_peak = mem_stats['reserved_peak']
sys_peak = '?' if opts.memmon_poll_rate <= 0 else mem_stats['system_peak']
sys_total = mem_stats['total']
sys_pct = '?' if opts.memmon_poll_rate <= 0 else round(sys_peak/sys_total * 100, 2)
vram_tooltip = "Torch active: Peak amount of VRAM used by Torch during generation, excluding cached data.&#013;" \
"Torch reserved: Peak amount of VRAM allocated by Torch, including all active and cached data.&#013;" \
"Sys VRAM: Peak amount of VRAM allocation across all applications / total GPU VRAM (peak utilization%)."
vram_html = '' if opts.memmon_poll_rate == 0 else f"<p class='vram' title='{vram_tooltip}'>Torch active/reserved: {active_peak}/{reserved_peak} MiB, <wbr>Sys VRAM: {sys_peak}/{sys_total} MiB ({sys_pct}%)</p>"
# last item is always HTML # last item is always HTML
res[-1] = res[-1] + f"<p class='performance'>Time taken: {elapsed:.2f}s</p>" res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr>{elapsed:.2f}s</p>{vram_html}</div>"
shared.state.interrupted = False shared.state.interrupted = False
@ -324,6 +337,8 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
custom_inputs = modules.scripts.scripts_txt2img.setup_ui(is_img2img=False) custom_inputs = modules.scripts.scripts_txt2img.setup_ui(is_img2img=False)
with gr.Column(variant='panel'): with gr.Column(variant='panel'):
progressbar = gr.HTML(elem_id="progressbar")
with gr.Group(): with gr.Group():
txt2img_preview = gr.Image(elem_id='txt2img_preview', visible=False) txt2img_preview = gr.Image(elem_id='txt2img_preview', visible=False)
txt2img_gallery = gr.Gallery(label='Output', elem_id='txt2img_gallery').style(grid=4) txt2img_gallery = gr.Gallery(label='Output', elem_id='txt2img_gallery').style(grid=4)
@ -336,8 +351,6 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
send_to_extras = gr.Button('Send to extras') send_to_extras = gr.Button('Send to extras')
interrupt = gr.Button('Interrupt') interrupt = gr.Button('Interrupt')
progressbar = gr.HTML(elem_id="progressbar")
with gr.Group(): with gr.Group():
html_info = gr.HTML() html_info = gr.HTML()
generation_info = gr.Textbox(visible=False) generation_info = gr.Textbox(visible=False)
@ -461,6 +474,8 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
custom_inputs = modules.scripts.scripts_img2img.setup_ui(is_img2img=True) custom_inputs = modules.scripts.scripts_img2img.setup_ui(is_img2img=True)
with gr.Column(variant='panel'): with gr.Column(variant='panel'):
progressbar = gr.HTML(elem_id="progressbar")
with gr.Group(): with gr.Group():
img2img_preview = gr.Image(elem_id='img2img_preview', visible=False) img2img_preview = gr.Image(elem_id='img2img_preview', visible=False)
img2img_gallery = gr.Gallery(label='Output', elem_id='img2img_gallery').style(grid=4) img2img_gallery = gr.Gallery(label='Output', elem_id='img2img_gallery').style(grid=4)
@ -474,7 +489,6 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
interrupt = gr.Button('Interrupt') interrupt = gr.Button('Interrupt')
img2img_save_style = gr.Button('Save prompt as style') img2img_save_style = gr.Button('Save prompt as style')
progressbar = gr.HTML(elem_id="progressbar")
with gr.Group(): with gr.Group():
html_info = gr.HTML() html_info = gr.HTML()
@ -745,7 +759,12 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
if comp_args and isinstance(comp_args, dict) and comp_args.get('visible') is False: if comp_args and isinstance(comp_args, dict) and comp_args.get('visible') is False:
continue continue
oldval = opts.data.get(key, None)
opts.data[key] = value opts.data[key] = value
if oldval != value and opts.data_labels[key].onchange is not None:
opts.data_labels[key].onchange()
up.append(comp.update(value=value)) up.append(comp.update(value=value))
opts.save(shared.config_filename) opts.save(shared.config_filename)

View file

@ -66,6 +66,8 @@ titles = {
"Style 2": "Style to apply; styles have components for both positive and negative prompts and apply to both", "Style 2": "Style to apply; styles have components for both positive and negative prompts and apply to both",
"Apply style": "Insert selected styles into prompt fields", "Apply style": "Insert selected styles into prompt fields",
"Create style": "Save current prompts as a style. If you add the token {prompt} to the text, the style use that as placeholder for your prompt when you use the style in the future.", "Create style": "Save current prompts as a style. If you add the token {prompt} to the text, the style use that as placeholder for your prompt when you use the style in the future.",
"Checkpoint name": "Loads weights from checkpoint before making images. You can either use hash or a part of filename (as seen in settings) for checkpoint name. Recommended to use with Y axis for less switching.",
} }
function gradioApp(){ function gradioApp(){
@ -74,6 +76,41 @@ function gradioApp(){
global_progressbar = null global_progressbar = null
function closeModal() {
gradioApp().getElementById("lightboxModal").style.display = "none";
}
function showModal(elem) {
gradioApp().getElementById("modalImage").src = elem.src
gradioApp().getElementById("lightboxModal").style.display = "block";
}
function showGalleryImage(){
setTimeout(function() {
fullImg_preview = gradioApp().querySelectorAll('img.w-full.object-contain')
if(fullImg_preview != null){
fullImg_preview.forEach(function function_name(e) {
if(e && e.parentElement.tagName == 'DIV'){
e.style.cursor='pointer'
elemfunc = function(elem){
elem.onclick = function(){showModal(elem)};
}
elemfunc(e)
}
});
}
}, 100);
}
function galleryImageHandler(e){
if(e && e.parentElement.tagName == 'BUTTON'){
e.onclick = showGalleryImage;
}
}
function addTitles(root){ function addTitles(root){
root.querySelectorAll('span, button, select').forEach(function(span){ root.querySelectorAll('span, button, select').forEach(function(span){
tooltip = titles[span.textContent]; tooltip = titles[span.textContent];
@ -116,12 +153,17 @@ function addTitles(root){
img2img_preview.style.height = img2img_gallery.clientHeight + "px" img2img_preview.style.height = img2img_gallery.clientHeight + "px"
} }
window.setTimeout(requestProgress, 500) window.setTimeout(requestProgress, 500)
}); });
mutationObserver.observe( progressbar, { childList:true, subtree:true }) mutationObserver.observe( progressbar, { childList:true, subtree:true })
} }
fullImg_preview = gradioApp().querySelectorAll('img.w-full')
if(fullImg_preview != null){
fullImg_preview.forEach(galleryImageHandler);
}
} }
document.addEventListener("DOMContentLoaded", function() { document.addEventListener("DOMContentLoaded", function() {
@ -129,6 +171,27 @@ document.addEventListener("DOMContentLoaded", function() {
addTitles(gradioApp()); addTitles(gradioApp());
}); });
mutationObserver.observe( gradioApp(), { childList:true, subtree:true }) mutationObserver.observe( gradioApp(), { childList:true, subtree:true })
const modalFragment = document.createDocumentFragment();
const modal = document.createElement('div')
modal.onclick = closeModal;
const modalClose = document.createElement('span')
modalClose.className = 'modalClose cursor';
modalClose.innerHTML = '&times;'
modalClose.onclick = closeModal;
modal.id = "lightboxModal";
modal.appendChild(modalClose)
const modalImage = document.createElement('img')
modalImage.id = 'modalImage';
modalImage.onclick = closeModal;
modal.appendChild(modalImage)
gradioApp().getRootNode().appendChild(modal)
document.body.appendChild(modalFragment);
}); });
function selected_gallery_index(){ function selected_gallery_index(){
@ -180,6 +243,15 @@ function submit(){
for(var i=0;i<arguments.length;i++){ for(var i=0;i<arguments.length;i++){
res.push(arguments[i]) res.push(arguments[i])
} }
// As it is currently, txt2img and img2img send back the previous output args (txt2img_gallery, generation_info, html_info) whenever you generate a new image.
// This can lead to uploading a huge gallery of previously generated images, which leads to an unnecessary delay between submitting and beginning to generate.
// I don't know why gradio is seding outputs along with inputs, but we can prevent sending the image gallery here, which seems to be an issue for some.
// If gradio at some point stops sending outputs, this may break something
if(Array.isArray(res[res.length - 3])){
res[res.length - 3] = null
}
return res return res
} }

View file

@ -59,7 +59,7 @@ def find_noise_for_image(p, cond, uncond, cfg_scale, steps):
return x / x.std() return x / x.std()
Cached = namedtuple("Cached", ["noise", "cfg_scale", "steps", "latent", "original_prompt"]) Cached = namedtuple("Cached", ["noise", "cfg_scale", "steps", "latent", "original_prompt", "original_negative_prompt"])
class Script(scripts.Script): class Script(scripts.Script):
@ -74,19 +74,20 @@ class Script(scripts.Script):
def ui(self, is_img2img): def ui(self, is_img2img):
original_prompt = gr.Textbox(label="Original prompt", lines=1) original_prompt = gr.Textbox(label="Original prompt", lines=1)
original_negative_prompt = gr.Textbox(label="Original negative prompt", lines=1)
cfg = gr.Slider(label="Decode CFG scale", minimum=0.0, maximum=15.0, step=0.1, value=1.0) cfg = gr.Slider(label="Decode CFG scale", minimum=0.0, maximum=15.0, step=0.1, value=1.0)
st = gr.Slider(label="Decode steps", minimum=1, maximum=150, step=1, value=50) st = gr.Slider(label="Decode steps", minimum=1, maximum=150, step=1, value=50)
randomness = gr.Slider(label="randomness", minimum=0.0, maximum=1.0, step=0.01, value=0.0) randomness = gr.Slider(label="randomness", minimum=0.0, maximum=1.0, step=0.01, value=0.0)
return [original_prompt, cfg, st, randomness] return [original_prompt, original_negative_prompt, cfg, st, randomness]
def run(self, p, original_prompt, cfg, st, randomness): def run(self, p, original_prompt, original_negative_prompt, cfg, st, randomness):
p.batch_size = 1 p.batch_size = 1
p.batch_count = 1 p.batch_count = 1
def sample_extra(x, conditioning, unconditional_conditioning): def sample_extra(x, conditioning, unconditional_conditioning):
lat = (p.init_latent.cpu().numpy() * 10).astype(int) lat = (p.init_latent.cpu().numpy() * 10).astype(int)
same_params = self.cache is not None and self.cache.cfg_scale == cfg and self.cache.steps == st and self.cache.original_prompt == original_prompt same_params = self.cache is not None and self.cache.cfg_scale == cfg and self.cache.steps == st and self.cache.original_prompt == original_prompt and self.cache.original_negative_prompt == original_negative_prompt
same_everything = same_params and self.cache.latent.shape == lat.shape and np.abs(self.cache.latent-lat).sum() < 100 same_everything = same_params and self.cache.latent.shape == lat.shape and np.abs(self.cache.latent-lat).sum() < 100
if same_everything: if same_everything:
@ -94,9 +95,9 @@ class Script(scripts.Script):
else: else:
shared.state.job_count += 1 shared.state.job_count += 1
cond = p.sd_model.get_learned_conditioning(p.batch_size * [original_prompt]) cond = p.sd_model.get_learned_conditioning(p.batch_size * [original_prompt])
uncond = p.sd_model.get_learned_conditioning(p.batch_size * [""]) uncond = p.sd_model.get_learned_conditioning(p.batch_size * [original_negative_prompt])
rec_noise = find_noise_for_image(p, cond, uncond, cfg, st) rec_noise = find_noise_for_image(p, cond, uncond, cfg, st)
self.cache = Cached(rec_noise, cfg, st, lat, original_prompt) self.cache = Cached(rec_noise, cfg, st, lat, original_prompt, original_negative_prompt)
rand_noise = processing.create_random_tensors(p.init_latent.shape[1:], [p.seed + x + 1 for x in range(p.init_latent.shape[0])]) rand_noise = processing.create_random_tensors(p.init_latent.shape[1:], [p.seed + x + 1 for x in range(p.init_latent.shape[0])])

290
scripts/outpainting_mk_2.py Normal file
View file

@ -0,0 +1,290 @@
import math
import numpy as np
import skimage
import modules.scripts as scripts
import gradio as gr
from PIL import Image, ImageDraw
from modules import images, processing, devices
from modules.processing import Processed, process_images
from modules.shared import opts, cmd_opts, state
def expand(x, dir, amount, power=0.75):
is_left = dir == 3
is_right = dir == 1
is_up = dir == 0
is_down = dir == 2
if is_left or is_right:
noise = np.zeros((x.shape[0], amount, 3), dtype=float)
indexes = np.random.random((x.shape[0], amount)) ** power * (1 - np.arange(amount) / amount)
if is_right:
indexes = 1 - indexes
indexes = (indexes * (x.shape[1] - 1)).astype(int)
for row in range(x.shape[0]):
if is_left:
noise[row] = x[row][indexes[row]]
else:
noise[row] = np.flip(x[row][indexes[row]], axis=0)
x = np.concatenate([noise, x] if is_left else [x, noise], axis=1)
return x
if is_up or is_down:
noise = np.zeros((amount, x.shape[1], 3), dtype=float)
indexes = np.random.random((x.shape[1], amount)) ** power * (1 - np.arange(amount) / amount)
if is_down:
indexes = 1 - indexes
indexes = (indexes * x.shape[0] - 1).astype(int)
for row in range(x.shape[1]):
if is_up:
noise[:, row] = x[:, row][indexes[row]]
else:
noise[:, row] = np.flip(x[:, row][indexes[row]], axis=0)
x = np.concatenate([noise, x] if is_up else [x, noise], axis=0)
return x
def get_matched_noise(_np_src_image, np_mask_rgb, noise_q=1, color_variation=0.05):
# helper fft routines that keep ortho normalization and auto-shift before and after fft
def _fft2(data):
if data.ndim > 2: # has channels
out_fft = np.zeros((data.shape[0], data.shape[1], data.shape[2]), dtype=np.complex128)
for c in range(data.shape[2]):
c_data = data[:, :, c]
out_fft[:, :, c] = np.fft.fft2(np.fft.fftshift(c_data), norm="ortho")
out_fft[:, :, c] = np.fft.ifftshift(out_fft[:, :, c])
else: # one channel
out_fft = np.zeros((data.shape[0], data.shape[1]), dtype=np.complex128)
out_fft[:, :] = np.fft.fft2(np.fft.fftshift(data), norm="ortho")
out_fft[:, :] = np.fft.ifftshift(out_fft[:, :])
return out_fft
def _ifft2(data):
if data.ndim > 2: # has channels
out_ifft = np.zeros((data.shape[0], data.shape[1], data.shape[2]), dtype=np.complex128)
for c in range(data.shape[2]):
c_data = data[:, :, c]
out_ifft[:, :, c] = np.fft.ifft2(np.fft.fftshift(c_data), norm="ortho")
out_ifft[:, :, c] = np.fft.ifftshift(out_ifft[:, :, c])
else: # one channel
out_ifft = np.zeros((data.shape[0], data.shape[1]), dtype=np.complex128)
out_ifft[:, :] = np.fft.ifft2(np.fft.fftshift(data), norm="ortho")
out_ifft[:, :] = np.fft.ifftshift(out_ifft[:, :])
return out_ifft
def _get_gaussian_window(width, height, std=3.14, mode=0):
window_scale_x = float(width / min(width, height))
window_scale_y = float(height / min(width, height))
window = np.zeros((width, height))
x = (np.arange(width) / width * 2. - 1.) * window_scale_x
for y in range(height):
fy = (y / height * 2. - 1.) * window_scale_y
if mode == 0:
window[:, y] = np.exp(-(x ** 2 + fy ** 2) * std)
else:
window[:, y] = (1 / ((x ** 2 + 1.) * (fy ** 2 + 1.))) ** (std / 3.14) # hey wait a minute that's not gaussian
return window
def _get_masked_window_rgb(np_mask_grey, hardness=1.):
np_mask_rgb = np.zeros((np_mask_grey.shape[0], np_mask_grey.shape[1], 3))
if hardness != 1.:
hardened = np_mask_grey[:] ** hardness
else:
hardened = np_mask_grey[:]
for c in range(3):
np_mask_rgb[:, :, c] = hardened[:]
return np_mask_rgb
width = _np_src_image.shape[0]
height = _np_src_image.shape[1]
num_channels = _np_src_image.shape[2]
np_src_image = _np_src_image[:] * (1. - np_mask_rgb)
np_mask_grey = (np.sum(np_mask_rgb, axis=2) / 3.)
img_mask = np_mask_grey > 1e-6
ref_mask = np_mask_grey < 1e-3
windowed_image = _np_src_image * (1. - _get_masked_window_rgb(np_mask_grey))
windowed_image /= np.max(windowed_image)
windowed_image += np.average(_np_src_image) * np_mask_rgb # / (1.-np.average(np_mask_rgb)) # rather than leave the masked area black, we get better results from fft by filling the average unmasked color
src_fft = _fft2(windowed_image) # get feature statistics from masked src img
src_dist = np.absolute(src_fft)
src_phase = src_fft / src_dist
noise_window = _get_gaussian_window(width, height, mode=1) # start with simple gaussian noise
noise_rgb = np.random.random_sample((width, height, num_channels))
noise_grey = (np.sum(noise_rgb, axis=2) / 3.)
noise_rgb *= color_variation # the colorfulness of the starting noise is blended to greyscale with a parameter
for c in range(num_channels):
noise_rgb[:, :, c] += (1. - color_variation) * noise_grey
noise_fft = _fft2(noise_rgb)
for c in range(num_channels):
noise_fft[:, :, c] *= noise_window
noise_rgb = np.real(_ifft2(noise_fft))
shaped_noise_fft = _fft2(noise_rgb)
shaped_noise_fft[:, :, :] = np.absolute(shaped_noise_fft[:, :, :]) ** 2 * (src_dist ** noise_q) * src_phase # perform the actual shaping
brightness_variation = 0. # color_variation # todo: temporarily tieing brightness variation to color variation for now
contrast_adjusted_np_src = _np_src_image[:] * (brightness_variation + 1.) - brightness_variation * 2.
# scikit-image is used for histogram matching, very convenient!
shaped_noise = np.real(_ifft2(shaped_noise_fft))
shaped_noise -= np.min(shaped_noise)
shaped_noise /= np.max(shaped_noise)
shaped_noise[img_mask, :] = skimage.exposure.match_histograms(shaped_noise[img_mask, :] ** 1., contrast_adjusted_np_src[ref_mask, :], channel_axis=1)
shaped_noise = _np_src_image[:] * (1. - np_mask_rgb) + shaped_noise * np_mask_rgb
matched_noise = shaped_noise[:]
return np.clip(matched_noise, 0., 1.)
class Script(scripts.Script):
def title(self):
return "Outpainting mk2"
def show(self, is_img2img):
return is_img2img
def ui(self, is_img2img):
if not is_img2img:
return None
info = gr.HTML("<p style=\"margin-bottom:0.75em\">Recommended settings: Sampling Steps: 80-100, Sampler: Euler a, Denoising strength: 0.8</p>")
pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=256, step=8, value=128)
mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=8, visible=False)
direction = gr.CheckboxGroup(label="Outpainting direction", choices=['left', 'right', 'up', 'down'], value=['left', 'right', 'up', 'down'])
noise_q = gr.Slider(label="Fall-off exponent (lower=higher detail)", minimum=0.0, maximum=4.0, step=0.01, value=1.0)
color_variation = gr.Slider(label="Color variation", minimum=0.0, maximum=1.0, step=0.01, value=0.05)
return [info, pixels, mask_blur, direction, noise_q, color_variation]
def run(self, p, _, pixels, mask_blur, direction, noise_q, color_variation):
initial_seed_and_info = [None, None]
process_width = p.width
process_height = p.height
p.mask_blur = mask_blur*4
p.inpaint_full_res = False
p.inpainting_fill = 1
p.do_not_save_samples = True
p.do_not_save_grid = True
left = pixels if "left" in direction else 0
right = pixels if "right" in direction else 0
up = pixels if "up" in direction else 0
down = pixels if "down" in direction else 0
init_img = p.init_images[0]
target_w = math.ceil((init_img.width + left + right) / 64) * 64
target_h = math.ceil((init_img.height + up + down) / 64) * 64
if left > 0:
left = left * (target_w - init_img.width) // (left + right)
if right > 0:
right = target_w - init_img.width - left
if up > 0:
up = up * (target_h - init_img.height) // (up + down)
if down > 0:
down = target_h - init_img.height - up
init_image = p.init_images[0]
state.job_count = (1 if left > 0 else 0) + (1 if right > 0 else 0)+ (1 if up > 0 else 0)+ (1 if down > 0 else 0)
def expand(init, expand_pixels, is_left=False, is_right=False, is_top=False, is_bottom=False):
is_horiz = is_left or is_right
is_vert = is_top or is_bottom
pixels_horiz = expand_pixels if is_horiz else 0
pixels_vert = expand_pixels if is_vert else 0
img = Image.new("RGB", (init.width + pixels_horiz, init.height + pixels_vert))
img.paste(init, (pixels_horiz if is_left else 0, pixels_vert if is_top else 0))
mask = Image.new("RGB", (init.width + pixels_horiz, init.height + pixels_vert), "white")
draw = ImageDraw.Draw(mask)
draw.rectangle((
expand_pixels + mask_blur if is_left else 0,
expand_pixels + mask_blur if is_top else 0,
mask.width - expand_pixels - mask_blur if is_right else mask.width,
mask.height - expand_pixels - mask_blur if is_bottom else mask.height,
), fill="black")
np_image = (np.asarray(img) / 255.0).astype(np.float64)
np_mask = (np.asarray(mask) / 255.0).astype(np.float64)
noised = get_matched_noise(np_image, np_mask, noise_q, color_variation)
out = Image.fromarray(np.clip(noised * 255., 0., 255.).astype(np.uint8), mode="RGB")
target_width = min(process_width, init.width + pixels_horiz) if is_horiz else img.width
target_height = min(process_height, init.height + pixels_vert) if is_vert else img.height
crop_region = (
0 if is_left else out.width - target_width,
0 if is_top else out.height - target_height,
target_width if is_left else out.width,
target_height if is_top else out.height,
)
image_to_process = out.crop(crop_region)
mask = mask.crop(crop_region)
p.width = target_width if is_horiz else img.width
p.height = target_height if is_vert else img.height
p.init_images = [image_to_process]
p.image_mask = mask
latent_mask = Image.new("RGB", (p.width, p.height), "white")
draw = ImageDraw.Draw(latent_mask)
draw.rectangle((
expand_pixels + mask_blur * 2 if is_left else 0,
expand_pixels + mask_blur * 2 if is_top else 0,
mask.width - expand_pixels - mask_blur * 2 if is_right else mask.width,
mask.height - expand_pixels - mask_blur * 2 if is_bottom else mask.height,
), fill="black")
p.latent_mask = latent_mask
proc = process_images(p)
proc_img = proc.images[0]
if initial_seed_and_info[0] is None:
initial_seed_and_info[0] = proc.seed
initial_seed_and_info[1] = proc.info
out.paste(proc_img, (0 if is_left else out.width - proc_img.width, 0 if is_top else out.height - proc_img.height))
return out
img = init_image
if left > 0:
img = expand(img, left, is_left=True)
if right > 0:
img = expand(img, right, is_right=True)
if up > 0:
img = expand(img, up, is_top=True)
if down > 0:
img = expand(img, down, is_bottom=True)
res = Processed(p, [img], initial_seed_and_info[0], initial_seed_and_info[1])
if opts.samples_save:
images.save_image(img, p.outpath_samples, "", res.seed, p.prompt, opts.grid_format, info=res.info, p=p)
return res

View file

@ -13,28 +13,42 @@ from modules.shared import opts, cmd_opts, state
class Script(scripts.Script): class Script(scripts.Script):
def title(self): def title(self):
return "Prompts from file" return "Prompts from file or textbox"
def ui(self, is_img2img): def ui(self, is_img2img):
# This checkbox would look nicer as two tabs, but there are two problems:
# 1) There is a bug in Gradio 3.3 that prevents visibility from working on Tabs
# 2) Even with Gradio 3.3.1, returning a control (like Tabs) that can't be used as input
# causes a AttributeError: 'Tabs' object has no attribute 'preprocess' assert,
# due to the way Script assumes all controls returned can be used as inputs.
# Therefore, there's no good way to use grouping components right now,
# so we will use a checkbox! :)
checkbox_txt = gr.Checkbox(label="Show Textbox", value=False)
file = gr.File(label="File with inputs", type='bytes') file = gr.File(label="File with inputs", type='bytes')
prompt_txt = gr.TextArea(label="Prompts")
checkbox_txt.change(fn=lambda x: [gr.File.update(visible = not x), gr.TextArea.update(visible = x)], inputs=[checkbox_txt], outputs=[file, prompt_txt])
return [checkbox_txt, file, prompt_txt]
return [file] def run(self, p, checkbox_txt, data: bytes, prompt_txt: str):
if (checkbox_txt):
def run(self, p, data: bytes): lines = [x.strip() for x in prompt_txt.splitlines()]
else:
lines = [x.strip() for x in data.decode('utf8', errors='ignore').split("\n")] lines = [x.strip() for x in data.decode('utf8', errors='ignore').split("\n")]
lines = [x for x in lines if len(x) > 0] lines = [x for x in lines if len(x) > 0]
batch_count = math.ceil(len(lines) / p.batch_size) img_count = len(lines) * p.n_iter
print(f"Will process {len(lines) * p.n_iter} images in {batch_count * p.n_iter} batches.") batch_count = math.ceil(img_count / p.batch_size)
loop_count = math.ceil(batch_count / p.n_iter)
print(f"Will process {img_count} images in {batch_count} batches.")
p.do_not_save_grid = True p.do_not_save_grid = True
state.job_count = batch_count state.job_count = batch_count
images = [] images = []
for batch_no in range(batch_count): for loop_no in range(loop_count):
state.job = f"{batch_no + 1} out of {batch_count * p.n_iter}" state.job = f"{loop_no + 1} out of {loop_count}"
p.prompt = lines[batch_no*p.batch_size:(batch_no+1)*p.batch_size] * p.n_iter p.prompt = lines[loop_no*p.batch_size:(loop_no+1)*p.batch_size] * p.n_iter
proc = process_images(p) proc = process_images(p)
images += proc.images images += proc.images

View file

@ -10,7 +10,9 @@ import gradio as gr
from modules import images from modules import images
from modules.processing import process_images, Processed from modules.processing import process_images, Processed
from modules.shared import opts, cmd_opts, state from modules.shared import opts, cmd_opts, state
import modules.shared as shared
import modules.sd_samplers import modules.sd_samplers
import modules.sd_models
import re import re
@ -41,6 +43,15 @@ def apply_sampler(p, x, xs):
p.sampler_index = sampler_index p.sampler_index = sampler_index
def apply_checkpoint(p, x, xs):
applicable = [info for info in modules.sd_models.checkpoints_list.values() if x in info.title]
assert len(applicable) > 0, f'Checkpoint {x} for found'
info = applicable[0]
modules.sd_models.reload_model_weights(shared.sd_model, info)
def format_value_add_label(p, opt, x): def format_value_add_label(p, opt, x):
if type(x) == float: if type(x) == float:
x = round(x, 8) x = round(x, 8)
@ -74,15 +85,16 @@ axis_options = [
AxisOption("CFG Scale", float, apply_field("cfg_scale"), format_value_add_label), AxisOption("CFG Scale", float, apply_field("cfg_scale"), format_value_add_label),
AxisOption("Prompt S/R", str, apply_prompt, format_value), AxisOption("Prompt S/R", str, apply_prompt, format_value),
AxisOption("Sampler", str, apply_sampler, format_value), AxisOption("Sampler", str, apply_sampler, format_value),
AxisOption("Checkpoint name", str, apply_checkpoint, format_value),
AxisOptionImg2Img("Denoising", float, apply_field("denoising_strength"), format_value_add_label), # as it is now all AxisOptionImg2Img items must go after AxisOption ones AxisOptionImg2Img("Denoising", float, apply_field("denoising_strength"), format_value_add_label), # as it is now all AxisOptionImg2Img items must go after AxisOption ones
] ]
def draw_xy_grid(p, xs, ys, x_label, y_label, cell, draw_legend): def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend):
res = [] res = []
ver_texts = [[images.GridAnnotation(y_label(y))] for y in ys] ver_texts = [[images.GridAnnotation(y)] for y in y_labels]
hor_texts = [[images.GridAnnotation(x_label(x))] for x in xs] hor_texts = [[images.GridAnnotation(x)] for x in x_labels]
first_pocessed = None first_pocessed = None
@ -206,8 +218,8 @@ class Script(scripts.Script):
p, p,
xs=xs, xs=xs,
ys=ys, ys=ys,
x_label=lambda x: x_opt.format_value(p, x_opt, x), x_labels=[x_opt.format_value(p, x_opt, x) for x in xs],
y_label=lambda y: y_opt.format_value(p, y_opt, y), y_labels=[y_opt.format_value(p, y_opt, y) for y in ys],
cell=cell, cell=cell,
draw_legend=draw_legend draw_legend=draw_legend
) )
@ -215,4 +227,7 @@ class Script(scripts.Script):
if opts.grid_save: if opts.grid_save:
images.save_image(processed.images[0], p.outpath_grids, "xy_grid", prompt=p.prompt, seed=processed.seed, grid=True, p=p) images.save_image(processed.images[0], p.outpath_grids, "xy_grid", prompt=p.prompt, seed=processed.seed, grid=True, p=p)
# restore checkpoint in case it was changed by axes
modules.sd_models.reload_model_weights(shared.sd_model)
return processed return processed

View file

@ -1,5 +1,21 @@
.output-html p {margin: 0 0.5em;} .output-html p {margin: 0 0.5em;}
.performance { font-size: 0.85em; color: #444; }
.performance {
font-size: 0.85em;
color: #444;
display: flex;
justify-content: space-between;
white-space: nowrap;
}
.performance .time {
margin-right: 0;
}
.performance .vram {
margin-left: 0;
text-align: right;
}
#generate{ #generate{
min-height: 4.5em; min-height: 4.5em;
@ -151,6 +167,12 @@ input[type="range"]{
#txt2img_negative_prompt, #img2img_negative_prompt{ #txt2img_negative_prompt, #img2img_negative_prompt{
} }
#progressbar{
position: absolute;
z-index: 1000;
right: 0;
}
.progressDiv{ .progressDiv{
width: 100%; width: 100%;
height: 30px; height: 30px;
@ -174,3 +196,40 @@ input[type="range"]{
border-radius: 8px; border-radius: 8px;
} }
#lightboxModal{
display: none;
position: fixed;
z-index: 900;
padding-top: 100px;
left: 0;
top: 0;
width: 100%;
height: 100%;
overflow: auto;
background-color: rgba(20, 20, 20, 0.95);
}
.modalClose {
color: white;
position: absolute;
top: 10px;
right: 25px;
font-size: 35px;
font-weight: bold;
}
.modalClose:hover,
.modalClose:focus {
color: #999;
text-decoration: none;
cursor: pointer;
}
#modalImage {
display: block;
margin-left: auto;
margin-right: auto;
margin-top: auto;
width: auto;
}

View file

@ -3,12 +3,8 @@ import threading
from modules.paths import script_path from modules.paths import script_path
import torch
from omegaconf import OmegaConf
import signal import signal
from ldm.util import instantiate_from_config
from modules.shared import opts, cmd_opts, state from modules.shared import opts, cmd_opts, state
import modules.shared as shared import modules.shared as shared
import modules.ui import modules.ui
@ -23,6 +19,7 @@ import modules.extras
import modules.lowvram import modules.lowvram
import modules.txt2img import modules.txt2img
import modules.img2img import modules.img2img
import modules.sd_models
modules.codeformer_model.setup_codeformer() modules.codeformer_model.setup_codeformer()
@ -32,31 +29,19 @@ shared.face_restorers.append(modules.face_restoration.FaceRestoration())
esrgan.load_models(cmd_opts.esrgan_models_path) esrgan.load_models(cmd_opts.esrgan_models_path)
realesrgan.setup_realesrgan() realesrgan.setup_realesrgan()
def load_model_from_config(config, ckpt, verbose=False):
print(f"Loading model [{shared.sd_model_hash}] from {ckpt}")
pl_sd = torch.load(ckpt, map_location="cpu")
if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}")
sd = pl_sd["state_dict"]
model = instantiate_from_config(config.model)
m, u = model.load_state_dict(sd, strict=False)
if len(m) > 0 and verbose:
print("missing keys:")
print(m)
if len(u) > 0 and verbose:
print("unexpected keys:")
print(u)
if cmd_opts.opt_channelslast:
model = model.to(memory_format=torch.channels_last)
model.eval()
return model
queue_lock = threading.Lock() queue_lock = threading.Lock()
def wrap_queued_call(func):
def f(*args, **kwargs):
with queue_lock:
res = func(*args, **kwargs)
return res
return f
def wrap_gradio_gpu_call(func): def wrap_gradio_gpu_call(func):
def f(*args, **kwargs): def f(*args, **kwargs):
shared.state.sampling_step = 0 shared.state.sampling_step = 0
@ -79,33 +64,8 @@ def wrap_gradio_gpu_call(func):
modules.scripts.load_scripts(os.path.join(script_path, "scripts")) modules.scripts.load_scripts(os.path.join(script_path, "scripts"))
try: shared.sd_model = modules.sd_models.load_model()
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start. shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(shared.sd_model)))
from transformers import logging
logging.set_verbosity_error()
except Exception:
pass
with open(cmd_opts.ckpt, "rb") as file:
import hashlib
m = hashlib.sha256()
file.seek(0x100000)
m.update(file.read(0x10000))
shared.sd_model_hash = m.hexdigest()[0:8]
sd_config = OmegaConf.load(cmd_opts.config)
shared.sd_model = load_model_from_config(sd_config, cmd_opts.ckpt)
shared.sd_model = (shared.sd_model if cmd_opts.no_half else shared.sd_model.half())
if cmd_opts.lowvram or cmd_opts.medvram:
modules.lowvram.setup_for_low_vram(shared.sd_model, cmd_opts.medvram)
else:
shared.sd_model = shared.sd_model.to(shared.device)
modules.sd_hijack.model_hijack.hijack(shared.sd_model)
def webui(): def webui():