Merge branch 'master' into 6866-fix-hires-prompt-matrix
This commit is contained in:
commit
cfc9849f3f
32 changed files with 395 additions and 209 deletions
|
@ -104,8 +104,7 @@ Alternatively, use online services (like Google Colab):
|
|||
1. Install [Python 3.10.6](https://www.python.org/downloads/windows/), checking "Add Python to PATH"
|
||||
2. Install [git](https://git-scm.com/download/win).
|
||||
3. Download the stable-diffusion-webui repository, for example by running `git clone https://github.com/AUTOMATIC1111/stable-diffusion-webui.git`.
|
||||
4. Place stable diffusion checkpoint (`model.ckpt`) in the `models/Stable-diffusion` directory (see [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) for where to get it).
|
||||
5. Run `webui-user.bat` from Windows Explorer as normal, non-administrator, user.
|
||||
4. Run `webui-user.bat` from Windows Explorer as normal, non-administrator, user.
|
||||
|
||||
### Automatic Installation on Linux
|
||||
1. Install the dependencies:
|
||||
|
@ -121,7 +120,7 @@ sudo pacman -S wget git python3
|
|||
```bash
|
||||
bash <(wget -qO- https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui/master/webui.sh)
|
||||
```
|
||||
|
||||
3. Run `webui.sh`.
|
||||
### Installation on Apple Silicon
|
||||
|
||||
Find the instructions [here](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Installation-on-Apple-Silicon).
|
||||
|
|
|
@ -8,8 +8,8 @@ titles = {
|
|||
"DDIM": "Denoising Diffusion Implicit Models - best at inpainting",
|
||||
"DPM adaptive": "Ignores step count - uses a number of steps determined by the CFG and resolution",
|
||||
|
||||
"Batch count": "How many batches of images to create",
|
||||
"Batch size": "How many image to create in a single batch",
|
||||
"Batch count": "How many batches of images to create (has no impact on generation performance or VRAM usage)",
|
||||
"Batch size": "How many image to create in a single batch (increases generation performance at cost of higher VRAM usage)",
|
||||
"CFG Scale": "Classifier Free Guidance Scale - how strongly the image should conform to prompt - lower values produce more creative results",
|
||||
"Seed": "A value that determines the output of random number generator - if you create an image with same parameters and seed as another image, you'll get the same result",
|
||||
"\u{1f3b2}\ufe0f": "Set seed to -1, which will cause a new random number to be used every time",
|
||||
|
@ -17,7 +17,7 @@ titles = {
|
|||
"\u2199\ufe0f": "Read generation parameters from prompt or last generation if prompt is empty into user interface.",
|
||||
"\u{1f4c2}": "Open images output directory",
|
||||
"\u{1f4be}": "Save style",
|
||||
"\U0001F5D1": "Clear prompt",
|
||||
"\u{1f5d1}": "Clear prompt",
|
||||
"\u{1f4cb}": "Apply selected styles to current prompt",
|
||||
"\u{1f4d2}": "Paste available values into the field",
|
||||
"\u{1f3b4}": "Show extra networks",
|
||||
|
@ -66,8 +66,8 @@ titles = {
|
|||
|
||||
"Interrogate": "Reconstruct prompt from existing image and put it into the prompt field.",
|
||||
|
||||
"Images filename pattern": "Use following tags to define how filenames for images are chosen: [steps], [cfg], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [model_name], [prompt_words], [date], [datetime], [datetime<Format>], [datetime<Format><Time Zone>], [job_timestamp]; leave empty for default.",
|
||||
"Directory name pattern": "Use following tags to define how subdirectories for images and grids are chosen: [steps], [cfg], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [model_name], [prompt_words], [date], [datetime], [datetime<Format>], [datetime<Format><Time Zone>], [job_timestamp]; leave empty for default.",
|
||||
"Images filename pattern": "Use following tags to define how filenames for images are chosen: [steps], [cfg], [prompt_hash], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [model_name], [prompt_words], [date], [datetime], [datetime<Format>], [datetime<Format><Time Zone>], [job_timestamp]; leave empty for default.",
|
||||
"Directory name pattern": "Use following tags to define how subdirectories for images and grids are chosen: [steps], [cfg],[prompt_hash], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [model_name], [prompt_words], [date], [datetime], [datetime<Format>], [datetime<Format><Time Zone>], [job_timestamp]; leave empty for default.",
|
||||
"Max prompt words": "Set the maximum number of words to be used in the [prompt_words] option; ATTENTION: If the words are too long, they may exceed the maximum length of the file path that the system can handle",
|
||||
|
||||
"Loopback": "Process an image, use it as an input, repeat.",
|
||||
|
|
|
@ -242,7 +242,7 @@ def prepare_environment():
|
|||
|
||||
sys.argv += shlex.split(commandline_args)
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser = argparse.ArgumentParser(add_help=False)
|
||||
parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default='config.json')
|
||||
args, _ = parser.parse_known_args(sys.argv)
|
||||
|
||||
|
|
|
@ -1,21 +1,17 @@
|
|||
import sys, os, shlex
|
||||
import sys
|
||||
import contextlib
|
||||
import torch
|
||||
from modules import errors
|
||||
from packaging import version
|
||||
|
||||
if sys.platform == "darwin":
|
||||
from modules import mac_specific
|
||||
|
||||
|
||||
# has_mps is only available in nightly pytorch (for now) and macOS 12.3+.
|
||||
# check `getattr` and try it for compatibility
|
||||
def has_mps() -> bool:
|
||||
if not getattr(torch, 'has_mps', False):
|
||||
if sys.platform != "darwin":
|
||||
return False
|
||||
try:
|
||||
torch.zeros(1).to(torch.device("mps"))
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
else:
|
||||
return mac_specific.has_mps
|
||||
|
||||
def extract_device_id(args, name):
|
||||
for x in range(len(args)):
|
||||
|
@ -154,56 +150,3 @@ def test_for_nans(x, where):
|
|||
message += " Use --disable-nan-check commandline argument to disable this check."
|
||||
|
||||
raise NansException(message)
|
||||
|
||||
|
||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/79383
|
||||
orig_tensor_to = torch.Tensor.to
|
||||
def tensor_to_fix(self, *args, **kwargs):
|
||||
if self.device.type != 'mps' and \
|
||||
((len(args) > 0 and isinstance(args[0], torch.device) and args[0].type == 'mps') or \
|
||||
(isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps')):
|
||||
self = self.contiguous()
|
||||
return orig_tensor_to(self, *args, **kwargs)
|
||||
|
||||
|
||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/80800
|
||||
orig_layer_norm = torch.nn.functional.layer_norm
|
||||
def layer_norm_fix(*args, **kwargs):
|
||||
if len(args) > 0 and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps':
|
||||
args = list(args)
|
||||
args[0] = args[0].contiguous()
|
||||
return orig_layer_norm(*args, **kwargs)
|
||||
|
||||
|
||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/90532
|
||||
orig_tensor_numpy = torch.Tensor.numpy
|
||||
def numpy_fix(self, *args, **kwargs):
|
||||
if self.requires_grad:
|
||||
self = self.detach()
|
||||
return orig_tensor_numpy(self, *args, **kwargs)
|
||||
|
||||
|
||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/89784
|
||||
orig_cumsum = torch.cumsum
|
||||
orig_Tensor_cumsum = torch.Tensor.cumsum
|
||||
def cumsum_fix(input, cumsum_func, *args, **kwargs):
|
||||
if input.device.type == 'mps':
|
||||
output_dtype = kwargs.get('dtype', input.dtype)
|
||||
if output_dtype == torch.int64:
|
||||
return cumsum_func(input.cpu(), *args, **kwargs).to(input.device)
|
||||
elif cumsum_needs_bool_fix and output_dtype == torch.bool or cumsum_needs_int_fix and (output_dtype == torch.int8 or output_dtype == torch.int16):
|
||||
return cumsum_func(input.to(torch.int32), *args, **kwargs).to(torch.int64)
|
||||
return cumsum_func(input, *args, **kwargs)
|
||||
|
||||
|
||||
if has_mps():
|
||||
if version.parse(torch.__version__) < version.parse("1.13"):
|
||||
# PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working
|
||||
torch.Tensor.to = tensor_to_fix
|
||||
torch.nn.functional.layer_norm = layer_norm_fix
|
||||
torch.Tensor.numpy = numpy_fix
|
||||
elif version.parse(torch.__version__) > version.parse("1.13.1"):
|
||||
cumsum_needs_int_fix = not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.ShortTensor([1,1]).to(torch.device("mps")).cumsum(0))
|
||||
cumsum_needs_bool_fix = not torch.BoolTensor([True,True]).to(device=torch.device("mps"), dtype=torch.int64).equal(torch.BoolTensor([True,False]).to(torch.device("mps")).cumsum(0))
|
||||
torch.cumsum = lambda input, *args, **kwargs: ( cumsum_fix(input, orig_cumsum, *args, **kwargs) )
|
||||
torch.Tensor.cumsum = lambda self, *args, **kwargs: ( cumsum_fix(self, orig_Tensor_cumsum, *args, **kwargs) )
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# this file is adapted from https://github.com/victorca25/iNNfer
|
||||
|
||||
from collections import OrderedDict
|
||||
import math
|
||||
import functools
|
||||
import torch
|
||||
|
|
|
@ -2,6 +2,7 @@ import os
|
|||
import sys
|
||||
import traceback
|
||||
|
||||
import time
|
||||
import git
|
||||
|
||||
from modules import paths, shared
|
||||
|
@ -25,6 +26,7 @@ class Extension:
|
|||
self.status = ''
|
||||
self.can_update = False
|
||||
self.is_builtin = is_builtin
|
||||
self.version = ''
|
||||
|
||||
repo = None
|
||||
try:
|
||||
|
@ -40,6 +42,10 @@ class Extension:
|
|||
try:
|
||||
self.remote = next(repo.remote().urls, None)
|
||||
self.status = 'unknown'
|
||||
head = repo.head.commit
|
||||
ts = time.asctime(time.gmtime(repo.head.commit.committed_date))
|
||||
self.version = f'{head.hexsha[:8]} ({ts})'
|
||||
|
||||
except Exception:
|
||||
self.remote = None
|
||||
|
||||
|
|
|
@ -74,8 +74,8 @@ def image_from_url_text(filedata):
|
|||
return image
|
||||
|
||||
|
||||
def add_paste_fields(tabname, init_img, fields):
|
||||
paste_fields[tabname] = {"init_img": init_img, "fields": fields}
|
||||
def add_paste_fields(tabname, init_img, fields, override_settings_component=None):
|
||||
paste_fields[tabname] = {"init_img": init_img, "fields": fields, "override_settings_component": override_settings_component}
|
||||
|
||||
# backwards compatibility for existing extensions
|
||||
import modules.ui
|
||||
|
@ -110,6 +110,7 @@ def connect_paste_params_buttons():
|
|||
for binding in registered_param_bindings:
|
||||
destination_image_component = paste_fields[binding.tabname]["init_img"]
|
||||
fields = paste_fields[binding.tabname]["fields"]
|
||||
override_settings_component = binding.override_settings_component or paste_fields[binding.tabname]["override_settings_component"]
|
||||
|
||||
destination_width_component = next(iter([field for field, name in fields if name == "Size-1"] if fields else []), None)
|
||||
destination_height_component = next(iter([field for field, name in fields if name == "Size-2"] if fields else []), None)
|
||||
|
@ -130,7 +131,7 @@ def connect_paste_params_buttons():
|
|||
)
|
||||
|
||||
if binding.source_text_component is not None and fields is not None:
|
||||
connect_paste(binding.paste_button, fields, binding.source_text_component, binding.override_settings_component, binding.tabname)
|
||||
connect_paste(binding.paste_button, fields, binding.source_text_component, override_settings_component, binding.tabname)
|
||||
|
||||
if binding.source_tabname is not None and fields is not None:
|
||||
paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration'] + (["Seed"] if shared.opts.send_seed else [])
|
||||
|
|
|
@ -4,6 +4,7 @@ import os.path
|
|||
|
||||
import filelock
|
||||
|
||||
from modules import shared
|
||||
from modules.paths import data_path
|
||||
|
||||
|
||||
|
@ -68,6 +69,9 @@ def sha256(filename, title):
|
|||
if sha256_value is not None:
|
||||
return sha256_value
|
||||
|
||||
if shared.cmd_opts.no_hashing:
|
||||
return None
|
||||
|
||||
print(f"Calculating sha256 for {filename}: ", end='')
|
||||
sha256_value = calculate_sha256(filename)
|
||||
print(f"{sha256_value}")
|
||||
|
|
|
@ -307,7 +307,7 @@ class Hypernetwork:
|
|||
def shorthash(self):
|
||||
sha256 = hashes.sha256(self.filename, f'hypernet/{self.name}')
|
||||
|
||||
return sha256[0:10]
|
||||
return sha256[0:10] if sha256 else None
|
||||
|
||||
|
||||
def list_hypernetworks(path):
|
||||
|
@ -380,8 +380,8 @@ def apply_single_hypernetwork(hypernetwork, context_k, context_v, layer=None):
|
|||
layer.hyper_k = hypernetwork_layers[0]
|
||||
layer.hyper_v = hypernetwork_layers[1]
|
||||
|
||||
context_k = hypernetwork_layers[0](context_k)
|
||||
context_v = hypernetwork_layers[1](context_v)
|
||||
context_k = devices.cond_cast_unet(hypernetwork_layers[0](devices.cond_cast_float(context_k)))
|
||||
context_v = devices.cond_cast_unet(hypernetwork_layers[1](devices.cond_cast_float(context_v)))
|
||||
return context_k, context_v
|
||||
|
||||
|
||||
|
|
|
@ -16,8 +16,9 @@ from PIL import Image, ImageFont, ImageDraw, PngImagePlugin
|
|||
from fonts.ttf import Roboto
|
||||
import string
|
||||
import json
|
||||
import hashlib
|
||||
|
||||
from modules import sd_samplers, shared, script_callbacks
|
||||
from modules import sd_samplers, shared, script_callbacks, errors
|
||||
from modules.shared import opts, cmd_opts
|
||||
|
||||
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
|
||||
|
@ -130,7 +131,7 @@ class GridAnnotation:
|
|||
self.size = None
|
||||
|
||||
|
||||
def draw_grid_annotations(im, width, height, hor_texts, ver_texts):
|
||||
def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0):
|
||||
def wrap(drawing, text, font, line_length):
|
||||
lines = ['']
|
||||
for word in text.split():
|
||||
|
@ -194,32 +195,35 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts):
|
|||
line.allowed_width = allowed_width
|
||||
|
||||
hor_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing for lines in hor_texts]
|
||||
ver_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing * len(lines) for lines in
|
||||
ver_texts]
|
||||
ver_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing * len(lines) for lines in ver_texts]
|
||||
|
||||
pad_top = 0 if sum(hor_text_heights) == 0 else max(hor_text_heights) + line_spacing * 2
|
||||
|
||||
result = Image.new("RGB", (im.width + pad_left, im.height + pad_top), "white")
|
||||
result.paste(im, (pad_left, pad_top))
|
||||
result = Image.new("RGB", (im.width + pad_left + margin * (cols-1), im.height + pad_top + margin * (rows-1)), "white")
|
||||
|
||||
for row in range(rows):
|
||||
for col in range(cols):
|
||||
cell = im.crop((width * col, height * row, width * (col+1), height * (row+1)))
|
||||
result.paste(cell, (pad_left + (width + margin) * col, pad_top + (height + margin) * row))
|
||||
|
||||
d = ImageDraw.Draw(result)
|
||||
|
||||
for col in range(cols):
|
||||
x = pad_left + width * col + width / 2
|
||||
x = pad_left + (width + margin) * col + width / 2
|
||||
y = pad_top / 2 - hor_text_heights[col] / 2
|
||||
|
||||
draw_texts(d, x, y, hor_texts[col], fnt, fontsize)
|
||||
|
||||
for row in range(rows):
|
||||
x = pad_left / 2
|
||||
y = pad_top + height * row + height / 2 - ver_text_heights[row] / 2
|
||||
y = pad_top + (height + margin) * row + height / 2 - ver_text_heights[row] / 2
|
||||
|
||||
draw_texts(d, x, y, ver_texts[row], fnt, fontsize)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def draw_prompt_matrix(im, width, height, all_prompts):
|
||||
def draw_prompt_matrix(im, width, height, all_prompts, margin=0):
|
||||
prompts = all_prompts[1:]
|
||||
boundary = math.ceil(len(prompts) / 2)
|
||||
|
||||
|
@ -229,7 +233,7 @@ def draw_prompt_matrix(im, width, height, all_prompts):
|
|||
hor_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_horiz)] for pos in range(1 << len(prompts_horiz))]
|
||||
ver_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_vert)] for pos in range(1 << len(prompts_vert))]
|
||||
|
||||
return draw_grid_annotations(im, width, height, hor_texts, ver_texts)
|
||||
return draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin)
|
||||
|
||||
|
||||
def resize_image(resize_mode, im, width, height, upscaler_name=None):
|
||||
|
@ -340,6 +344,7 @@ class FilenameGenerator:
|
|||
'date': lambda self: datetime.datetime.now().strftime('%Y-%m-%d'),
|
||||
'datetime': lambda self, *args: self.datetime(*args), # accepts formats: [datetime], [datetime<Format>], [datetime<Format><Time Zone>]
|
||||
'job_timestamp': lambda self: getattr(self.p, "job_timestamp", shared.state.job_timestamp),
|
||||
'prompt_hash': lambda self: hashlib.sha256(self.prompt.encode()).hexdigest()[0:8],
|
||||
'prompt': lambda self: sanitize_filename_part(self.prompt),
|
||||
'prompt_no_styles': lambda self: self.prompt_no_style(),
|
||||
'prompt_spaces': lambda self: sanitize_filename_part(self.prompt, replace_spaces=False),
|
||||
|
@ -548,6 +553,8 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
|
|||
elif extension.lower() in (".jpg", ".jpeg", ".webp"):
|
||||
if image_to_save.mode == 'RGBA':
|
||||
image_to_save = image_to_save.convert("RGB")
|
||||
elif image_to_save.mode == 'I;16':
|
||||
image_to_save = image_to_save.point(lambda p: p * 0.0038910505836576).convert("RGB" if extension.lower() == ".webp" else "L")
|
||||
|
||||
image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality)
|
||||
|
||||
|
@ -570,17 +577,19 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
|
|||
|
||||
image.already_saved_as = fullfn
|
||||
|
||||
target_side_length = 4000
|
||||
oversize = image.width > target_side_length or image.height > target_side_length
|
||||
if opts.export_for_4chan and (oversize or os.stat(fullfn).st_size > 4 * 1024 * 1024):
|
||||
oversize = image.width > opts.target_side_length or image.height > opts.target_side_length
|
||||
if opts.export_for_4chan and (oversize or os.stat(fullfn).st_size > opts.img_downscale_threshold * 1024 * 1024):
|
||||
ratio = image.width / image.height
|
||||
|
||||
if oversize and ratio > 1:
|
||||
image = image.resize((target_side_length, image.height * target_side_length // image.width), LANCZOS)
|
||||
image = image.resize((opts.target_side_length, image.height * opts.target_side_length // image.width), LANCZOS)
|
||||
elif oversize:
|
||||
image = image.resize((image.width * target_side_length // image.height, target_side_length), LANCZOS)
|
||||
image = image.resize((image.width * opts.target_side_length // image.height, opts.target_side_length), LANCZOS)
|
||||
|
||||
_atomically_save_image(image, fullfn_without_extension, ".jpg")
|
||||
try:
|
||||
_atomically_save_image(image, fullfn_without_extension, ".jpg")
|
||||
except Exception as e:
|
||||
errors.display(e, "saving image as downscaled JPG")
|
||||
|
||||
if opts.save_txt and info is not None:
|
||||
txt_fullfn = f"{fullfn_without_extension}.txt"
|
||||
|
|
|
@ -73,10 +73,12 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args):
|
|||
|
||||
if not save_normally:
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
if processed_image.mode == 'RGBA':
|
||||
processed_image = processed_image.convert("RGB")
|
||||
processed_image.save(os.path.join(output_dir, filename))
|
||||
|
||||
|
||||
def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, *args):
|
||||
def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, *args):
|
||||
override_settings = create_override_settings_dict(override_settings_texts)
|
||||
|
||||
is_batch = mode == 5
|
||||
|
@ -142,6 +144,7 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
|
|||
inpainting_fill=inpainting_fill,
|
||||
resize_mode=resize_mode,
|
||||
denoising_strength=denoising_strength,
|
||||
image_cfg_scale=image_cfg_scale,
|
||||
inpaint_full_res=inpaint_full_res,
|
||||
inpaint_full_res_padding=inpaint_full_res_padding,
|
||||
inpainting_mask_invert=inpainting_mask_invert,
|
||||
|
|
53
modules/mac_specific.py
Normal file
53
modules/mac_specific.py
Normal file
|
@ -0,0 +1,53 @@
|
|||
import torch
|
||||
from modules import paths
|
||||
from modules.sd_hijack_utils import CondFunc
|
||||
from packaging import version
|
||||
|
||||
|
||||
# has_mps is only available in nightly pytorch (for now) and macOS 12.3+.
|
||||
# check `getattr` and try it for compatibility
|
||||
def check_for_mps() -> bool:
|
||||
if not getattr(torch, 'has_mps', False):
|
||||
return False
|
||||
try:
|
||||
torch.zeros(1).to(torch.device("mps"))
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
has_mps = check_for_mps()
|
||||
|
||||
|
||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/89784
|
||||
def cumsum_fix(input, cumsum_func, *args, **kwargs):
|
||||
if input.device.type == 'mps':
|
||||
output_dtype = kwargs.get('dtype', input.dtype)
|
||||
if output_dtype == torch.int64:
|
||||
return cumsum_func(input.cpu(), *args, **kwargs).to(input.device)
|
||||
elif cumsum_needs_bool_fix and output_dtype == torch.bool or cumsum_needs_int_fix and (output_dtype == torch.int8 or output_dtype == torch.int16):
|
||||
return cumsum_func(input.to(torch.int32), *args, **kwargs).to(torch.int64)
|
||||
return cumsum_func(input, *args, **kwargs)
|
||||
|
||||
|
||||
if has_mps:
|
||||
# MPS fix for randn in torchsde
|
||||
CondFunc('torchsde._brownian.brownian_interval._randn', lambda _, size, dtype, device, seed: torch.randn(size, dtype=dtype, device=torch.device("cpu"), generator=torch.Generator(torch.device("cpu")).manual_seed(int(seed))).to(device), lambda _, size, dtype, device, seed: device.type == 'mps')
|
||||
|
||||
if version.parse(torch.__version__) < version.parse("1.13"):
|
||||
# PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working
|
||||
|
||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/79383
|
||||
CondFunc('torch.Tensor.to', lambda orig_func, self, *args, **kwargs: orig_func(self.contiguous(), *args, **kwargs),
|
||||
lambda _, self, *args, **kwargs: self.device.type != 'mps' and (args and isinstance(args[0], torch.device) and args[0].type == 'mps' or isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps'))
|
||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/80800
|
||||
CondFunc('torch.nn.functional.layer_norm', lambda orig_func, *args, **kwargs: orig_func(*([args[0].contiguous()] + list(args[1:])), **kwargs),
|
||||
lambda _, *args, **kwargs: args and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps')
|
||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/90532
|
||||
CondFunc('torch.Tensor.numpy', lambda orig_func, self, *args, **kwargs: orig_func(self.detach(), *args, **kwargs), lambda _, self, *args, **kwargs: self.requires_grad)
|
||||
elif version.parse(torch.__version__) > version.parse("1.13.1"):
|
||||
cumsum_needs_int_fix = not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.ShortTensor([1,1]).to(torch.device("mps")).cumsum(0))
|
||||
cumsum_needs_bool_fix = not torch.BoolTensor([True,True]).to(device=torch.device("mps"), dtype=torch.int64).equal(torch.BoolTensor([True,False]).to(torch.device("mps")).cumsum(0))
|
||||
cumsum_fix_func = lambda orig_func, input, *args, **kwargs: cumsum_fix(input, orig_func, *args, **kwargs)
|
||||
CondFunc('torch.cumsum', cumsum_fix_func, None)
|
||||
CondFunc('torch.Tensor.cumsum', cumsum_fix_func, None)
|
||||
CondFunc('torch.narrow', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).clone(), None)
|
||||
|
|
@ -45,6 +45,9 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
|
|||
full_path = file
|
||||
if os.path.isdir(full_path):
|
||||
continue
|
||||
if os.path.islink(full_path) and not os.path.exists(full_path):
|
||||
print(f"Skipping broken symlink: {full_path}")
|
||||
continue
|
||||
if ext_blacklist is not None and any([full_path.endswith(x) for x in ext_blacklist]):
|
||||
continue
|
||||
if len(ext_filter) != 0:
|
||||
|
|
|
@ -186,7 +186,7 @@ class StableDiffusionProcessing:
|
|||
return conditioning
|
||||
|
||||
def edit_image_conditioning(self, source_image):
|
||||
conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image))
|
||||
conditioning_image = self.sd_model.encode_first_stage(source_image).mode()
|
||||
|
||||
return conditioning_image
|
||||
|
||||
|
@ -268,6 +268,7 @@ class Processed:
|
|||
self.height = p.height
|
||||
self.sampler_name = p.sampler_name
|
||||
self.cfg_scale = p.cfg_scale
|
||||
self.image_cfg_scale = getattr(p, 'image_cfg_scale', None)
|
||||
self.steps = p.steps
|
||||
self.batch_size = p.batch_size
|
||||
self.restore_faces = p.restore_faces
|
||||
|
@ -445,6 +446,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
|
|||
"Steps": p.steps,
|
||||
"Sampler": p.sampler_name,
|
||||
"CFG scale": p.cfg_scale,
|
||||
"Image CFG scale": getattr(p, 'image_cfg_scale', None),
|
||||
"Seed": all_seeds[index],
|
||||
"Face restoration": (opts.face_restoration_model if p.restore_faces else None),
|
||||
"Size": f"{p.width}x{p.height}",
|
||||
|
@ -541,8 +543,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||
if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings:
|
||||
model_hijack.embedding_db.load_textual_inversion_embeddings()
|
||||
|
||||
_, extra_network_data = extra_networks.parse_prompts(p.all_prompts[0:1])
|
||||
|
||||
if p.scripts is not None:
|
||||
p.scripts.process(p)
|
||||
|
||||
|
@ -580,13 +580,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||
if shared.opts.live_previews_enable and opts.show_progress_type == "Approx NN":
|
||||
sd_vae_approx.model()
|
||||
|
||||
if not p.disable_extra_networks:
|
||||
extra_networks.activate(p, extra_network_data)
|
||||
|
||||
with open(os.path.join(paths.data_path, "params.txt"), "w", encoding="utf8") as file:
|
||||
processed = Processed(p, [], p.seed, "")
|
||||
file.write(processed.infotext(p, 0))
|
||||
|
||||
if state.job_count == -1:
|
||||
state.job_count = p.n_iter
|
||||
|
||||
|
@ -607,11 +600,24 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||
if len(prompts) == 0:
|
||||
break
|
||||
|
||||
prompts, _ = extra_networks.parse_prompts(prompts)
|
||||
prompts, extra_network_data = extra_networks.parse_prompts(prompts)
|
||||
|
||||
if not p.disable_extra_networks:
|
||||
with devices.autocast():
|
||||
extra_networks.activate(p, extra_network_data)
|
||||
|
||||
if p.scripts is not None:
|
||||
p.scripts.process_batch(p, batch_number=n, prompts=prompts, seeds=seeds, subseeds=subseeds)
|
||||
|
||||
# params.txt should be saved after scripts.process_batch, since the
|
||||
# infotext could be modified by that callback
|
||||
# Example: a wildcard processed by process_batch sets an extra model
|
||||
# strength, which is saved as "Model Strength: 1.0" in the infotext
|
||||
if n == 0:
|
||||
with open(os.path.join(paths.data_path, "params.txt"), "w", encoding="utf8") as file:
|
||||
processed = Processed(p, [], p.seed, "")
|
||||
file.write(processed.infotext(p, 0))
|
||||
|
||||
uc = get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, p.steps, cached_uc)
|
||||
c = get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, p.steps, cached_c)
|
||||
|
||||
|
@ -901,12 +907,13 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||
class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
||||
sampler = None
|
||||
|
||||
def __init__(self, init_images: list = None, resize_mode: int = 0, denoising_strength: float = 0.75, mask: Any = None, mask_blur: int = 4, inpainting_fill: int = 0, inpaint_full_res: bool = True, inpaint_full_res_padding: int = 0, inpainting_mask_invert: int = 0, initial_noise_multiplier: float = None, **kwargs):
|
||||
def __init__(self, init_images: list = None, resize_mode: int = 0, denoising_strength: float = 0.75, image_cfg_scale: float = None, mask: Any = None, mask_blur: int = 4, inpainting_fill: int = 0, inpaint_full_res: bool = True, inpaint_full_res_padding: int = 0, inpainting_mask_invert: int = 0, initial_noise_multiplier: float = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.init_images = init_images
|
||||
self.resize_mode: int = resize_mode
|
||||
self.denoising_strength: float = denoising_strength
|
||||
self.image_cfg_scale: float = image_cfg_scale if shared.sd_model.cond_stage_key == "edit" else None
|
||||
self.init_latent = None
|
||||
self.image_mask = mask
|
||||
self.latent_mask = None
|
||||
|
|
|
@ -46,6 +46,18 @@ class CFGDenoiserParams:
|
|||
"""Total number of sampling steps planned"""
|
||||
|
||||
|
||||
class CFGDenoisedParams:
|
||||
def __init__(self, x, sampling_step, total_sampling_steps):
|
||||
self.x = x
|
||||
"""Latent image representation in the process of being denoised"""
|
||||
|
||||
self.sampling_step = sampling_step
|
||||
"""Current Sampling step number"""
|
||||
|
||||
self.total_sampling_steps = total_sampling_steps
|
||||
"""Total number of sampling steps planned"""
|
||||
|
||||
|
||||
class UiTrainTabParams:
|
||||
def __init__(self, txt2img_preview_params):
|
||||
self.txt2img_preview_params = txt2img_preview_params
|
||||
|
@ -68,6 +80,7 @@ callback_map = dict(
|
|||
callbacks_before_image_saved=[],
|
||||
callbacks_image_saved=[],
|
||||
callbacks_cfg_denoiser=[],
|
||||
callbacks_cfg_denoised=[],
|
||||
callbacks_before_component=[],
|
||||
callbacks_after_component=[],
|
||||
callbacks_image_grid=[],
|
||||
|
@ -150,6 +163,14 @@ def cfg_denoiser_callback(params: CFGDenoiserParams):
|
|||
report_exception(c, 'cfg_denoiser_callback')
|
||||
|
||||
|
||||
def cfg_denoised_callback(params: CFGDenoisedParams):
|
||||
for c in callback_map['callbacks_cfg_denoised']:
|
||||
try:
|
||||
c.callback(params)
|
||||
except Exception:
|
||||
report_exception(c, 'cfg_denoised_callback')
|
||||
|
||||
|
||||
def before_component_callback(component, **kwargs):
|
||||
for c in callback_map['callbacks_before_component']:
|
||||
try:
|
||||
|
@ -283,6 +304,14 @@ def on_cfg_denoiser(callback):
|
|||
add_callback(callback_map['callbacks_cfg_denoiser'], callback)
|
||||
|
||||
|
||||
def on_cfg_denoised(callback):
|
||||
"""register a function to be called in the kdiffussion cfg_denoiser method after building the inner model inputs.
|
||||
The callback is called with one argument:
|
||||
- params: CFGDenoisedParams - parameters to be passed to the inner model and sampling state details.
|
||||
"""
|
||||
add_callback(callback_map['callbacks_cfg_denoised'], callback)
|
||||
|
||||
|
||||
def on_before_component(callback):
|
||||
"""register a function to be called before a component is created.
|
||||
The callback is called with arguments:
|
||||
|
|
|
@ -20,8 +20,9 @@ class DisableInitialization:
|
|||
```
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, disable_clip=True):
|
||||
self.replaced = []
|
||||
self.disable_clip = disable_clip
|
||||
|
||||
def replace(self, obj, field, func):
|
||||
original = getattr(obj, field, None)
|
||||
|
@ -75,12 +76,14 @@ class DisableInitialization:
|
|||
self.replace(torch.nn.init, 'kaiming_uniform_', do_nothing)
|
||||
self.replace(torch.nn.init, '_no_grad_normal_', do_nothing)
|
||||
self.replace(torch.nn.init, '_no_grad_uniform_', do_nothing)
|
||||
self.create_model_and_transforms = self.replace(open_clip, 'create_model_and_transforms', create_model_and_transforms_without_pretrained)
|
||||
self.CLIPTextModel_from_pretrained = self.replace(ldm.modules.encoders.modules.CLIPTextModel, 'from_pretrained', CLIPTextModel_from_pretrained)
|
||||
self.transformers_modeling_utils_load_pretrained_model = self.replace(transformers.modeling_utils.PreTrainedModel, '_load_pretrained_model', transformers_modeling_utils_load_pretrained_model)
|
||||
self.transformers_tokenization_utils_base_cached_file = self.replace(transformers.tokenization_utils_base, 'cached_file', transformers_tokenization_utils_base_cached_file)
|
||||
self.transformers_configuration_utils_cached_file = self.replace(transformers.configuration_utils, 'cached_file', transformers_configuration_utils_cached_file)
|
||||
self.transformers_utils_hub_get_from_cache = self.replace(transformers.utils.hub, 'get_from_cache', transformers_utils_hub_get_from_cache)
|
||||
|
||||
if self.disable_clip:
|
||||
self.create_model_and_transforms = self.replace(open_clip, 'create_model_and_transforms', create_model_and_transforms_without_pretrained)
|
||||
self.CLIPTextModel_from_pretrained = self.replace(ldm.modules.encoders.modules.CLIPTextModel, 'from_pretrained', CLIPTextModel_from_pretrained)
|
||||
self.transformers_modeling_utils_load_pretrained_model = self.replace(transformers.modeling_utils.PreTrainedModel, '_load_pretrained_model', transformers_modeling_utils_load_pretrained_model)
|
||||
self.transformers_tokenization_utils_base_cached_file = self.replace(transformers.tokenization_utils_base, 'cached_file', transformers_tokenization_utils_base_cached_file)
|
||||
self.transformers_configuration_utils_cached_file = self.replace(transformers.configuration_utils, 'cached_file', transformers_configuration_utils_cached_file)
|
||||
self.transformers_utils_hub_get_from_cache = self.replace(transformers.utils.hub, 'get_from_cache', transformers_utils_hub_get_from_cache)
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
for obj, field, original in self.replaced:
|
||||
|
|
|
@ -104,6 +104,9 @@ class StableDiffusionModelHijack:
|
|||
m.cond_stage_model.model.token_embedding = EmbeddingsWithFixes(m.cond_stage_model.model.token_embedding, self)
|
||||
m.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
|
||||
|
||||
if m.cond_stage_key == "edit":
|
||||
sd_hijack_unet.hijack_ddpm_edit()
|
||||
|
||||
self.optimization_method = apply_optimizations()
|
||||
|
||||
self.clip = m.cond_stage_model
|
||||
|
|
|
@ -11,6 +11,7 @@ import ldm.models.diffusion.plms
|
|||
from ldm.models.diffusion.ddpm import LatentDiffusion
|
||||
from ldm.models.diffusion.plms import PLMSSampler
|
||||
from ldm.models.diffusion.ddim import DDIMSampler, noise_like
|
||||
from ldm.models.diffusion.sampling_util import norm_thresholding
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
|
|
|
@ -44,6 +44,7 @@ def apply_model(orig_func, self, x_noisy, t, cond, **kwargs):
|
|||
with devices.autocast():
|
||||
return orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs).float()
|
||||
|
||||
|
||||
class GELUHijack(torch.nn.GELU, torch.nn.Module):
|
||||
def __init__(self, *args, **kwargs):
|
||||
torch.nn.GELU.__init__(self, *args, **kwargs)
|
||||
|
@ -53,6 +54,16 @@ class GELUHijack(torch.nn.GELU, torch.nn.Module):
|
|||
else:
|
||||
return torch.nn.GELU.forward(self, x)
|
||||
|
||||
|
||||
ddpm_edit_hijack = None
|
||||
def hijack_ddpm_edit():
|
||||
global ddpm_edit_hijack
|
||||
if not ddpm_edit_hijack:
|
||||
CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond)
|
||||
CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond)
|
||||
ddpm_edit_hijack = CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.apply_model', apply_model, unet_needs_upcast)
|
||||
|
||||
|
||||
unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast
|
||||
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast)
|
||||
CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast)
|
||||
|
|
|
@ -59,13 +59,17 @@ class CheckpointInfo:
|
|||
|
||||
def calculate_shorthash(self):
|
||||
self.sha256 = hashes.sha256(self.filename, "checkpoint/" + self.name)
|
||||
if self.sha256 is None:
|
||||
return
|
||||
|
||||
self.shorthash = self.sha256[0:10]
|
||||
|
||||
if self.shorthash not in self.ids:
|
||||
self.ids += [self.shorthash, self.sha256]
|
||||
self.register()
|
||||
self.ids += [self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]']
|
||||
|
||||
checkpoints_list.pop(self.title)
|
||||
self.title = f'{self.name} [{self.shorthash}]'
|
||||
self.register()
|
||||
|
||||
return self.shorthash
|
||||
|
||||
|
@ -101,7 +105,7 @@ def checkpoint_tiles():
|
|||
def list_models():
|
||||
checkpoints_list.clear()
|
||||
checkpoint_alisases.clear()
|
||||
model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"], ext_blacklist=[".vae.safetensors"])
|
||||
model_list = modelloader.load_models(model_path=model_path, model_url="https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.safetensors", command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"], download_name="v1-5-pruned-emaonly.safetensors", ext_blacklist=[".vae.ckpt", ".vae.safetensors"])
|
||||
|
||||
cmd_ckpt = shared.cmd_opts.ckpt
|
||||
if os.path.exists(cmd_ckpt):
|
||||
|
@ -158,7 +162,7 @@ def select_checkpoint():
|
|||
print(f" - directory {model_path}", file=sys.stderr)
|
||||
if shared.cmd_opts.ckpt_dir is not None:
|
||||
print(f" - directory {os.path.abspath(shared.cmd_opts.ckpt_dir)}", file=sys.stderr)
|
||||
print("Can't run without a checkpoint. Find and place a .ckpt file into any of those locations. The program will exit.", file=sys.stderr)
|
||||
print("Can't run without a checkpoint. Find and place a .ckpt or .safetensors file into any of those locations. The program will exit.", file=sys.stderr)
|
||||
exit(1)
|
||||
|
||||
checkpoint_info = next(iter(checkpoints_list.values()))
|
||||
|
@ -350,6 +354,9 @@ def repair_config(sd_config):
|
|||
sd_config.model.params.unet_config.params.use_fp16 = True
|
||||
|
||||
|
||||
sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight'
|
||||
sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight'
|
||||
|
||||
def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_to_load_state_dict=None):
|
||||
from modules import lowvram, sd_hijack
|
||||
checkpoint_info = checkpoint_info or select_checkpoint()
|
||||
|
@ -370,6 +377,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_
|
|||
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
|
||||
|
||||
checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
|
||||
clip_is_included_into_sd = sd1_clip_weight in state_dict or sd2_clip_weight in state_dict
|
||||
|
||||
timer.record("find config")
|
||||
|
||||
|
@ -382,7 +390,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_
|
|||
|
||||
sd_model = None
|
||||
try:
|
||||
with sd_disable_initialization.DisableInitialization():
|
||||
with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd):
|
||||
sd_model = instantiate_from_config(sd_config.model)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
|
|
@ -2,7 +2,6 @@ from collections import namedtuple
|
|||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
import torchsde._brownian.brownian_interval
|
||||
from modules import devices, processing, images, sd_vae_approx
|
||||
|
||||
from modules.shared import opts, state
|
||||
|
@ -61,18 +60,3 @@ def store_latent(decoded):
|
|||
|
||||
class InterruptedException(BaseException):
|
||||
pass
|
||||
|
||||
|
||||
# MPS fix for randn in torchsde
|
||||
# XXX move this to separate file for MPS
|
||||
def torchsde_randn(size, dtype, device, seed):
|
||||
if device.type == 'mps':
|
||||
generator = torch.Generator(devices.cpu).manual_seed(int(seed))
|
||||
return torch.randn(size, dtype=dtype, device=devices.cpu, generator=generator).to(device)
|
||||
else:
|
||||
generator = torch.Generator(device).manual_seed(int(seed))
|
||||
return torch.randn(size, dtype=dtype, device=device, generator=generator)
|
||||
|
||||
|
||||
torchsde._brownian.brownian_interval._randn = torchsde_randn
|
||||
|
||||
|
|
|
@ -1,12 +1,14 @@
|
|||
from collections import deque
|
||||
import torch
|
||||
import inspect
|
||||
import einops
|
||||
import k_diffusion.sampling
|
||||
from modules import prompt_parser, devices, sd_samplers_common
|
||||
|
||||
from modules.shared import opts, state
|
||||
import modules.shared as shared
|
||||
from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback
|
||||
from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback
|
||||
|
||||
samplers_k_diffusion = [
|
||||
('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {}),
|
||||
|
@ -56,6 +58,7 @@ class CFGDenoiser(torch.nn.Module):
|
|||
self.nmask = None
|
||||
self.init_latent = None
|
||||
self.step = 0
|
||||
self.image_cfg_scale = None
|
||||
|
||||
def combine_denoised(self, x_out, conds_list, uncond, cond_scale):
|
||||
denoised_uncond = x_out[-uncond.shape[0]:]
|
||||
|
@ -67,19 +70,36 @@ class CFGDenoiser(torch.nn.Module):
|
|||
|
||||
return denoised
|
||||
|
||||
def combine_denoised_for_edit_model(self, x_out, cond_scale):
|
||||
out_cond, out_img_cond, out_uncond = x_out.chunk(3)
|
||||
denoised = out_uncond + cond_scale * (out_cond - out_img_cond) + self.image_cfg_scale * (out_img_cond - out_uncond)
|
||||
|
||||
return denoised
|
||||
|
||||
def forward(self, x, sigma, uncond, cond, cond_scale, image_cond):
|
||||
if state.interrupted or state.skipped:
|
||||
raise sd_samplers_common.InterruptedException
|
||||
|
||||
# at self.image_cfg_scale == 1.0 produced results for edit model are the same as with normal sampling,
|
||||
# so is_edit_model is set to False to support AND composition.
|
||||
is_edit_model = shared.sd_model.cond_stage_key == "edit" and self.image_cfg_scale is not None and self.image_cfg_scale != 1.0
|
||||
|
||||
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
|
||||
uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
|
||||
|
||||
assert not is_edit_model or all([len(conds) == 1 for conds in conds_list]), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
|
||||
|
||||
batch_size = len(conds_list)
|
||||
repeats = [len(conds_list[i]) for i in range(batch_size)]
|
||||
|
||||
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
|
||||
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond])
|
||||
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
|
||||
if not is_edit_model:
|
||||
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
|
||||
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
|
||||
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond])
|
||||
else:
|
||||
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x] + [x])
|
||||
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma])
|
||||
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond] + [torch.zeros_like(self.init_latent)])
|
||||
|
||||
denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps)
|
||||
cfg_denoiser_callback(denoiser_params)
|
||||
|
@ -88,7 +108,10 @@ class CFGDenoiser(torch.nn.Module):
|
|||
sigma_in = denoiser_params.sigma
|
||||
|
||||
if tensor.shape[1] == uncond.shape[1]:
|
||||
cond_in = torch.cat([tensor, uncond])
|
||||
if not is_edit_model:
|
||||
cond_in = torch.cat([tensor, uncond])
|
||||
else:
|
||||
cond_in = torch.cat([tensor, uncond, uncond])
|
||||
|
||||
if shared.batch_cond_uncond:
|
||||
x_out = self.inner_model(x_in, sigma_in, cond={"c_crossattn": [cond_in], "c_concat": [image_cond_in]})
|
||||
|
@ -104,10 +127,19 @@ class CFGDenoiser(torch.nn.Module):
|
|||
for batch_offset in range(0, tensor.shape[0], batch_size):
|
||||
a = batch_offset
|
||||
b = min(a + batch_size, tensor.shape[0])
|
||||
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": [tensor[a:b]], "c_concat": [image_cond_in[a:b]]})
|
||||
|
||||
if not is_edit_model:
|
||||
c_crossattn = [tensor[a:b]]
|
||||
else:
|
||||
c_crossattn = torch.cat([tensor[a:b]], uncond)
|
||||
|
||||
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": c_crossattn, "c_concat": [image_cond_in[a:b]]})
|
||||
|
||||
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond={"c_crossattn": [uncond], "c_concat": [image_cond_in[-uncond.shape[0]:]]})
|
||||
|
||||
denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps)
|
||||
cfg_denoised_callback(denoised_params)
|
||||
|
||||
devices.test_for_nans(x_out, "unet")
|
||||
|
||||
if opts.live_preview_content == "Prompt":
|
||||
|
@ -115,7 +147,10 @@ class CFGDenoiser(torch.nn.Module):
|
|||
elif opts.live_preview_content == "Negative prompt":
|
||||
sd_samplers_common.store_latent(x_out[-uncond.shape[0]:])
|
||||
|
||||
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
|
||||
if not is_edit_model:
|
||||
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
|
||||
else:
|
||||
denoised = self.combine_denoised_for_edit_model(x_out, cond_scale)
|
||||
|
||||
if self.mask is not None:
|
||||
denoised = self.init_latent * self.mask + self.nmask * denoised
|
||||
|
@ -198,6 +233,7 @@ class KDiffusionSampler:
|
|||
self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
|
||||
self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
|
||||
self.model_wrap_cfg.step = 0
|
||||
self.model_wrap_cfg.image_cfg_scale = getattr(p, 'image_cfg_scale', None)
|
||||
self.eta = p.eta if p.eta is not None else opts.eta_ancestral
|
||||
|
||||
k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else [])
|
||||
|
@ -237,6 +273,16 @@ class KDiffusionSampler:
|
|||
|
||||
return sigmas
|
||||
|
||||
def create_noise_sampler(self, x, sigmas, p):
|
||||
"""For DPM++ SDE: manually create noise sampler to enable deterministic results across different batch sizes"""
|
||||
if shared.opts.no_dpmpp_sde_batch_determinism:
|
||||
return None
|
||||
|
||||
from k_diffusion.sampling import BrownianTreeNoiseSampler
|
||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||
current_iter_seeds = p.all_seeds[p.iteration * p.batch_size:(p.iteration + 1) * p.batch_size]
|
||||
return BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=current_iter_seeds)
|
||||
|
||||
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||
steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps)
|
||||
|
||||
|
@ -246,31 +292,38 @@ class KDiffusionSampler:
|
|||
xi = x + noise * sigma_sched[0]
|
||||
|
||||
extra_params_kwargs = self.initialize(p)
|
||||
if 'sigma_min' in inspect.signature(self.func).parameters:
|
||||
parameters = inspect.signature(self.func).parameters
|
||||
|
||||
if 'sigma_min' in parameters:
|
||||
## last sigma is zero which isn't allowed by DPM Fast & Adaptive so taking value before last
|
||||
extra_params_kwargs['sigma_min'] = sigma_sched[-2]
|
||||
if 'sigma_max' in inspect.signature(self.func).parameters:
|
||||
if 'sigma_max' in parameters:
|
||||
extra_params_kwargs['sigma_max'] = sigma_sched[0]
|
||||
if 'n' in inspect.signature(self.func).parameters:
|
||||
if 'n' in parameters:
|
||||
extra_params_kwargs['n'] = len(sigma_sched) - 1
|
||||
if 'sigma_sched' in inspect.signature(self.func).parameters:
|
||||
if 'sigma_sched' in parameters:
|
||||
extra_params_kwargs['sigma_sched'] = sigma_sched
|
||||
if 'sigmas' in inspect.signature(self.func).parameters:
|
||||
if 'sigmas' in parameters:
|
||||
extra_params_kwargs['sigmas'] = sigma_sched
|
||||
|
||||
if self.funcname == 'sample_dpmpp_sde':
|
||||
noise_sampler = self.create_noise_sampler(x, sigmas, p)
|
||||
extra_params_kwargs['noise_sampler'] = noise_sampler
|
||||
|
||||
self.model_wrap_cfg.init_latent = x
|
||||
self.last_latent = x
|
||||
|
||||
samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args={
|
||||
extra_args={
|
||||
'cond': conditioning,
|
||||
'image_cond': image_conditioning,
|
||||
'uncond': unconditional_conditioning,
|
||||
'cond_scale': p.cfg_scale
|
||||
}, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
||||
'cond_scale': p.cfg_scale,
|
||||
}
|
||||
|
||||
samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
||||
|
||||
return samples
|
||||
|
||||
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning = None):
|
||||
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||
steps = steps or p.steps
|
||||
|
||||
sigmas = self.get_sigmas(p, steps)
|
||||
|
@ -278,14 +331,20 @@ class KDiffusionSampler:
|
|||
x = x * sigmas[0]
|
||||
|
||||
extra_params_kwargs = self.initialize(p)
|
||||
if 'sigma_min' in inspect.signature(self.func).parameters:
|
||||
parameters = inspect.signature(self.func).parameters
|
||||
|
||||
if 'sigma_min' in parameters:
|
||||
extra_params_kwargs['sigma_min'] = self.model_wrap.sigmas[0].item()
|
||||
extra_params_kwargs['sigma_max'] = self.model_wrap.sigmas[-1].item()
|
||||
if 'n' in inspect.signature(self.func).parameters:
|
||||
if 'n' in parameters:
|
||||
extra_params_kwargs['n'] = steps
|
||||
else:
|
||||
extra_params_kwargs['sigmas'] = sigmas
|
||||
|
||||
if self.funcname == 'sample_dpmpp_sde':
|
||||
noise_sampler = self.create_noise_sampler(x, sigmas, p)
|
||||
extra_params_kwargs['noise_sampler'] = noise_sampler
|
||||
|
||||
self.last_latent = x
|
||||
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={
|
||||
'cond': conditioning,
|
||||
|
|
|
@ -106,7 +106,7 @@ parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, req
|
|||
parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None)
|
||||
parser.add_argument("--gradio-queue", action='store_true', help="Uses gradio queue; experimental option; breaks restart UI button")
|
||||
parser.add_argument("--skip-version-check", action='store_true', help="Do not check versions of torch and xformers")
|
||||
|
||||
parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False)
|
||||
|
||||
|
||||
script_loading.preload_extensions(extensions.extensions_dir, parser)
|
||||
|
@ -325,9 +325,11 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
|
|||
"save_images_before_highres_fix": OptionInfo(False, "Save a copy of image before applying highres fix."),
|
||||
"save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
|
||||
"jpeg_quality": OptionInfo(80, "Quality for saved jpeg images", gr.Slider, {"minimum": 1, "maximum": 100, "step": 1}),
|
||||
"export_for_4chan": OptionInfo(True, "If PNG image is larger than 4MB or any dimension is larger than 4000, downscale and save copy as JPG"),
|
||||
"export_for_4chan": OptionInfo(True, "If the saved image file size is above the limit, or its either width or height are above the limit, save a downscaled copy as JPG"),
|
||||
"img_downscale_threshold": OptionInfo(4.0, "File size limit for the above option, MB", gr.Number),
|
||||
"target_side_length": OptionInfo(4000, "Width/height limit for the above option, in pixels", gr.Number),
|
||||
|
||||
"use_original_name_batch": OptionInfo(False, "Use original name for output filename during batch process in extras tab"),
|
||||
"use_original_name_batch": OptionInfo(True, "Use original name for output filename during batch process in extras tab"),
|
||||
"use_upscaler_name_as_suffix": OptionInfo(False, "Use upscaler name as filename suffix in the extras tab"),
|
||||
"save_selected_only": OptionInfo(True, "When using 'Save' button, only save a single selected image"),
|
||||
"do_not_add_watermark": OptionInfo(False, "Do not add watermark to images"),
|
||||
|
@ -364,7 +366,7 @@ options_templates.update(options_section(('upscaling', "Upscaling"), {
|
|||
}))
|
||||
|
||||
options_templates.update(options_section(('face-restoration', "Face restoration"), {
|
||||
"face_restoration_model": OptionInfo(None, "Face restoration model", gr.Radio, lambda: {"choices": [x.name() for x in face_restorers]}),
|
||||
"face_restoration_model": OptionInfo("CodeFormer", "Face restoration model", gr.Radio, lambda: {"choices": [x.name() for x in face_restorers]}),
|
||||
"code_former_weight": OptionInfo(0.5, "CodeFormer weight parameter; 0 = maximum effect; 1 = minimum effect", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),
|
||||
"face_restoration_unload": OptionInfo(False, "Move face restoration model from VRAM into RAM after processing"),
|
||||
}))
|
||||
|
@ -414,6 +416,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
|||
options_templates.update(options_section(('compatibility', "Compatibility"), {
|
||||
"use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."),
|
||||
"use_old_karras_scheduler_sigmas": OptionInfo(False, "Use old karras scheduler sigmas (0.1 to 10)."),
|
||||
"no_dpmpp_sde_batch_determinism": OptionInfo(False, "Do not make DPM++ SDE deterministic across different batch sizes."),
|
||||
"use_old_hires_fix_width_height": OptionInfo(False, "For hires fix, use width/height sliders to set final resolution rather than first pass (disables Upscale by, Resize width/height to)."),
|
||||
}))
|
||||
|
||||
|
|
|
@ -20,4 +20,4 @@ def sd_vae_items():
|
|||
def refresh_vae_list():
|
||||
import modules.sd_vae
|
||||
|
||||
return modules.sd_vae.refresh_vae_list
|
||||
modules.sd_vae.refresh_vae_list()
|
||||
|
|
|
@ -479,8 +479,8 @@ def create_ui():
|
|||
width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="txt2img_width")
|
||||
height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="txt2img_height")
|
||||
|
||||
res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="txt2img_res_switch_btn")
|
||||
if opts.dimensions_and_batch_together:
|
||||
res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="txt2img_res_switch_btn")
|
||||
with gr.Column(elem_id="txt2img_column_batch"):
|
||||
batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count")
|
||||
batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size")
|
||||
|
@ -631,9 +631,9 @@ def create_ui():
|
|||
(hr_resize_y, "Hires resize-2"),
|
||||
*modules.scripts.scripts_txt2img.infotext_fields
|
||||
]
|
||||
parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields)
|
||||
parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields, override_settings)
|
||||
parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(
|
||||
paste_button=txt2img_paste, tabname="txt2img", source_text_component=txt2img_prompt, source_image_component=None, override_settings_component=override_settings,
|
||||
paste_button=txt2img_paste, tabname="txt2img", source_text_component=txt2img_prompt, source_image_component=None,
|
||||
))
|
||||
|
||||
txt2img_preview_params = [
|
||||
|
@ -757,15 +757,17 @@ def create_ui():
|
|||
width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width")
|
||||
height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height")
|
||||
|
||||
res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn")
|
||||
if opts.dimensions_and_batch_together:
|
||||
res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn")
|
||||
with gr.Column(elem_id="img2img_column_batch"):
|
||||
batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count")
|
||||
batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size")
|
||||
|
||||
elif category == "cfg":
|
||||
with FormGroup():
|
||||
cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="img2img_cfg_scale")
|
||||
with FormRow():
|
||||
cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="img2img_cfg_scale")
|
||||
image_cfg_scale = gr.Slider(minimum=0, maximum=3.0, step=0.05, label='Image CFG Scale', value=1.5, elem_id="img2img_image_cfg_scale", visible=shared.sd_model and shared.sd_model.cond_stage_key == "edit")
|
||||
denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75, elem_id="img2img_denoising_strength")
|
||||
|
||||
elif category == "seed":
|
||||
|
@ -861,6 +863,7 @@ def create_ui():
|
|||
batch_count,
|
||||
batch_size,
|
||||
cfg_scale,
|
||||
image_cfg_scale,
|
||||
denoising_strength,
|
||||
seed,
|
||||
subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox,
|
||||
|
@ -947,6 +950,7 @@ def create_ui():
|
|||
(sampler_index, "Sampler"),
|
||||
(restore_faces, "Face restoration"),
|
||||
(cfg_scale, "CFG scale"),
|
||||
(image_cfg_scale, "Image CFG scale"),
|
||||
(seed, "Seed"),
|
||||
(width, "Size-1"),
|
||||
(height, "Size-2"),
|
||||
|
@ -959,10 +963,10 @@ def create_ui():
|
|||
(mask_blur, "Mask blur"),
|
||||
*modules.scripts.scripts_img2img.infotext_fields
|
||||
]
|
||||
parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields)
|
||||
parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields)
|
||||
parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields, override_settings)
|
||||
parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields, override_settings)
|
||||
parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(
|
||||
paste_button=img2img_paste, tabname="img2img", source_text_component=img2img_prompt, source_image_component=None, override_settings_component=override_settings,
|
||||
paste_button=img2img_paste, tabname="img2img", source_text_component=img2img_prompt, source_image_component=None,
|
||||
))
|
||||
|
||||
modules.scripts.scripts_current = None
|
||||
|
@ -1591,6 +1595,12 @@ def create_ui():
|
|||
outputs=[component, text_settings],
|
||||
)
|
||||
|
||||
text_settings.change(
|
||||
fn=lambda: gr.update(visible=shared.sd_model and shared.sd_model.cond_stage_key == "edit"),
|
||||
inputs=[],
|
||||
outputs=[image_cfg_scale],
|
||||
)
|
||||
|
||||
button_set_checkpoint = gr.Button('Change checkpoint', elem_id='change_checkpoint', visible=False)
|
||||
button_set_checkpoint.click(
|
||||
fn=lambda value, _: run_settings_single(value, key='sd_model_checkpoint'),
|
||||
|
@ -1772,7 +1782,7 @@ def versions_html():
|
|||
return f"""
|
||||
python: <span title="{sys.version}">{python_version}</span>
|
||||
•
|
||||
torch: {torch.__version__}
|
||||
torch: {getattr(torch, '__long_version__',torch.__version__)}
|
||||
•
|
||||
xformers: {xformers_version}
|
||||
•
|
||||
|
|
|
@ -80,6 +80,7 @@ def extension_table():
|
|||
<tr>
|
||||
<th><abbr title="Use checkbox to enable the extension; it will be enabled or disabled when you click apply button">Extension</abbr></th>
|
||||
<th>URL</th>
|
||||
<th><abbr title="Extension version">Version</abbr></th>
|
||||
<th><abbr title="Use checkbox to mark the extension for update; it will be updated when you click apply button">Update</abbr></th>
|
||||
</tr>
|
||||
</thead>
|
||||
|
@ -87,11 +88,7 @@ def extension_table():
|
|||
"""
|
||||
|
||||
for ext in extensions.extensions:
|
||||
remote = ""
|
||||
if ext.is_builtin:
|
||||
remote = "built-in"
|
||||
elif ext.remote:
|
||||
remote = f"""<a href="{html.escape(ext.remote or '')}" target="_blank">{html.escape("built-in" if ext.is_builtin else ext.remote or '')}</a>"""
|
||||
remote = f"""<a href="{html.escape(ext.remote or '')}" target="_blank">{html.escape("built-in" if ext.is_builtin else ext.remote or '')}</a>"""
|
||||
|
||||
if ext.can_update:
|
||||
ext_status = f"""<label><input class="gr-check-radio gr-checkbox" name="update_{html.escape(ext.name)}" checked="checked" type="checkbox">{html.escape(ext.status)}</label>"""
|
||||
|
@ -102,6 +99,7 @@ def extension_table():
|
|||
<tr>
|
||||
<td><label><input class="gr-check-radio gr-checkbox" name="enable_{html.escape(ext.name)}" type="checkbox" {'checked="checked"' if ext.enabled else ''}>{html.escape(ext.name)}</label></td>
|
||||
<td>{remote}</td>
|
||||
<td>{ext.version}</td>
|
||||
<td{' class="extension_status"' if ext.remote is not None else ''}>{ext_status}</td>
|
||||
</tr>
|
||||
"""
|
||||
|
|
|
@ -26,11 +26,12 @@ def add_pages_to_demo(app):
|
|||
def fetch_file(filename: str = ""):
|
||||
from starlette.responses import FileResponse
|
||||
|
||||
if not any([Path(x).resolve() in Path(filename).resolve().parents for x in allowed_dirs]):
|
||||
if not any([Path(x).absolute() in Path(filename).absolute().parents for x in allowed_dirs]):
|
||||
raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.")
|
||||
|
||||
if os.path.splitext(filename)[1].lower() != ".png":
|
||||
raise ValueError(f"File cannot be fetched: {filename}. Only png.")
|
||||
ext = os.path.splitext(filename)[1].lower()
|
||||
if ext not in (".png", ".jpg"):
|
||||
raise ValueError(f"File cannot be fetched: {filename}. Only png and jpg.")
|
||||
|
||||
# would profit from returning 304
|
||||
return FileResponse(filename, headers={"Accept-Ranges": "bytes"})
|
||||
|
@ -75,6 +76,10 @@ class ExtraNetworksPage:
|
|||
while subdir.startswith("/"):
|
||||
subdir = subdir[1:]
|
||||
|
||||
is_empty = len(os.listdir(x)) == 0
|
||||
if not is_empty and not subdir.endswith("/"):
|
||||
subdir = subdir + "/"
|
||||
|
||||
subdirs[subdir] = 1
|
||||
|
||||
if subdirs:
|
||||
|
@ -93,11 +98,13 @@ class ExtraNetworksPage:
|
|||
dirs = "".join([f"<li>{x}</li>" for x in self.allowed_directories_for_previews()])
|
||||
items_html = shared.html("extra-networks-no-cards.html").format(dirs=dirs)
|
||||
|
||||
self_name_id = self.name.replace(" ", "_")
|
||||
|
||||
res = f"""
|
||||
<div id='{tabname}_{self.name}_subdirs' class='extra-network-subdirs extra-network-subdirs-{view}'>
|
||||
<div id='{tabname}_{self_name_id}_subdirs' class='extra-network-subdirs extra-network-subdirs-{view}'>
|
||||
{subdirs_html}
|
||||
</div>
|
||||
<div id='{tabname}_{self.name}_cards' class='extra-network-{view}'>
|
||||
<div id='{tabname}_{self_name_id}_cards' class='extra-network-{view}'>
|
||||
{items_html}
|
||||
</div>
|
||||
"""
|
||||
|
|
|
@ -27,3 +27,4 @@ GitPython==3.1.27
|
|||
torchsde==0.2.5
|
||||
safetensors==0.2.7
|
||||
httpcore<=0.15
|
||||
fastapi==0.90.1
|
||||
|
|
|
@ -6,7 +6,7 @@ from tqdm import trange
|
|||
import modules.scripts as scripts
|
||||
import gradio as gr
|
||||
|
||||
from modules import processing, shared, sd_samplers, prompt_parser
|
||||
from modules import processing, shared, sd_samplers, prompt_parser, sd_samplers_common
|
||||
from modules.processing import Processed
|
||||
from modules.shared import opts, cmd_opts, state
|
||||
|
||||
|
@ -50,7 +50,7 @@ def find_noise_for_image(p, cond, uncond, cfg_scale, steps):
|
|||
|
||||
x = x + d * dt
|
||||
|
||||
sd_samplers.store_latent(x)
|
||||
sd_samplers_common.store_latent(x)
|
||||
|
||||
# This shouldn't be necessary, but solved some VRAM issues
|
||||
del x_in, sigma_in, cond_in, c_out, c_in, t,
|
||||
|
@ -104,7 +104,7 @@ def find_noise_for_image_sigma_adjustment(p, cond, uncond, cfg_scale, steps):
|
|||
dt = sigmas[i] - sigmas[i - 1]
|
||||
x = x + d * dt
|
||||
|
||||
sd_samplers.store_latent(x)
|
||||
sd_samplers_common.store_latent(x)
|
||||
|
||||
# This shouldn't be necessary, but solved some VRAM issues
|
||||
del x_in, sigma_in, cond_in, c_out, c_in, t,
|
||||
|
|
|
@ -44,16 +44,34 @@ class Script(scripts.Script):
|
|||
def title(self):
|
||||
return "Prompt matrix"
|
||||
|
||||
def ui(self, is_img2img):
|
||||
put_at_start = gr.Checkbox(label='Put variable parts at start of prompt', value=False, elem_id=self.elem_id("put_at_start"))
|
||||
different_seeds = gr.Checkbox(label='Use different seed for each picture', value=False, elem_id=self.elem_id("different_seeds"))
|
||||
def ui(self, is_img2img):
|
||||
gr.HTML('<br />')
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
put_at_start = gr.Checkbox(label='Put variable parts at start of prompt', value=False, elem_id=self.elem_id("put_at_start"))
|
||||
different_seeds = gr.Checkbox(label='Use different seed for each picture', value=False, elem_id=self.elem_id("different_seeds"))
|
||||
with gr.Column():
|
||||
prompt_type = gr.Radio(["positive", "negative"], label="Select prompt", elem_id=self.elem_id("prompt_type"), value="positive")
|
||||
variations_delimiter = gr.Radio(["comma", "space"], label="Select joining char", elem_id=self.elem_id("variations_delimiter"), value="comma")
|
||||
with gr.Column():
|
||||
margin_size = gr.Slider(label="Grid margins (px)", minimum=0, maximum=500, value=0, step=2, elem_id=self.elem_id("margin_size"))
|
||||
|
||||
return [put_at_start, different_seeds]
|
||||
return [put_at_start, different_seeds, prompt_type, variations_delimiter, margin_size]
|
||||
|
||||
def run(self, p, put_at_start, different_seeds):
|
||||
def run(self, p, put_at_start, different_seeds, prompt_type, variations_delimiter, margin_size):
|
||||
modules.processing.fix_seed(p)
|
||||
# Raise error if promp type is not positive or negative
|
||||
if prompt_type not in ["positive", "negative"]:
|
||||
raise ValueError(f"Unknown prompt type {prompt_type}")
|
||||
# Raise error if variations delimiter is not comma or space
|
||||
if variations_delimiter not in ["comma", "space"]:
|
||||
raise ValueError(f"Unknown variations delimiter {variations_delimiter}")
|
||||
|
||||
original_prompt = p.prompt[0] if type(p.prompt) == list else p.prompt
|
||||
prompt = p.prompt if prompt_type == "positive" else p.negative_prompt
|
||||
original_prompt = prompt[0] if type(prompt) == list else prompt
|
||||
positive_prompt = p.prompt[0] if type(p.prompt) == list else p.prompt
|
||||
|
||||
delimiter = ", " if variations_delimiter == "comma" else " "
|
||||
|
||||
all_prompts = []
|
||||
prompt_matrix_parts = original_prompt.split("|")
|
||||
|
@ -66,20 +84,23 @@ class Script(scripts.Script):
|
|||
else:
|
||||
selected_prompts = [prompt_matrix_parts[0]] + selected_prompts
|
||||
|
||||
all_prompts.append(", ".join(selected_prompts))
|
||||
all_prompts.append(delimiter.join(selected_prompts))
|
||||
|
||||
p.n_iter = math.ceil(len(all_prompts) / p.batch_size)
|
||||
p.do_not_save_grid = True
|
||||
|
||||
print(f"Prompt matrix will create {len(all_prompts)} images using a total of {p.n_iter} batches.")
|
||||
|
||||
p.prompt = all_prompts
|
||||
if prompt_type == "positive":
|
||||
p.prompt = all_prompts
|
||||
else:
|
||||
p.negative_prompt = all_prompts
|
||||
p.seed = [p.seed + (i if different_seeds else 0) for i in range(len(all_prompts))]
|
||||
p.prompt_for_display = original_prompt
|
||||
p.prompt_for_display = positive_prompt
|
||||
processed = process_images(p)
|
||||
|
||||
grid = images.image_grid(processed.images, p.batch_size, rows=1 << ((len(prompt_matrix_parts) - 1) // 2))
|
||||
grid = images.draw_prompt_matrix(grid, processed.images[0].width, processed.images[1].height, prompt_matrix_parts)
|
||||
grid = images.draw_prompt_matrix(grid, processed.images[0].width, processed.images[1].height, prompt_matrix_parts, margin_size)
|
||||
processed.images.insert(0, grid)
|
||||
processed.index_of_first_image = 1
|
||||
processed.infotexts.insert(0, processed.infotexts[0])
|
||||
|
|
|
@ -25,6 +25,8 @@ from modules.ui_components import ToolButton
|
|||
|
||||
fill_values_symbol = "\U0001f4d2" # 📒
|
||||
|
||||
AxisInfo = namedtuple('AxisInfo', ['axis', 'values'])
|
||||
|
||||
|
||||
def apply_field(field):
|
||||
def fun(p, x, xs):
|
||||
|
@ -186,6 +188,7 @@ axis_options = [
|
|||
AxisOption("Steps", int, apply_field("steps")),
|
||||
AxisOptionTxt2Img("Hires steps", int, apply_field("hr_second_pass_steps")),
|
||||
AxisOption("CFG Scale", float, apply_field("cfg_scale")),
|
||||
AxisOptionImg2Img("Image CFG Scale", float, apply_field("image_cfg_scale")),
|
||||
AxisOption("Prompt S/R", str, apply_prompt, format_value=format_value),
|
||||
AxisOption("Prompt order", str_permutations, apply_order, format_value=format_value_join_list),
|
||||
AxisOptionTxt2Img("Sampler", str, apply_sampler, format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers]),
|
||||
|
@ -205,7 +208,7 @@ axis_options = [
|
|||
]
|
||||
|
||||
|
||||
def draw_xyz_grid(p, xs, ys, zs, x_labels, y_labels, z_labels, cell, draw_legend, include_lone_images, include_sub_grids, first_axes_processed, second_axes_processed):
|
||||
def draw_xyz_grid(p, xs, ys, zs, x_labels, y_labels, z_labels, cell, draw_legend, include_lone_images, include_sub_grids, first_axes_processed, second_axes_processed, margin_size):
|
||||
hor_texts = [[images.GridAnnotation(x)] for x in x_labels]
|
||||
ver_texts = [[images.GridAnnotation(y)] for y in y_labels]
|
||||
title_texts = [[images.GridAnnotation(z)] for z in z_labels]
|
||||
|
@ -286,23 +289,24 @@ def draw_xyz_grid(p, xs, ys, zs, x_labels, y_labels, z_labels, cell, draw_legend
|
|||
print("Unexpected error: draw_xyz_grid failed to return even a single processed image")
|
||||
return Processed(p, [])
|
||||
|
||||
grids = [None] * len(zs)
|
||||
sub_grids = [None] * len(zs)
|
||||
for i in range(len(zs)):
|
||||
start_index = i * len(xs) * len(ys)
|
||||
end_index = start_index + len(xs) * len(ys)
|
||||
grid = images.image_grid(image_cache[start_index:end_index], rows=len(ys))
|
||||
if draw_legend:
|
||||
grid = images.draw_grid_annotations(grid, cell_size[0], cell_size[1], hor_texts, ver_texts)
|
||||
|
||||
grids[i] = grid
|
||||
grid = images.draw_grid_annotations(grid, cell_size[0], cell_size[1], hor_texts, ver_texts, margin_size)
|
||||
sub_grids[i] = grid
|
||||
if include_sub_grids and len(zs) > 1:
|
||||
processed_result.images.insert(i+1, grid)
|
||||
|
||||
original_grid_size = grids[0].size
|
||||
grids = images.image_grid(grids, rows=1)
|
||||
processed_result.images[0] = images.draw_grid_annotations(grids, original_grid_size[0], original_grid_size[1], title_texts, [[images.GridAnnotation()]])
|
||||
sub_grid_size = sub_grids[0].size
|
||||
z_grid = images.image_grid(sub_grids, rows=1)
|
||||
if draw_legend:
|
||||
z_grid = images.draw_grid_annotations(z_grid, sub_grid_size[0], sub_grid_size[1], title_texts, [[images.GridAnnotation()]])
|
||||
processed_result.images[0] = z_grid
|
||||
|
||||
return processed_result
|
||||
return processed_result, sub_grids
|
||||
|
||||
|
||||
class SharedSettingsStackHelper(object):
|
||||
|
@ -350,10 +354,16 @@ class Script(scripts.Script):
|
|||
fill_z_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_z_tool_button", visible=False)
|
||||
|
||||
with gr.Row(variant="compact", elem_id="axis_options"):
|
||||
draw_legend = gr.Checkbox(label='Draw legend', value=True, elem_id=self.elem_id("draw_legend"))
|
||||
include_lone_images = gr.Checkbox(label='Include Sub Images', value=False, elem_id=self.elem_id("include_lone_images"))
|
||||
include_sub_grids = gr.Checkbox(label='Include Sub Grids', value=False, elem_id=self.elem_id("include_sub_grids"))
|
||||
no_fixed_seeds = gr.Checkbox(label='Keep -1 for seeds', value=False, elem_id=self.elem_id("no_fixed_seeds"))
|
||||
with gr.Column():
|
||||
draw_legend = gr.Checkbox(label='Draw legend', value=True, elem_id=self.elem_id("draw_legend"))
|
||||
no_fixed_seeds = gr.Checkbox(label='Keep -1 for seeds', value=False, elem_id=self.elem_id("no_fixed_seeds"))
|
||||
with gr.Column():
|
||||
include_lone_images = gr.Checkbox(label='Include Sub Images', value=False, elem_id=self.elem_id("include_lone_images"))
|
||||
include_sub_grids = gr.Checkbox(label='Include Sub Grids', value=False, elem_id=self.elem_id("include_sub_grids"))
|
||||
with gr.Column():
|
||||
margin_size = gr.Slider(label="Grid margins (px)", minimum=0, maximum=500, value=0, step=2, elem_id=self.elem_id("margin_size"))
|
||||
|
||||
with gr.Row(variant="compact", elem_id="swap_axes"):
|
||||
swap_xy_axes_button = gr.Button(value="Swap X/Y axes", elem_id="xy_grid_swap_axes_button")
|
||||
swap_yz_axes_button = gr.Button(value="Swap Y/Z axes", elem_id="yz_grid_swap_axes_button")
|
||||
swap_xz_axes_button = gr.Button(value="Swap X/Z axes", elem_id="xz_grid_swap_axes_button")
|
||||
|
@ -392,9 +402,9 @@ class Script(scripts.Script):
|
|||
(z_values, "Z Values"),
|
||||
)
|
||||
|
||||
return [x_type, x_values, y_type, y_values, z_type, z_values, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds]
|
||||
return [x_type, x_values, y_type, y_values, z_type, z_values, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size]
|
||||
|
||||
def run(self, p, x_type, x_values, y_type, y_values, z_type, z_values, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds):
|
||||
def run(self, p, x_type, x_values, y_type, y_values, z_type, z_values, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size):
|
||||
if not no_fixed_seeds:
|
||||
modules.processing.fix_seed(p)
|
||||
|
||||
|
@ -513,6 +523,10 @@ class Script(scripts.Script):
|
|||
|
||||
grid_infotext = [None]
|
||||
|
||||
state.xyz_plot_x = AxisInfo(x_opt, xs)
|
||||
state.xyz_plot_y = AxisInfo(y_opt, ys)
|
||||
state.xyz_plot_z = AxisInfo(z_opt, zs)
|
||||
|
||||
# If one of the axes is very slow to change between (like SD model
|
||||
# checkpoint), then make sure it is in the outer iteration of the nested
|
||||
# `for` loop.
|
||||
|
@ -576,7 +590,7 @@ class Script(scripts.Script):
|
|||
return res
|
||||
|
||||
with SharedSettingsStackHelper():
|
||||
processed = draw_xyz_grid(
|
||||
processed, sub_grids = draw_xyz_grid(
|
||||
p,
|
||||
xs=xs,
|
||||
ys=ys,
|
||||
|
@ -589,9 +603,14 @@ class Script(scripts.Script):
|
|||
include_lone_images=include_lone_images,
|
||||
include_sub_grids=include_sub_grids,
|
||||
first_axes_processed=first_axes_processed,
|
||||
second_axes_processed=second_axes_processed
|
||||
second_axes_processed=second_axes_processed,
|
||||
margin_size=margin_size
|
||||
)
|
||||
|
||||
if opts.grid_save and len(sub_grids) > 1:
|
||||
for sub_grid in sub_grids:
|
||||
images.save_image(sub_grid, p.outpath_grids, "xyz_grid", info=grid_infotext[0], extension=opts.grid_format, prompt=p.prompt, seed=processed.seed, grid=True, p=p)
|
||||
|
||||
if opts.grid_save:
|
||||
images.save_image(processed.images[0], p.outpath_grids, "xyz_grid", info=grid_infotext[0], extension=opts.grid_format, prompt=p.prompt, seed=processed.seed, grid=True, p=p)
|
||||
|
||||
|
|
2
webui.py
2
webui.py
|
@ -20,6 +20,7 @@ import torch
|
|||
|
||||
# Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors
|
||||
if ".dev" in torch.__version__ or "+git" in torch.__version__:
|
||||
torch.__long_version__ = torch.__version__
|
||||
torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0)
|
||||
|
||||
from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir, ui_extra_networks
|
||||
|
@ -97,7 +98,6 @@ def initialize():
|
|||
modules.sd_models.setup_model()
|
||||
codeformer.setup_model(cmd_opts.codeformer_models_path)
|
||||
gfpgan.setup_model(cmd_opts.gfpgan_models_path)
|
||||
shared.face_restorers.append(modules.face_restoration.FaceRestoration())
|
||||
|
||||
modelloader.list_builtin_upscalers()
|
||||
modules.scripts.load_scripts()
|
||||
|
|
Loading…
Reference in a new issue