Merge branch 'master' into fix-vram

This commit is contained in:
Jairo Correa 2022-10-04 19:53:52 -03:00
commit 1f50971fb8
22 changed files with 321 additions and 173 deletions

View file

@ -15,7 +15,7 @@ titles = {
"\u267b\ufe0f": "Reuse seed from last generation, mostly useful if it was randomed", "\u267b\ufe0f": "Reuse seed from last generation, mostly useful if it was randomed",
"\u{1f3a8}": "Add a random artist to the prompt.", "\u{1f3a8}": "Add a random artist to the prompt.",
"\u2199\ufe0f": "Read generation parameters from prompt into user interface.", "\u2199\ufe0f": "Read generation parameters from prompt into user interface.",
"\uD83D\uDCC2": "Open images output directory", "\u{1f4c2}": "Open images output directory",
"Inpaint a part of image": "Draw a mask over an image, and the script will regenerate the masked area with content according to prompt", "Inpaint a part of image": "Draw a mask over an image, and the script will regenerate the masked area with content according to prompt",
"SD upscale": "Upscale image normally, split result into tiles, improve each tile using img2img, merge whole image back", "SD upscale": "Upscale image normally, split result into tiles, improve each tile using img2img, merge whole image back",
@ -47,6 +47,7 @@ titles = {
"Custom code": "Run Python code. Advanced user only. Must run program with --allow-code for this to work", "Custom code": "Run Python code. Advanced user only. Must run program with --allow-code for this to work",
"Prompt S/R": "Separate a list of words with commas, and the first word will be used as a keyword: script will search for this word in the prompt, and replace it with others", "Prompt S/R": "Separate a list of words with commas, and the first word will be used as a keyword: script will search for this word in the prompt, and replace it with others",
"Prompt order": "Separate a list of words with commas, and the script will make a variation of prompt with those words for their every possible order",
"Tiling": "Produce an image that can be tiled.", "Tiling": "Produce an image that can be tiled.",
"Tile overlap": "For SD upscale, how much overlap in pixels should there be between tiles. Tiles overlap so that when they are merged back into one picture, there is no clearly visible seam.", "Tile overlap": "For SD upscale, how much overlap in pixels should there be between tiles. Tiles overlap so that when they are merged back into one picture, there is no clearly visible seam.",

View file

@ -8,7 +8,7 @@ import torch
from basicsr.utils.download_util import load_file_from_url from basicsr.utils.download_util import load_file_from_url
import modules.upscaler import modules.upscaler
from modules import shared, modelloader from modules import devices, modelloader
from modules.bsrgan_model_arch import RRDBNet from modules.bsrgan_model_arch import RRDBNet
from modules.paths import models_path from modules.paths import models_path
@ -44,13 +44,13 @@ class UpscalerBSRGAN(modules.upscaler.Upscaler):
model = self.load_model(selected_file) model = self.load_model(selected_file)
if model is None: if model is None:
return img return img
model.to(shared.device) model.to(devices.device_bsrgan)
torch.cuda.empty_cache() torch.cuda.empty_cache()
img = np.array(img) img = np.array(img)
img = img[:, :, ::-1] img = img[:, :, ::-1]
img = np.moveaxis(img, 2, 0) / 255 img = np.moveaxis(img, 2, 0) / 255
img = torch.from_numpy(img).float() img = torch.from_numpy(img).float()
img = img.unsqueeze(0).to(shared.device) img = img.unsqueeze(0).to(devices.device_bsrgan)
with torch.no_grad(): with torch.no_grad():
output = model(img) output = model(img)
output = output.squeeze().float().cpu().clamp_(0, 1).numpy() output = output.squeeze().float().cpu().clamp_(0, 1).numpy()

View file

@ -69,10 +69,14 @@ def setup_model(dirname):
self.net = net self.net = net
self.face_helper = face_helper self.face_helper = face_helper
self.net.to(devices.device_codeformer)
return net, face_helper return net, face_helper
def send_model_to(self, device):
self.net.to(device)
self.face_helper.face_det.to(device)
self.face_helper.face_parse.to(device)
def restore(self, np_image, w=None): def restore(self, np_image, w=None):
np_image = np_image[:, :, ::-1] np_image = np_image[:, :, ::-1]
@ -82,6 +86,8 @@ def setup_model(dirname):
if self.net is None or self.face_helper is None: if self.net is None or self.face_helper is None:
return np_image return np_image
self.send_model_to(devices.device_codeformer)
self.face_helper.clean_all() self.face_helper.clean_all()
self.face_helper.read_image(np_image) self.face_helper.read_image(np_image)
self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5) self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
@ -97,7 +103,7 @@ def setup_model(dirname):
output = self.net(cropped_face_t, w=w if w is not None else shared.opts.code_former_weight, adain=True)[0] output = self.net(cropped_face_t, w=w if w is not None else shared.opts.code_former_weight, adain=True)[0]
restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1)) restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
del output del output
devices.torch_gc() torch.cuda.empty_cache()
except Exception as error: except Exception as error:
print(f'\tFailed inference for CodeFormer: {error}', file=sys.stderr) print(f'\tFailed inference for CodeFormer: {error}', file=sys.stderr)
restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1)) restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
@ -113,10 +119,10 @@ def setup_model(dirname):
if original_resolution != restored_img.shape[0:2]: if original_resolution != restored_img.shape[0:2]:
restored_img = cv2.resize(restored_img, (0, 0), fx=original_resolution[1]/restored_img.shape[1], fy=original_resolution[0]/restored_img.shape[0], interpolation=cv2.INTER_LINEAR) restored_img = cv2.resize(restored_img, (0, 0), fx=original_resolution[1]/restored_img.shape[1], fy=original_resolution[0]/restored_img.shape[0], interpolation=cv2.INTER_LINEAR)
self.face_helper.clean_all()
if shared.opts.face_restoration_unload: if shared.opts.face_restoration_unload:
self.net = None self.send_model_to(devices.cpu)
self.face_helper = None
devices.torch_gc()
return restored_img return restored_img

View file

@ -1,9 +1,11 @@
import contextlib
import torch import torch
import gc import gc
# has_mps is only available in nightly pytorch (for now), `getattr` for compatibility
from modules import errors from modules import errors
# has_mps is only available in nightly pytorch (for now), `getattr` for compatibility
has_mps = getattr(torch, 'has_mps', False) has_mps = getattr(torch, 'has_mps', False)
cpu = torch.device("cpu") cpu = torch.device("cpu")
@ -33,8 +35,7 @@ def enable_tf32():
errors.run(enable_tf32, "Enabling TF32") errors.run(enable_tf32, "Enabling TF32")
device = get_optimal_device() device = device_gfpgan = device_bsrgan = device_esrgan = device_scunet = device_codeformer = get_optimal_device()
device_codeformer = cpu if has_mps else device
dtype = torch.float16 dtype = torch.float16
def randn(seed, shape): def randn(seed, shape):
@ -58,3 +59,11 @@ def randn_without_seed(shape):
return torch.randn(shape, device=device) return torch.randn(shape, device=device)
def autocast():
from modules import shared
if dtype == torch.float32 or shared.cmd_opts.precision == "full":
return contextlib.nullcontext()
return torch.autocast("cuda")

View file

@ -6,8 +6,7 @@ from PIL import Image
from basicsr.utils.download_util import load_file_from_url from basicsr.utils.download_util import load_file_from_url
import modules.esrgam_model_arch as arch import modules.esrgam_model_arch as arch
from modules import shared, modelloader, images from modules import shared, modelloader, images, devices
from modules.devices import has_mps
from modules.paths import models_path from modules.paths import models_path
from modules.upscaler import Upscaler, UpscalerData from modules.upscaler import Upscaler, UpscalerData
from modules.shared import opts from modules.shared import opts
@ -97,7 +96,7 @@ class UpscalerESRGAN(Upscaler):
model = self.load_model(selected_model) model = self.load_model(selected_model)
if model is None: if model is None:
return img return img
model.to(shared.device) model.to(devices.device_esrgan)
img = esrgan_upscale(model, img) img = esrgan_upscale(model, img)
return img return img
@ -112,7 +111,7 @@ class UpscalerESRGAN(Upscaler):
print("Unable to load %s from %s" % (self.model_path, filename)) print("Unable to load %s from %s" % (self.model_path, filename))
return None return None
pretrained_net = torch.load(filename, map_location='cpu' if has_mps else None) pretrained_net = torch.load(filename, map_location='cpu' if shared.device.type == 'mps' else None)
crt_model = arch.RRDBNet(3, 3, 64, 23, gc=32) crt_model = arch.RRDBNet(3, 3, 64, 23, gc=32)
pretrained_net = fix_model_layers(crt_model, pretrained_net) pretrained_net = fix_model_layers(crt_model, pretrained_net)
@ -127,7 +126,7 @@ def upscale_without_tiling(model, img):
img = img[:, :, ::-1] img = img[:, :, ::-1]
img = np.moveaxis(img, 2, 0) / 255 img = np.moveaxis(img, 2, 0) / 255
img = torch.from_numpy(img).float() img = torch.from_numpy(img).float()
img = img.unsqueeze(0).to(shared.device) img = img.unsqueeze(0).to(devices.device_esrgan)
with torch.no_grad(): with torch.no_grad():
output = model(img) output = model(img)
output = output.squeeze().float().cpu().clamp_(0, 1).numpy() output = output.squeeze().float().cpu().clamp_(0, 1).numpy()

View file

@ -21,7 +21,7 @@ def gfpgann():
global loaded_gfpgan_model global loaded_gfpgan_model
global model_path global model_path
if loaded_gfpgan_model is not None: if loaded_gfpgan_model is not None:
loaded_gfpgan_model.gfpgan.to(shared.device) loaded_gfpgan_model.gfpgan.to(devices.device_gfpgan)
return loaded_gfpgan_model return loaded_gfpgan_model
if gfpgan_constructor is None: if gfpgan_constructor is None:
@ -37,25 +37,32 @@ def gfpgann():
print("Unable to load gfpgan model!") print("Unable to load gfpgan model!")
return None return None
model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None) model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None)
model.gfpgan.to(shared.device)
loaded_gfpgan_model = model loaded_gfpgan_model = model
return model return model
def send_model_to(model, device):
model.gfpgan.to(device)
model.face_helper.face_det.to(device)
model.face_helper.face_parse.to(device)
def gfpgan_fix_faces(np_image): def gfpgan_fix_faces(np_image):
global loaded_gfpgan_model
model = gfpgann() model = gfpgann()
if model is None: if model is None:
return np_image return np_image
send_model_to(model, devices.device_gfpgan)
np_image_bgr = np_image[:, :, ::-1] np_image_bgr = np_image[:, :, ::-1]
cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True) cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True)
np_image = gfpgan_output_bgr[:, :, ::-1] np_image = gfpgan_output_bgr[:, :, ::-1]
model.face_helper.clean_all()
if shared.opts.face_restoration_unload: if shared.opts.face_restoration_unload:
del model send_model_to(model, devices.cpu)
loaded_gfpgan_model = None
devices.torch_gc()
return np_image return np_image

View file

@ -287,6 +287,25 @@ def apply_filename_pattern(x, p, seed, prompt):
if seed is not None: if seed is not None:
x = x.replace("[seed]", str(seed)) x = x.replace("[seed]", str(seed))
if p is not None:
x = x.replace("[steps]", str(p.steps))
x = x.replace("[cfg]", str(p.cfg_scale))
x = x.replace("[width]", str(p.width))
x = x.replace("[height]", str(p.height))
#currently disabled if using the save button, will work otherwise
# if enabled it will cause a bug because styles is not included in the save_files data dictionary
if hasattr(p, "styles"):
x = x.replace("[styles]", sanitize_filename_part(", ".join([x for x in p.styles if not x == "None"]) or "None", replace_spaces=False))
x = x.replace("[sampler]", sanitize_filename_part(sd_samplers.samplers[p.sampler_index].name, replace_spaces=False))
x = x.replace("[model_hash]", shared.sd_model.sd_model_hash)
x = x.replace("[date]", datetime.date.today().isoformat())
x = x.replace("[datetime]", datetime.datetime.now().strftime("%Y%m%d%H%M%S"))
x = x.replace("[job_timestamp]", shared.state.job_timestamp)
# Apply [prompt] at last. Because it may contain any replacement word.^M
if prompt is not None: if prompt is not None:
x = x.replace("[prompt]", sanitize_filename_part(prompt)) x = x.replace("[prompt]", sanitize_filename_part(prompt))
if "[prompt_no_styles]" in x: if "[prompt_no_styles]" in x:
@ -295,7 +314,7 @@ def apply_filename_pattern(x, p, seed, prompt):
if len(style) > 0: if len(style) > 0:
style_parts = [y for y in style.split("{prompt}")] style_parts = [y for y in style.split("{prompt}")]
for part in style_parts: for part in style_parts:
prompt_no_style = prompt_no_style.replace(part, "").replace(", ,", ",").strip().strip(',') prompt_no_style = prompt_no_style.replace(part, "").replace(", ,", ",").strip().strip(',')
prompt_no_style = prompt_no_style.replace(style, "").strip().strip(',').strip() prompt_no_style = prompt_no_style.replace(style, "").strip().strip(',').strip()
x = x.replace("[prompt_no_styles]", sanitize_filename_part(prompt_no_style, replace_spaces=False)) x = x.replace("[prompt_no_styles]", sanitize_filename_part(prompt_no_style, replace_spaces=False))
@ -306,24 +325,6 @@ def apply_filename_pattern(x, p, seed, prompt):
words = ["empty"] words = ["empty"]
x = x.replace("[prompt_words]", sanitize_filename_part(" ".join(words[0:max_prompt_words]), replace_spaces=False)) x = x.replace("[prompt_words]", sanitize_filename_part(" ".join(words[0:max_prompt_words]), replace_spaces=False))
if p is not None:
x = x.replace("[steps]", str(p.steps))
x = x.replace("[cfg]", str(p.cfg_scale))
x = x.replace("[width]", str(p.width))
x = x.replace("[height]", str(p.height))
#currently disabled if using the save button, will work otherwise
# if enabled it will cause a bug because styles is not included in the save_files data dictionary
if hasattr(p, "styles"):
x = x.replace("[styles]", sanitize_filename_part(", ".join([x for x in p.styles if not x == "None"]), replace_spaces=False))
x = x.replace("[sampler]", sanitize_filename_part(sd_samplers.samplers[p.sampler_index].name, replace_spaces=False))
x = x.replace("[model_hash]", shared.sd_model.sd_model_hash)
x = x.replace("[date]", datetime.date.today().isoformat())
x = x.replace("[datetime]", datetime.datetime.now().strftime("%Y%m%d%H%M%S"))
x = x.replace("[job_timestamp]", shared.state.job_timestamp)
if cmd_opts.hide_ui_dir_config: if cmd_opts.hide_ui_dir_config:
x = re.sub(r'^[\\/]+|\.{2,}[\\/]+|[\\/]+\.{2,}', '', x) x = re.sub(r'^[\\/]+|\.{2,}[\\/]+|[\\/]+\.{2,}', '', x)
@ -379,7 +380,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
save_to_dirs = (grid and opts.grid_save_to_dirs) or (not grid and opts.save_to_dirs and not no_prompt) save_to_dirs = (grid and opts.grid_save_to_dirs) or (not grid and opts.save_to_dirs and not no_prompt)
if save_to_dirs: if save_to_dirs:
dirname = apply_filename_pattern(opts.directories_filename_pattern or "[prompt_words]", p, seed, prompt) dirname = apply_filename_pattern(opts.directories_filename_pattern or "[prompt_words]", p, seed, prompt).strip('\\ /')
path = os.path.join(path, dirname) path = os.path.join(path, dirname)
os.makedirs(path, exist_ok=True) os.makedirs(path, exist_ok=True)

View file

@ -23,8 +23,10 @@ def process_batch(p, input_dir, output_dir, args):
print(f"Will process {len(images)} images, creating {p.n_iter * p.batch_size} new images for each.") print(f"Will process {len(images)} images, creating {p.n_iter * p.batch_size} new images for each.")
save_normally = output_dir == ''
p.do_not_save_grid = True p.do_not_save_grid = True
p.do_not_save_samples = True p.do_not_save_samples = not save_normally
state.job_count = len(images) * p.n_iter state.job_count = len(images) * p.n_iter
@ -48,7 +50,8 @@ def process_batch(p, input_dir, output_dir, args):
left, right = os.path.splitext(filename) left, right = os.path.splitext(filename)
filename = f"{left}-{n}{right}" filename = f"{left}-{n}{right}"
processed_image.save(os.path.join(output_dir, filename)) if not save_normally:
processed_image.save(os.path.join(output_dir, filename))
def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, init_img_with_mask, init_img_inpaint, init_mask_inpaint, mask_mode, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args): def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, init_img_with_mask, init_img_inpaint, init_mask_inpaint, mask_mode, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args):
@ -126,4 +129,7 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
if opts.samples_log_stdout: if opts.samples_log_stdout:
print(generation_info_js) print(generation_info_js)
if opts.do_not_show_images:
processed.images = []
return processed.images, generation_info_js, plaintext_to_html(processed.info) return processed.images, generation_info_js, plaintext_to_html(processed.info)

View file

@ -1,4 +1,3 @@
import contextlib
import json import json
import math import math
import os import os
@ -85,7 +84,7 @@ class StableDiffusionProcessing:
self.s_tmin = opts.s_tmin self.s_tmin = opts.s_tmin
self.s_tmax = float('inf') # not representable as a standard ui option self.s_tmax = float('inf') # not representable as a standard ui option
self.s_noise = opts.s_noise self.s_noise = opts.s_noise
if not seed_enable_extras: if not seed_enable_extras:
self.subseed = -1 self.subseed = -1
self.subseed_strength = 0 self.subseed_strength = 0
@ -249,9 +248,16 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
return x return x
def get_fixed_seed(seed):
if seed is None or seed == '' or seed == -1:
return int(random.randrange(4294967294))
return seed
def fix_seed(p): def fix_seed(p):
p.seed = int(random.randrange(4294967294)) if p.seed is None or p.seed == '' or p.seed == -1 else p.seed p.seed = get_fixed_seed(p.seed)
p.subseed = int(random.randrange(4294967294)) if p.subseed is None or p.subseed == '' or p.subseed == -1 else p.subseed p.subseed = get_fixed_seed(p.subseed)
def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration=0, position_in_batch=0): def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration=0, position_in_batch=0):
@ -290,10 +296,11 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
assert(len(p.prompt) > 0) assert(len(p.prompt) > 0)
else: else:
assert p.prompt is not None assert p.prompt is not None
devices.torch_gc() devices.torch_gc()
fix_seed(p) seed = get_fixed_seed(p.seed)
subseed = get_fixed_seed(p.subseed)
if p.outpath_samples is not None: if p.outpath_samples is not None:
os.makedirs(p.outpath_samples, exist_ok=True) os.makedirs(p.outpath_samples, exist_ok=True)
@ -312,15 +319,15 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
else: else:
all_prompts = p.batch_size * p.n_iter * [p.prompt] all_prompts = p.batch_size * p.n_iter * [p.prompt]
if type(p.seed) == list: if type(seed) == list:
all_seeds = p.seed all_seeds = seed
else: else:
all_seeds = [int(p.seed) + (x if p.subseed_strength == 0 else 0) for x in range(len(all_prompts))] all_seeds = [int(seed) + (x if p.subseed_strength == 0 else 0) for x in range(len(all_prompts))]
if type(p.subseed) == list: if type(subseed) == list:
all_subseeds = p.subseed all_subseeds = subseed
else: else:
all_subseeds = [int(p.subseed) + x for x in range(len(all_prompts))] all_subseeds = [int(subseed) + x for x in range(len(all_prompts))]
def infotext(iteration=0, position_in_batch=0): def infotext(iteration=0, position_in_batch=0):
return create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration, position_in_batch) return create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration, position_in_batch)
@ -330,10 +337,10 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
infotexts = [] infotexts = []
output_images = [] output_images = []
precision_scope = torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext
ema_scope = (contextlib.nullcontext if cmd_opts.lowvram else p.sd_model.ema_scope) with torch.no_grad():
with torch.no_grad(), precision_scope("cuda"), ema_scope(): with devices.autocast():
p.init(all_prompts, all_seeds, all_subseeds) p.init(all_prompts, all_seeds, all_subseeds)
if state.job_count == -1: if state.job_count == -1:
state.job_count = p.n_iter state.job_count = p.n_iter
@ -352,8 +359,9 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
#uc = p.sd_model.get_learned_conditioning(len(prompts) * [p.negative_prompt]) #uc = p.sd_model.get_learned_conditioning(len(prompts) * [p.negative_prompt])
#c = p.sd_model.get_learned_conditioning(prompts) #c = p.sd_model.get_learned_conditioning(prompts)
uc = prompt_parser.get_learned_conditioning(len(prompts) * [p.negative_prompt], p.steps) with devices.autocast():
c = prompt_parser.get_learned_conditioning(prompts, p.steps) uc = prompt_parser.get_learned_conditioning(shared.sd_model, len(prompts) * [p.negative_prompt], p.steps)
c = prompt_parser.get_learned_conditioning(shared.sd_model, prompts, p.steps)
if len(model_hijack.comments) > 0: if len(model_hijack.comments) > 0:
for comment in model_hijack.comments: for comment in model_hijack.comments:
@ -362,13 +370,17 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
if p.n_iter > 1: if p.n_iter > 1:
shared.state.job = f"Batch {n+1} out of {p.n_iter}" shared.state.job = f"Batch {n+1} out of {p.n_iter}"
samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength) with devices.autocast():
samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength)
if state.interrupted: if state.interrupted:
# if we are interruped, sample returns just noise # if we are interruped, sample returns just noise
# use the image collected previously in sampler loop # use the image collected previously in sampler loop
samples_ddim = shared.state.current_latent samples_ddim = shared.state.current_latent
samples_ddim = samples_ddim.to(devices.dtype)
x_samples_ddim = p.sd_model.decode_first_stage(samples_ddim) x_samples_ddim = p.sd_model.decode_first_stage(samples_ddim)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
@ -394,6 +406,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-face-restoration") images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-face-restoration")
x_sample = modules.face_restoration.restore_faces(x_sample) x_sample = modules.face_restoration.restore_faces(x_sample)
devices.torch_gc()
devices.torch_gc() devices.torch_gc()
@ -530,7 +543,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
# GC now before running the next img2img to prevent running out of memory # GC now before running the next img2img to prevent running out of memory
x = None x = None
devices.torch_gc() devices.torch_gc()
samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps) samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps)
return samples return samples

View file

@ -1,19 +1,7 @@
import re import re
from collections import namedtuple from collections import namedtuple
import torch
import modules.shared as shared import lark
re_prompt = re.compile(r'''
(.*?)
\[
([^]:]+):
(?:([^]:]*):)?
([0-9]*\.?[0-9]+)
]
|
(.+)
''', re.X)
# a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]" # a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]"
# will be represented with prompt_schedule like this (assuming steps=100): # will be represented with prompt_schedule like this (assuming steps=100):
@ -23,71 +11,96 @@ re_prompt = re.compile(r'''
# [75, 'fantasy landscape with a lake and an oak in background masterful'] # [75, 'fantasy landscape with a lake and an oak in background masterful']
# [100, 'fantasy landscape with a lake and a christmas tree in background masterful'] # [100, 'fantasy landscape with a lake and a christmas tree in background masterful']
schedule_parser = lark.Lark(r"""
!start: (prompt | /[][():]/+)*
prompt: (emphasized | scheduled | plain | WHITESPACE)*
!emphasized: "(" prompt ")"
| "(" prompt ":" prompt ")"
| "[" prompt "]"
scheduled: "[" [prompt ":"] prompt ":" [WHITESPACE] NUMBER "]"
WHITESPACE: /\s+/
plain: /([^\\\[\]():]|\\.)+/
%import common.SIGNED_NUMBER -> NUMBER
""")
def get_learned_conditioning_prompt_schedules(prompts, steps): def get_learned_conditioning_prompt_schedules(prompts, steps):
res = [] """
cache = {} >>> g = lambda p: get_learned_conditioning_prompt_schedules([p], 10)[0]
>>> g("test")
[[10, 'test']]
>>> g("a [b:3]")
[[3, 'a '], [10, 'a b']]
>>> g("a [b: 3]")
[[3, 'a '], [10, 'a b']]
>>> g("a [[[b]]:2]")
[[2, 'a '], [10, 'a [[b]]']]
>>> g("[(a:2):3]")
[[3, ''], [10, '(a:2)']]
>>> g("a [b : c : 1] d")
[[1, 'a b d'], [10, 'a c d']]
>>> g("a[b:[c:d:2]:1]e")
[[1, 'abe'], [2, 'ace'], [10, 'ade']]
>>> g("a [unbalanced")
[[10, 'a [unbalanced']]
>>> g("a [b:.5] c")
[[5, 'a c'], [10, 'a b c']]
>>> g("a [{b|d{:.5] c") # not handling this right now
[[5, 'a c'], [10, 'a {b|d{ c']]
>>> g("((a][:b:c [d:3]")
[[3, '((a][:b:c '], [10, '((a][:b:c d']]
"""
for prompt in prompts: def collect_steps(steps, tree):
prompt_schedule: list[list[str | int]] = [[steps, ""]] l = [steps]
class CollectSteps(lark.Visitor):
def scheduled(self, tree):
tree.children[-1] = float(tree.children[-1])
if tree.children[-1] < 1:
tree.children[-1] *= steps
tree.children[-1] = min(steps, int(tree.children[-1]))
l.append(tree.children[-1])
CollectSteps().visit(tree)
return sorted(set(l))
cached = cache.get(prompt, None) def at_step(step, tree):
if cached is not None: class AtStep(lark.Transformer):
res.append(cached) def scheduled(self, args):
continue before, after, _, when = args
yield before or () if step <= when else after
def start(self, args):
def flatten(x):
if type(x) == str:
yield x
else:
for gen in x:
yield from flatten(gen)
return ''.join(flatten(args))
def plain(self, args):
yield args[0].value
def __default__(self, data, children, meta):
for child in children:
yield from child
return AtStep().transform(tree)
for m in re_prompt.finditer(prompt): def get_schedule(prompt):
plaintext = m.group(1) if m.group(5) is None else m.group(5) try:
concept_from = m.group(2) tree = schedule_parser.parse(prompt)
concept_to = m.group(3) except lark.exceptions.LarkError as e:
if concept_to is None: if 0:
concept_to = concept_from import traceback
concept_from = "" traceback.print_exc()
swap_position = float(m.group(4)) if m.group(4) is not None else None return [[steps, prompt]]
return [[t, at_step(t, tree)] for t in collect_steps(steps, tree)]
if swap_position is not None: promptdict = {prompt: get_schedule(prompt) for prompt in set(prompts)}
if swap_position < 1: return [promptdict[prompt] for prompt in prompts]
swap_position = swap_position * steps
swap_position = int(min(swap_position, steps))
swap_index = None
found_exact_index = False
for i in range(len(prompt_schedule)):
end_step = prompt_schedule[i][0]
prompt_schedule[i][1] += plaintext
if swap_position is not None and swap_index is None:
if swap_position == end_step:
swap_index = i
found_exact_index = True
if swap_position < end_step:
swap_index = i
if swap_index is not None:
if not found_exact_index:
prompt_schedule.insert(swap_index, [swap_position, prompt_schedule[swap_index][1]])
for i in range(len(prompt_schedule)):
end_step = prompt_schedule[i][0]
must_replace = swap_position < end_step
prompt_schedule[i][1] += concept_to if must_replace else concept_from
res.append(prompt_schedule)
cache[prompt] = prompt_schedule
#for t in prompt_schedule:
# print(t)
return res
ScheduledPromptConditioning = namedtuple("ScheduledPromptConditioning", ["end_at_step", "cond"]) ScheduledPromptConditioning = namedtuple("ScheduledPromptConditioning", ["end_at_step", "cond"])
ScheduledPromptBatch = namedtuple("ScheduledPromptBatch", ["shape", "schedules"]) ScheduledPromptBatch = namedtuple("ScheduledPromptBatch", ["shape", "schedules"])
def get_learned_conditioning(prompts, steps): def get_learned_conditioning(model, prompts, steps):
res = [] res = []
prompt_schedules = get_learned_conditioning_prompt_schedules(prompts, steps) prompt_schedules = get_learned_conditioning_prompt_schedules(prompts, steps)
@ -101,7 +114,7 @@ def get_learned_conditioning(prompts, steps):
continue continue
texts = [x[1] for x in prompt_schedule] texts = [x[1] for x in prompt_schedule]
conds = shared.sd_model.get_learned_conditioning(texts) conds = model.get_learned_conditioning(texts)
cond_schedule = [] cond_schedule = []
for i, (end_at_step, text) in enumerate(prompt_schedule): for i, (end_at_step, text) in enumerate(prompt_schedule):
@ -114,12 +127,13 @@ def get_learned_conditioning(prompts, steps):
def reconstruct_cond_batch(c: ScheduledPromptBatch, current_step): def reconstruct_cond_batch(c: ScheduledPromptBatch, current_step):
res = torch.zeros(c.shape, device=shared.device, dtype=next(shared.sd_model.parameters()).dtype) param = c.schedules[0][0].cond
res = torch.zeros(c.shape, device=param.device, dtype=param.dtype)
for i, cond_schedule in enumerate(c.schedules): for i, cond_schedule in enumerate(c.schedules):
target_index = 0 target_index = 0
for curret_index, (end_at, cond) in enumerate(cond_schedule): for current, (end_at, cond) in enumerate(cond_schedule):
if current_step <= end_at: if current_step <= end_at:
target_index = curret_index target_index = current
break break
res[i] = cond_schedule[target_index].cond res[i] = cond_schedule[target_index].cond
@ -157,23 +171,26 @@ def parse_prompt_attention(text):
\\ - literal character '\' \\ - literal character '\'
anything else - just text anything else - just text
Example: >>> parse_prompt_attention('normal text')
[['normal text', 1.0]]
'a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).' >>> parse_prompt_attention('an (important) word')
[['an ', 1.0], ['important', 1.1], [' word', 1.0]]
produces: >>> parse_prompt_attention('(unbalanced')
[['unbalanced', 1.1]]
[ >>> parse_prompt_attention('\(literal\]')
['a ', 1.0], [['(literal]', 1.0]]
['house', 1.5730000000000004], >>> parse_prompt_attention('(unnecessary)(parens)')
[' ', 1.1], [['unnecessaryparens', 1.1]]
['on', 1.0], >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
[' a ', 1.1], [['a ', 1.0],
['hill', 0.55], ['house', 1.5730000000000004],
[', sun, ', 1.1], [' ', 1.1],
['sky', 1.4641000000000006], ['on', 1.0],
['.', 1.1] [' a ', 1.1],
] ['hill', 0.55],
[', sun, ', 1.1],
['sky', 1.4641000000000006],
['.', 1.1]]
""" """
res = [] res = []
@ -215,4 +232,19 @@ def parse_prompt_attention(text):
if len(res) == 0: if len(res) == 0:
res = [["", 1.0]] res = [["", 1.0]]
# merge runs of identical weights
i = 0
while i + 1 < len(res):
if res[i][1] == res[i + 1][1]:
res[i][0] += res[i + 1][0]
res.pop(i + 1)
else:
i += 1
return res return res
if __name__ == "__main__":
import doctest
doctest.testmod(optionflags=doctest.NORMALIZE_WHITESPACE)
else:
import torch # doctest faster

View file

@ -8,7 +8,7 @@ import torch
from basicsr.utils.download_util import load_file_from_url from basicsr.utils.download_util import load_file_from_url
import modules.upscaler import modules.upscaler
from modules import shared, modelloader from modules import devices, modelloader
from modules.paths import models_path from modules.paths import models_path
from modules.scunet_model_arch import SCUNet as net from modules.scunet_model_arch import SCUNet as net
@ -51,12 +51,12 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
if model is None: if model is None:
return img return img
device = shared.device device = devices.device_scunet
img = np.array(img) img = np.array(img)
img = img[:, :, ::-1] img = img[:, :, ::-1]
img = np.moveaxis(img, 2, 0) / 255 img = np.moveaxis(img, 2, 0) / 255
img = torch.from_numpy(img).float() img = torch.from_numpy(img).float()
img = img.unsqueeze(0).to(shared.device) img = img.unsqueeze(0).to(device)
img = img.to(device) img = img.to(device)
with torch.no_grad(): with torch.no_grad():
@ -69,7 +69,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
return PIL.Image.fromarray(output, 'RGB') return PIL.Image.fromarray(output, 'RGB')
def load_model(self, path: str): def load_model(self, path: str):
device = shared.device device = devices.device_scunet
if "http" in path: if "http" in path:
filename = load_file_from_url(url=self.model_url, model_dir=self.model_path, file_name="%s.pth" % self.name, filename = load_file_from_url(url=self.model_url, model_dir=self.model_path, file_name="%s.pth" % self.name,
progress=True) progress=True)

View file

@ -127,7 +127,7 @@ class VanillaStableDiffusionSampler:
return res return res
def initialize(self, p): def initialize(self, p):
self.eta = p.eta or opts.eta_ddim self.eta = p.eta if p.eta is not None else opts.eta_ddim
for fieldname in ['p_sample_ddim', 'p_sample_plms']: for fieldname in ['p_sample_ddim', 'p_sample_plms']:
if hasattr(self.sampler, fieldname): if hasattr(self.sampler, fieldname):

View file

@ -12,7 +12,7 @@ import modules.interrogate
import modules.memmon import modules.memmon
import modules.sd_models import modules.sd_models
import modules.styles import modules.styles
from modules.devices import get_optimal_device import modules.devices as devices
from modules.paths import script_path, sd_path from modules.paths import script_path, sd_path
sd_model_file = os.path.join(script_path, 'model.ckpt') sd_model_file = os.path.join(script_path, 'model.ckpt')
@ -46,6 +46,7 @@ parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with
parser.add_argument("--opt-split-attention", action='store_true', help="force-enables cross-attention layer optimization. By default, it's on for torch.cuda and off for other torch devices.") parser.add_argument("--opt-split-attention", action='store_true', help="force-enables cross-attention layer optimization. By default, it's on for torch.cuda and off for other torch devices.")
parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization") parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find") parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
parser.add_argument("--use-cpu", nargs='+',choices=['SD', 'GFPGAN', 'BSRGAN', 'ESRGAN', 'SCUNet', 'CodeFormer'], help="use CPU as torch device for specified modules", default=[])
parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests") parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None) parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None)
parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False) parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False)
@ -54,6 +55,7 @@ parser.add_argument("--hide-ui-dir-config", action='store_true', help="hide dire
parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default=os.path.join(script_path, 'config.json')) parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default=os.path.join(script_path, 'config.json'))
parser.add_argument("--gradio-debug", action='store_true', help="launch gradio with --debug option") parser.add_argument("--gradio-debug", action='store_true', help="launch gradio with --debug option")
parser.add_argument("--gradio-auth", type=str, help='set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None) parser.add_argument("--gradio-auth", type=str, help='set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
parser.add_argument("--gradio-img2img-tool", type=str, help='gradio image uploader tool: can be either editor for ctopping, or color-sketch for drawing', choices=["color-sketch", "editor"], default="color-sketch")
parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last") parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last")
parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(script_path, 'styles.csv')) parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(script_path, 'styles.csv'))
parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False) parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False)
@ -63,7 +65,11 @@ parser.add_argument("--enable-console-prompts", action='store_true', help="print
cmd_opts = parser.parse_args() cmd_opts = parser.parse_args()
device = get_optimal_device()
devices.device, devices.device_gfpgan, devices.device_bsrgan, devices.device_esrgan, devices.device_scunet, devices.device_codeformer = \
(devices.cpu if x in cmd_opts.use_cpu else devices.get_optimal_device() for x in ['SD', 'GFPGAN', 'BSRGAN', 'ESRGAN', 'SCUNet', 'CodeFormer'])
device = devices.device
batch_cond_uncond = cmd_opts.always_batch_cond_uncond or not (cmd_opts.lowvram or cmd_opts.medvram) batch_cond_uncond = cmd_opts.always_batch_cond_uncond or not (cmd_opts.lowvram or cmd_opts.medvram)
parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram
@ -183,7 +189,7 @@ options_templates.update(options_section(('upscaling', "Upscaling"), {
"SWIN_tile": OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}), "SWIN_tile": OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}),
"SWIN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}), "SWIN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
"ldsr_steps": OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}), "ldsr_steps": OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}),
"upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Radio, lambda: {"choices": [x.name for x in sd_upscalers]}), "upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}),
})) }))
options_templates.update(options_section(('face-restoration', "Face restoration"), { options_templates.update(options_section(('face-restoration', "Face restoration"), {
@ -195,7 +201,7 @@ options_templates.update(options_section(('face-restoration', "Face restoration"
options_templates.update(options_section(('system', "System"), { options_templates.update(options_section(('system', "System"), {
"memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation. Set to 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}), "memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation. Set to 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}),
"samples_log_stdout": OptionInfo(False, "Always print all generation info to standard output"), "samples_log_stdout": OptionInfo(False, "Always print all generation info to standard output"),
"multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job. Broken in PyCharm console."), "multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."),
})) }))
options_templates.update(options_section(('sd', "Stable Diffusion"), { options_templates.update(options_section(('sd', "Stable Diffusion"), {
@ -204,7 +210,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
"save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"), "save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
"img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."), "img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."),
"enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."), "enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."),
"enable_emphasis": OptionInfo(True, "Eemphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"), "enable_emphasis": OptionInfo(True, "Emphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"),
"use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."), "use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."),
"enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"), "enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
"filter_nsfw": OptionInfo(False, "Filter NSFW content"), "filter_nsfw": OptionInfo(False, "Filter NSFW content"),
@ -224,6 +230,7 @@ options_templates.update(options_section(('ui', "User interface"), {
"show_progressbar": OptionInfo(True, "Show progressbar"), "show_progressbar": OptionInfo(True, "Show progressbar"),
"show_progress_every_n_steps": OptionInfo(0, "Show show image creation progress every N sampling steps. Set 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 32, "step": 1}), "show_progress_every_n_steps": OptionInfo(0, "Show show image creation progress every N sampling steps. Set 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 32, "step": 1}),
"return_grid": OptionInfo(True, "Show grid in results for web"), "return_grid": OptionInfo(True, "Show grid in results for web"),
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"), "add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
"font": OptionInfo("", "Font for image grids that have text"), "font": OptionInfo("", "Font for image grids that have text"),
"js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"), "js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"),

View file

@ -9,6 +9,9 @@ from torchvision import transforms
import random import random
import tqdm import tqdm
from modules import devices from modules import devices
import re
re_tag = re.compile(r"[a-zA-Z][_\w\d()]+")
class PersonalizedBase(Dataset): class PersonalizedBase(Dataset):
@ -38,8 +41,8 @@ class PersonalizedBase(Dataset):
image = image.resize((self.width, self.height), PIL.Image.BICUBIC) image = image.resize((self.width, self.height), PIL.Image.BICUBIC)
filename = os.path.basename(path) filename = os.path.basename(path)
filename_tokens = os.path.splitext(filename)[0].replace('_', '-').replace(' ', '-').split('-') filename_tokens = os.path.splitext(filename)[0]
filename_tokens = [token for token in filename_tokens if token.isalpha()] filename_tokens = re_tag.findall(filename_tokens)
npimage = np.array(image).astype(np.uint8) npimage = np.array(image).astype(np.uint8)
npimage = (npimage / 127.5 - 1.0).astype(np.float32) npimage = (npimage / 127.5 - 1.0).astype(np.float32)

View file

@ -26,7 +26,9 @@ def preprocess(process_src, process_dst, process_flip, process_split, process_ca
if process_caption: if process_caption:
caption = "-" + shared.interrogator.generate_caption(image) caption = "-" + shared.interrogator.generate_caption(image)
else: else:
caption = "" caption = filename
caption = os.path.splitext(caption)[0]
caption = os.path.basename(caption)
image.save(os.path.join(dst, f"{index:05}-{subindex[0]}{caption}.png")) image.save(os.path.join(dst, f"{index:05}-{subindex[0]}{caption}.png"))
subindex[0] += 1 subindex[0] += 1

View file

@ -164,7 +164,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps,
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt') filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%d-%m"), embedding_name) log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), embedding_name)
if save_embedding_every > 0: if save_embedding_every > 0:
embedding_dir = os.path.join(log_directory, "embeddings") embedding_dir = os.path.join(log_directory, "embeddings")

View file

@ -48,5 +48,8 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
if opts.samples_log_stdout: if opts.samples_log_stdout:
print(generation_info_js) print(generation_info_js)
if opts.do_not_show_images:
processed.images = []
return processed.images, generation_info_js, plaintext_to_html(processed.info) return processed.images, generation_info_js, plaintext_to_html(processed.info)

View file

@ -69,7 +69,7 @@ random_symbol = '\U0001f3b2\ufe0f' # 🎲️
reuse_symbol = '\u267b\ufe0f' # ♻️ reuse_symbol = '\u267b\ufe0f' # ♻️
art_symbol = '\U0001f3a8' # 🎨 art_symbol = '\U0001f3a8' # 🎨
paste_symbol = '\u2199\ufe0f' # ↙ paste_symbol = '\u2199\ufe0f' # ↙
folder_symbol = '\uD83D\uDCC2' folder_symbol = '\U0001f4c2' # 📂
def plaintext_to_html(text): def plaintext_to_html(text):
text = "<p>" + "<br>\n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "</p>" text = "<p>" + "<br>\n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "</p>"
@ -196,6 +196,11 @@ def wrap_gradio_call(func, extra_outputs=None):
res = extra_outputs_array + [f"<div class='error'>{plaintext_to_html(type(e).__name__+': '+str(e))}</div>"] res = extra_outputs_array + [f"<div class='error'>{plaintext_to_html(type(e).__name__+': '+str(e))}</div>"]
elapsed = time.perf_counter() - t elapsed = time.perf_counter() - t
elapsed_m = int(elapsed // 60)
elapsed_s = elapsed % 60
elapsed_text = f"{elapsed_s:.2f}s"
if (elapsed_m > 0):
elapsed_text = f"{elapsed_m}m "+elapsed_text
if run_memmon: if run_memmon:
mem_stats = {k: -(v//-(1024*1024)) for k, v in shared.mem_mon.stop().items()} mem_stats = {k: -(v//-(1024*1024)) for k, v in shared.mem_mon.stop().items()}
@ -210,7 +215,7 @@ def wrap_gradio_call(func, extra_outputs=None):
vram_html = '' vram_html = ''
# last item is always HTML # last item is always HTML
res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr>{elapsed:.2f}s</p>{vram_html}</div>" res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr>{elapsed_text}</p>{vram_html}</div>"
shared.state.interrupted = False shared.state.interrupted = False
shared.state.job_count = 0 shared.state.job_count = 0
@ -386,14 +391,22 @@ def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info:
outputs=[seed, dummy_component] outputs=[seed, dummy_component]
) )
def update_token_counter(text, steps): def update_token_counter(text, steps):
prompt_schedules = get_learned_conditioning_prompt_schedules([text], steps) try:
prompt_schedules = get_learned_conditioning_prompt_schedules([text], steps)
except Exception:
# a parsing error can happen here during typing, and we don't want to bother the user with
# messages related to it in console
prompt_schedules = [[[steps, text]]]
flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules) flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules)
prompts = [prompt_text for step,prompt_text in flat_prompts] prompts = [prompt_text for step, prompt_text in flat_prompts]
tokens, token_count, max_length = max([model_hijack.tokenize(prompt) for prompt in prompts], key=lambda args: args[1]) tokens, token_count, max_length = max([model_hijack.tokenize(prompt) for prompt in prompts], key=lambda args: args[1])
style_class = ' class="red"' if (token_count > max_length) else "" style_class = ' class="red"' if (token_count > max_length) else ""
return f"<span {style_class}>{token_count}/{max_length}</span>" return f"<span {style_class}>{token_count}/{max_length}</span>"
def create_toprow(is_img2img): def create_toprow(is_img2img):
id_part = "img2img" if is_img2img else "txt2img" id_part = "img2img" if is_img2img else "txt2img"
@ -636,7 +649,7 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Tabs(elem_id="mode_img2img") as tabs_img2img_mode: with gr.Tabs(elem_id="mode_img2img") as tabs_img2img_mode:
with gr.TabItem('img2img', id='img2img'): with gr.TabItem('img2img', id='img2img'):
init_img = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil") init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_img2img_tool)
with gr.TabItem('Inpaint', id='inpaint'): with gr.TabItem('Inpaint', id='inpaint'):
init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA") init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA")
@ -658,7 +671,7 @@ def create_ui(wrap_gradio_gpu_call):
with gr.TabItem('Batch img2img', id='batch'): with gr.TabItem('Batch img2img', id='batch'):
hidden = '<br>Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else '' hidden = '<br>Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else ''
gr.HTML(f"<p class=\"text-gray-500\">Process images in a directory on the same machine where the server is running.{hidden}</p>") gr.HTML(f"<p class=\"text-gray-500\">Process images in a directory on the same machine where the server is running.<br>Use an empty output directory to save pictures normally instead of writing to the output directory.{hidden}</p>")
img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs) img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs)
img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs) img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs)

View file

@ -22,3 +22,4 @@ clean-fid
resize-right resize-right
torchdiffeq torchdiffeq
kornia kornia
lark

View file

@ -21,3 +21,4 @@ clean-fid==0.1.29
resize-right==0.0.2 resize-right==0.0.2
torchdiffeq==0.2.3 torchdiffeq==0.2.3
kornia==0.6.7 kornia==0.6.7
lark==1.1.2

View file

@ -1,5 +1,6 @@
from collections import namedtuple from collections import namedtuple
from copy import copy from copy import copy
from itertools import permutations
import random import random
from PIL import Image from PIL import Image
@ -29,6 +30,31 @@ def apply_prompt(p, x, xs):
p.negative_prompt = p.negative_prompt.replace(xs[0], x) p.negative_prompt = p.negative_prompt.replace(xs[0], x)
def apply_order(p, x, xs):
token_order = []
# Initally grab the tokens from the prompt, so they can be replaced in order of earliest seen
for token in x:
token_order.append((p.prompt.find(token), token))
token_order.sort(key=lambda t: t[0])
prompt_parts = []
# Split the prompt up, taking out the tokens
for _, token in token_order:
n = p.prompt.find(token)
prompt_parts.append(p.prompt[0:n])
p.prompt = p.prompt[n + len(token):]
# Rebuild the prompt with the tokens in the order we want
prompt_tmp = ""
for idx, part in enumerate(prompt_parts):
prompt_tmp += part
prompt_tmp += x[idx]
p.prompt = prompt_tmp + p.prompt
samplers_dict = {} samplers_dict = {}
for i, sampler in enumerate(modules.sd_samplers.samplers): for i, sampler in enumerate(modules.sd_samplers.samplers):
samplers_dict[sampler.name.lower()] = i samplers_dict[sampler.name.lower()] = i
@ -60,16 +86,26 @@ def format_value_add_label(p, opt, x):
def format_value(p, opt, x): def format_value(p, opt, x):
if type(x) == float: if type(x) == float:
x = round(x, 8) x = round(x, 8)
return x return x
def format_value_join_list(p, opt, x):
return ", ".join(x)
def do_nothing(p, x, xs): def do_nothing(p, x, xs):
pass pass
def format_nothing(p, opt, x): def format_nothing(p, opt, x):
return "" return ""
def str_permutations(x):
"""dummy function for specifying it in AxisOption's type when you want to get a list of permutations"""
return x
AxisOption = namedtuple("AxisOption", ["label", "type", "apply", "format_value"]) AxisOption = namedtuple("AxisOption", ["label", "type", "apply", "format_value"])
AxisOptionImg2Img = namedtuple("AxisOptionImg2Img", ["label", "type", "apply", "format_value"]) AxisOptionImg2Img = namedtuple("AxisOptionImg2Img", ["label", "type", "apply", "format_value"])
@ -82,6 +118,7 @@ axis_options = [
AxisOption("Steps", int, apply_field("steps"), format_value_add_label), AxisOption("Steps", int, apply_field("steps"), format_value_add_label),
AxisOption("CFG Scale", float, apply_field("cfg_scale"), format_value_add_label), AxisOption("CFG Scale", float, apply_field("cfg_scale"), format_value_add_label),
AxisOption("Prompt S/R", str, apply_prompt, format_value), AxisOption("Prompt S/R", str, apply_prompt, format_value),
AxisOption("Prompt order", str_permutations, apply_order, format_value_join_list),
AxisOption("Sampler", str, apply_sampler, format_value), AxisOption("Sampler", str, apply_sampler, format_value),
AxisOption("Checkpoint name", str, apply_checkpoint, format_value), AxisOption("Checkpoint name", str, apply_checkpoint, format_value),
AxisOption("Sigma Churn", float, apply_field("s_churn"), format_value_add_label), AxisOption("Sigma Churn", float, apply_field("s_churn"), format_value_add_label),
@ -131,6 +168,7 @@ re_range_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d
re_range_count = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\[(\d+)\s*\])?\s*") re_range_count = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\[(\d+)\s*\])?\s*")
re_range_count_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\[(\d+(?:.\d*)?)\s*\])?\s*") re_range_count_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\[(\d+(?:.\d*)?)\s*\])?\s*")
class Script(scripts.Script): class Script(scripts.Script):
def title(self): def title(self):
return "X/Y plot" return "X/Y plot"
@ -206,6 +244,8 @@ class Script(scripts.Script):
valslist_ext.append(val) valslist_ext.append(val)
valslist = valslist_ext valslist = valslist_ext
elif opt.type == str_permutations:
valslist = list(permutations(valslist))
valslist = [opt.type(x) for x in valslist] valslist = [opt.type(x) for x in valslist]

View file

@ -403,3 +403,7 @@ input[type="range"]{
.red { .red {
color: red; color: red;
} }
#img2img_image div.h-60{
height: 480px;
}