Merge branch 'master' into gallery-styling

This commit is contained in:
AUTOMATIC1111 2022-10-06 20:30:29 +03:00 committed by GitHub
commit ab4ddbf333
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
51 changed files with 2249 additions and 714 deletions

1
.gitignore vendored
View file

@ -25,3 +25,4 @@ __pycache__
/.idea /.idea
notification.mp3 notification.mp3
/SwinIR /SwinIR
/textual_inversion

View file

@ -11,44 +11,56 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web
- One click install and run script (but you still must install python and git) - One click install and run script (but you still must install python and git)
- Outpainting - Outpainting
- Inpainting - Inpainting
- Prompt matrix - Prompt Matrix
- Stable Diffusion upscale - Stable Diffusion Upscale
- Attention - Attention, specify parts of text that the model should pay more attention to
- Loopback - a man in a ((tuxedo)) - will pay more attention to tuxedo
- X/Y plot - a man in a (tuxedo:1.21) - alternative syntax
- Loopback, run img2img processing multiple times
- X/Y plot, a way to draw a 2 dimensional plot of images with different parameters
- Textual Inversion - Textual Inversion
- have as many embeddings as you want and use any names you like for them
- use multiple embeddings with different numbers of vectors per token
- works with half precision floating point numbers
- Extras tab with: - Extras tab with:
- GFPGAN, neural network that fixes faces - GFPGAN, neural network that fixes faces
- CodeFormer, face restoration tool as an alternative to GFPGAN - CodeFormer, face restoration tool as an alternative to GFPGAN
- RealESRGAN, neural network upscaler - RealESRGAN, neural network upscaler
- ESRGAN, neural network with a lot of third party models - ESRGAN, neural network upscaler with a lot of third party models
- SwinIR, neural network upscaler - SwinIR, neural network upscaler
- LDSR, Latent diffusion super resolution upscaling - LDSR, Latent diffusion super resolution upscaling
- Resizing aspect ratio options - Resizing aspect ratio options
- Sampling method selection - Sampling method selection
- Interrupt processing at any time - Interrupt processing at any time
- 4GB video card support - 4GB video card support (also reports of 2GB working)
- Correct seeds for batches - Correct seeds for batches
- Prompt length validation - Prompt length validation
- Generation parameters added as text to PNG - get length of prompt in tokens as you type
- Tab to view an existing picture's generation parameters - get a warning after generation if some text was truncated
- Generation parameters
- parameters you used to generate images are saved with that image
- in PNG chunks for PNG, in EXIF for JPEG
- can drag the image to PNG info tab to restore generation parameters and automatically copy them into UI
- can be disabled in settings
- Settings page - Settings page
- Running custom code from UI - Running arbitrary python code from UI (must run with --allow-code to enable)
- Mouseover hints for most UI elements - Mouseover hints for most UI elements
- Possible to change defaults/mix/max/step values for UI elements via text config - Possible to change defaults/mix/max/step values for UI elements via text config
- Random artist button - Random artist button
- Tiling support: UI checkbox to create images that can be tiled like textures - Tiling support, a checkbox to create images that can be tiled like textures
- Progress bar and live image generation preview - Progress bar and live image generation preview
- Negative prompt - Negative prompt, an extra text field that allows you to list what you don't want to see in generated image
- Styles - Styles, a way to save part of prompt and easily apply them via dropdown later
- Variations - Variations, a way to generate same image but with tiny differences
- Seed resizing - Seed resizing, a way to generate same image but at slightly different resolution
- CLIP interrogator - CLIP interrogator, a button that tries to guess prompt from an image
- Prompt Editing - Prompt Editing, a way to change prompt mid-generation, say to start making a watermelon and switch to anime girl midway
- Batch Processing - Batch Processing, process a group of files using img2img
- Img2img Alternative - Img2img Alternative
- Highres Fix - Highres Fix, a convenience option to produce high resolution pictures in one click without usual distortions
- LDSR Upscaling - Reloading checkpoints on the fly
- Checkpoint Merger, a tab that allows you to merge two checkpoints into one
- [Custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Scripts) with many extensions from community
## Installation and Running ## Installation and Running
Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for both [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended) and [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs. Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for both [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended) and [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs.
@ -101,6 +113,7 @@ The documentation was moved from this README over to the project's [wiki](https:
- LDSR - https://github.com/Hafiidz/latent-diffusion - LDSR - https://github.com/Hafiidz/latent-diffusion
- Ideas for optimizations - https://github.com/basujindal/stable-diffusion - Ideas for optimizations - https://github.com/basujindal/stable-diffusion
- Doggettx - Cross Attention layer optimization - https://github.com/Doggettx/stable-diffusion, original idea for prompt editing. - Doggettx - Cross Attention layer optimization - https://github.com/Doggettx/stable-diffusion, original idea for prompt editing.
- Rinon Gal - Textual Inversion - https://github.com/rinongal/textual_inversion (we're not using his code, but we are using his ideas).
- Idea for SD upscale - https://github.com/jquesnelle/txt2imghd - Idea for SD upscale - https://github.com/jquesnelle/txt2imghd
- Noise generation for outpainting mk2 - https://github.com/parlance-zz/g-diffuser-bot - Noise generation for outpainting mk2 - https://github.com/parlance-zz/g-diffuser-bot
- CLIP interrogator idea and borrowing some code - https://github.com/pharmapsychotic/clip-interrogator - CLIP interrogator idea and borrowing some code - https://github.com/pharmapsychotic/clip-interrogator

View file

@ -15,7 +15,7 @@ titles = {
"\u267b\ufe0f": "Reuse seed from last generation, mostly useful if it was randomed", "\u267b\ufe0f": "Reuse seed from last generation, mostly useful if it was randomed",
"\u{1f3a8}": "Add a random artist to the prompt.", "\u{1f3a8}": "Add a random artist to the prompt.",
"\u2199\ufe0f": "Read generation parameters from prompt into user interface.", "\u2199\ufe0f": "Read generation parameters from prompt into user interface.",
"\uD83D\uDCC2": "Open images output directory", "\u{1f4c2}": "Open images output directory",
"Inpaint a part of image": "Draw a mask over an image, and the script will regenerate the masked area with content according to prompt", "Inpaint a part of image": "Draw a mask over an image, and the script will regenerate the masked area with content according to prompt",
"SD upscale": "Upscale image normally, split result into tiles, improve each tile using img2img, merge whole image back", "SD upscale": "Upscale image normally, split result into tiles, improve each tile using img2img, merge whole image back",
@ -47,6 +47,7 @@ titles = {
"Custom code": "Run Python code. Advanced user only. Must run program with --allow-code for this to work", "Custom code": "Run Python code. Advanced user only. Must run program with --allow-code for this to work",
"Prompt S/R": "Separate a list of words with commas, and the first word will be used as a keyword: script will search for this word in the prompt, and replace it with others", "Prompt S/R": "Separate a list of words with commas, and the first word will be used as a keyword: script will search for this word in the prompt, and replace it with others",
"Prompt order": "Separate a list of words with commas, and the script will make a variation of prompt with those words for their every possible order",
"Tiling": "Produce an image that can be tiled.", "Tiling": "Produce an image that can be tiled.",
"Tile overlap": "For SD upscale, how much overlap in pixels should there be between tiles. Tiles overlap so that when they are merged back into one picture, there is no clearly visible seam.", "Tile overlap": "For SD upscale, how much overlap in pixels should there be between tiles. Tiles overlap so that when they are merged back into one picture, there is no clearly visible seam.",

View file

@ -4,6 +4,21 @@ global_progressbars = {}
function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_interrupt, id_preview, id_gallery){ function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_interrupt, id_preview, id_gallery){
var progressbar = gradioApp().getElementById(id_progressbar) var progressbar = gradioApp().getElementById(id_progressbar)
var interrupt = gradioApp().getElementById(id_interrupt) var interrupt = gradioApp().getElementById(id_interrupt)
if(opts.show_progress_in_title && progressbar && progressbar.offsetParent){
if(progressbar.innerText){
let newtitle = 'Stable Diffusion - ' + progressbar.innerText
if(document.title != newtitle){
document.title = newtitle;
}
}else{
let newtitle = 'Stable Diffusion'
if(document.title != newtitle){
document.title = newtitle;
}
}
}
if(progressbar!= null && progressbar != global_progressbars[id_progressbar]){ if(progressbar!= null && progressbar != global_progressbars[id_progressbar]){
global_progressbars[id_progressbar] = progressbar global_progressbars[id_progressbar] = progressbar
@ -30,6 +45,7 @@ function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_inte
onUiUpdate(function(){ onUiUpdate(function(){
check_progressbar('txt2img', 'txt2img_progressbar', 'txt2img_progress_span', 'txt2img_interrupt', 'txt2img_preview', 'txt2img_gallery') check_progressbar('txt2img', 'txt2img_progressbar', 'txt2img_progress_span', 'txt2img_interrupt', 'txt2img_preview', 'txt2img_gallery')
check_progressbar('img2img', 'img2img_progressbar', 'img2img_progress_span', 'img2img_interrupt', 'img2img_preview', 'img2img_gallery') check_progressbar('img2img', 'img2img_progressbar', 'img2img_progress_span', 'img2img_interrupt', 'img2img_preview', 'img2img_gallery')
check_progressbar('ti', 'ti_progressbar', 'ti_progress_span', 'ti_interrupt', 'ti_preview', 'ti_gallery')
}) })
function requestMoreProgress(id_part, id_progressbar_span, id_interrupt){ function requestMoreProgress(id_part, id_progressbar_span, id_interrupt){

View file

@ -0,0 +1,8 @@
function start_training_textual_inversion(){
requestProgress('ti')
gradioApp().querySelector('#ti_error').innerHTML=''
return args_to_array(arguments)
}

View file

@ -186,10 +186,12 @@ onUiUpdate(function(){
if (!txt2img_textarea) { if (!txt2img_textarea) {
txt2img_textarea = gradioApp().querySelector("#txt2img_prompt > label > textarea"); txt2img_textarea = gradioApp().querySelector("#txt2img_prompt > label > textarea");
txt2img_textarea?.addEventListener("input", () => update_token_counter("txt2img_token_button")); txt2img_textarea?.addEventListener("input", () => update_token_counter("txt2img_token_button"));
txt2img_textarea?.addEventListener("keyup", (event) => submit_prompt(event, "txt2img_generate"));
} }
if (!img2img_textarea) { if (!img2img_textarea) {
img2img_textarea = gradioApp().querySelector("#img2img_prompt > label > textarea"); img2img_textarea = gradioApp().querySelector("#img2img_prompt > label > textarea");
img2img_textarea?.addEventListener("input", () => update_token_counter("img2img_token_button")); img2img_textarea?.addEventListener("input", () => update_token_counter("img2img_token_button"));
img2img_textarea?.addEventListener("keyup", (event) => submit_prompt(event, "img2img_generate"));
} }
}) })
@ -197,8 +199,35 @@ let txt2img_textarea, img2img_textarea = undefined;
let wait_time = 800 let wait_time = 800
let token_timeout; let token_timeout;
function update_txt2img_tokens(...args) {
update_token_counter("txt2img_token_button")
if (args.length == 2)
return args[0]
return args;
}
function update_img2img_tokens(...args) {
update_token_counter("img2img_token_button")
if (args.length == 2)
return args[0]
return args;
}
function update_token_counter(button_id) { function update_token_counter(button_id) {
if (token_timeout) if (token_timeout)
clearTimeout(token_timeout); clearTimeout(token_timeout);
token_timeout = setTimeout(() => gradioApp().getElementById(button_id)?.click(), wait_time); token_timeout = setTimeout(() => gradioApp().getElementById(button_id)?.click(), wait_time);
} }
function submit_prompt(event, generate_button_id) {
if (event.altKey && event.keyCode === 13) {
event.preventDefault();
gradioApp().getElementById(generate_button_id).click();
return;
}
}
function restart_reload(){
document.body.innerHTML='<h1 style="font-family:monospace;margin-top:20%;color:lightgray;text-align:center;">Reloading...</h1>';
setTimeout(function(){location.reload()},2000)
}

View file

@ -15,10 +15,11 @@ requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
commandline_args = os.environ.get('COMMANDLINE_ARGS', "") commandline_args = os.environ.get('COMMANDLINE_ARGS', "")
gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379") gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379")
clip_package = os.environ.get('CLIP_PACKAGE', "git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1")
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "69ae4b35e0a0f6ee1af8bb9a5d0016ccb27e36dc") stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "69ae4b35e0a0f6ee1af8bb9a5d0016ccb27e36dc")
taming_transformers_commit_hash = os.environ.get('TAMING_TRANSFORMERS_COMMIT_HASH', "24268930bf1dce879235a7fddd0b2355b84d7ea6") taming_transformers_commit_hash = os.environ.get('TAMING_TRANSFORMERS_COMMIT_HASH', "24268930bf1dce879235a7fddd0b2355b84d7ea6")
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "a7ec1974d4ccb394c2dca275f42cd97490618924") k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "f4e99857772fc3a126ba886aadf795a332774878")
codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af") codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9") blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
@ -85,6 +86,15 @@ def git_clone(url, dir, name, commithash=None):
# TODO clone into temporary dir and move if successful # TODO clone into temporary dir and move if successful
if os.path.exists(dir): if os.path.exists(dir):
if commithash is None:
return
current_hash = run(f'"{git}" -C {dir} rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}").strip()
if current_hash == commithash:
return
run(f'"{git}" -C {dir} fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}")
run(f'"{git}" -C {dir} checkout {commithash}', f"Checking out commint for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}")
return return
run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}") run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}")
@ -111,6 +121,9 @@ if not skip_torch_cuda_test:
if not is_installed("gfpgan"): if not is_installed("gfpgan"):
run_pip(f"install {gfpgan_package}", "gfpgan") run_pip(f"install {gfpgan_package}", "gfpgan")
if not is_installed("clip"):
run_pip(f"install {clip_package}", "clip")
os.makedirs(dir_repos, exist_ok=True) os.makedirs(dir_repos, exist_ok=True)
git_clone("https://github.com/CompVis/stable-diffusion.git", repo_dir('stable-diffusion'), "Stable Diffusion", stable_diffusion_commit_hash) git_clone("https://github.com/CompVis/stable-diffusion.git", repo_dir('stable-diffusion'), "Stable Diffusion", stable_diffusion_commit_hash)

View file

@ -8,7 +8,7 @@ import torch
from basicsr.utils.download_util import load_file_from_url from basicsr.utils.download_util import load_file_from_url
import modules.upscaler import modules.upscaler
from modules import shared, modelloader from modules import devices, modelloader
from modules.bsrgan_model_arch import RRDBNet from modules.bsrgan_model_arch import RRDBNet
from modules.paths import models_path from modules.paths import models_path
@ -44,13 +44,13 @@ class UpscalerBSRGAN(modules.upscaler.Upscaler):
model = self.load_model(selected_file) model = self.load_model(selected_file)
if model is None: if model is None:
return img return img
model.to(shared.device) model.to(devices.device_bsrgan)
torch.cuda.empty_cache() torch.cuda.empty_cache()
img = np.array(img) img = np.array(img)
img = img[:, :, ::-1] img = img[:, :, ::-1]
img = np.moveaxis(img, 2, 0) / 255 img = np.moveaxis(img, 2, 0) / 255
img = torch.from_numpy(img).float() img = torch.from_numpy(img).float()
img = img.unsqueeze(0).to(shared.device) img = img.unsqueeze(0).to(devices.device_bsrgan)
with torch.no_grad(): with torch.no_grad():
output = model(img) output = model(img)
output = output.squeeze().float().cpu().clamp_(0, 1).numpy() output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
@ -67,10 +67,9 @@ class UpscalerBSRGAN(modules.upscaler.Upscaler):
else: else:
filename = path filename = path
if not os.path.exists(filename) or filename is None: if not os.path.exists(filename) or filename is None:
print("Unable to load %s from %s" % (self.model_dir, filename)) print(f"BSRGAN: Unable to load model from {filename}", file=sys.stderr)
return None return None
print("Loading %s from %s" % (self.model_dir, filename)) model = RRDBNet(in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=4) # define network
model = RRDBNet(in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=2) # define network
model.load_state_dict(torch.load(filename), strict=True) model.load_state_dict(torch.load(filename), strict=True)
model.eval() model.eval()
for k, v in model.named_parameters(): for k, v in model.named_parameters():

View file

@ -76,7 +76,6 @@ class RRDBNet(nn.Module):
super(RRDBNet, self).__init__() super(RRDBNet, self).__init__()
RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc) RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
self.sf = sf self.sf = sf
print([in_nc, out_nc, nf, nb, gc, sf])
self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
self.RRDB_trunk = make_layer(RRDB_block_f, nb) self.RRDB_trunk = make_layer(RRDB_block_f, nb)

View file

@ -69,10 +69,14 @@ def setup_model(dirname):
self.net = net self.net = net
self.face_helper = face_helper self.face_helper = face_helper
self.net.to(devices.device_codeformer)
return net, face_helper return net, face_helper
def send_model_to(self, device):
self.net.to(device)
self.face_helper.face_det.to(device)
self.face_helper.face_parse.to(device)
def restore(self, np_image, w=None): def restore(self, np_image, w=None):
np_image = np_image[:, :, ::-1] np_image = np_image[:, :, ::-1]
@ -82,6 +86,8 @@ def setup_model(dirname):
if self.net is None or self.face_helper is None: if self.net is None or self.face_helper is None:
return np_image return np_image
self.send_model_to(devices.device_codeformer)
self.face_helper.clean_all() self.face_helper.clean_all()
self.face_helper.read_image(np_image) self.face_helper.read_image(np_image)
self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5) self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
@ -113,8 +119,10 @@ def setup_model(dirname):
if original_resolution != restored_img.shape[0:2]: if original_resolution != restored_img.shape[0:2]:
restored_img = cv2.resize(restored_img, (0, 0), fx=original_resolution[1]/restored_img.shape[1], fy=original_resolution[0]/restored_img.shape[0], interpolation=cv2.INTER_LINEAR) restored_img = cv2.resize(restored_img, (0, 0), fx=original_resolution[1]/restored_img.shape[1], fy=original_resolution[0]/restored_img.shape[0], interpolation=cv2.INTER_LINEAR)
self.face_helper.clean_all()
if shared.opts.face_restoration_unload: if shared.opts.face_restoration_unload:
self.net.to(devices.cpu) self.send_model_to(devices.cpu)
return restored_img return restored_img

View file

@ -1,8 +1,10 @@
import contextlib
import torch import torch
# has_mps is only available in nightly pytorch (for now), `getattr` for compatibility
from modules import errors from modules import errors
# has_mps is only available in nightly pytorch (for now), `getattr` for compatibility
has_mps = getattr(torch, 'has_mps', False) has_mps = getattr(torch, 'has_mps', False)
cpu = torch.device("cpu") cpu = torch.device("cpu")
@ -32,10 +34,8 @@ def enable_tf32():
errors.run(enable_tf32, "Enabling TF32") errors.run(enable_tf32, "Enabling TF32")
device = device_gfpgan = device_bsrgan = device_esrgan = device_scunet = device_codeformer = get_optimal_device()
device = get_optimal_device() dtype = torch.float16
device_codeformer = cpu if has_mps else device
def randn(seed, shape): def randn(seed, shape):
# Pytorch currently doesn't handle setting randomness correctly when the metal backend is used. # Pytorch currently doesn't handle setting randomness correctly when the metal backend is used.
@ -58,3 +58,11 @@ def randn_without_seed(shape):
return torch.randn(shape, device=device) return torch.randn(shape, device=device)
def autocast():
from modules import shared
if dtype == torch.float32 or shared.cmd_opts.precision == "full":
return contextlib.nullcontext()
return torch.autocast("cuda")

View file

@ -6,8 +6,7 @@ from PIL import Image
from basicsr.utils.download_util import load_file_from_url from basicsr.utils.download_util import load_file_from_url
import modules.esrgam_model_arch as arch import modules.esrgam_model_arch as arch
from modules import shared, modelloader, images from modules import shared, modelloader, images, devices
from modules.devices import has_mps
from modules.paths import models_path from modules.paths import models_path
from modules.upscaler import Upscaler, UpscalerData from modules.upscaler import Upscaler, UpscalerData
from modules.shared import opts from modules.shared import opts
@ -73,8 +72,8 @@ def fix_model_layers(crt_model, pretrained_net):
class UpscalerESRGAN(Upscaler): class UpscalerESRGAN(Upscaler):
def __init__(self, dirname): def __init__(self, dirname):
self.name = "ESRGAN" self.name = "ESRGAN"
self.model_url = "https://drive.google.com/u/0/uc?id=1TPrz5QKd8DHHt1k8SRtm6tMiPjz_Qene&export=download" self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/ESRGAN.pth"
self.model_name = "ESRGAN 4x" self.model_name = "ESRGAN_4x"
self.scalers = [] self.scalers = []
self.user_path = dirname self.user_path = dirname
self.model_path = os.path.join(models_path, self.name) self.model_path = os.path.join(models_path, self.name)
@ -97,7 +96,7 @@ class UpscalerESRGAN(Upscaler):
model = self.load_model(selected_model) model = self.load_model(selected_model)
if model is None: if model is None:
return img return img
model.to(shared.device) model.to(devices.device_esrgan)
img = esrgan_upscale(model, img) img = esrgan_upscale(model, img)
return img return img
@ -112,7 +111,7 @@ class UpscalerESRGAN(Upscaler):
print("Unable to load %s from %s" % (self.model_path, filename)) print("Unable to load %s from %s" % (self.model_path, filename))
return None return None
pretrained_net = torch.load(filename, map_location='cpu' if has_mps else None) pretrained_net = torch.load(filename, map_location='cpu' if shared.device.type == 'mps' else None)
crt_model = arch.RRDBNet(3, 3, 64, 23, gc=32) crt_model = arch.RRDBNet(3, 3, 64, 23, gc=32)
pretrained_net = fix_model_layers(crt_model, pretrained_net) pretrained_net = fix_model_layers(crt_model, pretrained_net)
@ -127,7 +126,7 @@ def upscale_without_tiling(model, img):
img = img[:, :, ::-1] img = img[:, :, ::-1]
img = np.moveaxis(img, 2, 0) / 255 img = np.moveaxis(img, 2, 0) / 255
img = torch.from_numpy(img).float() img = torch.from_numpy(img).float()
img = img.unsqueeze(0).to(shared.device) img = img.unsqueeze(0).to(devices.device_esrgan)
with torch.no_grad(): with torch.no_grad():
output = model(img) output = model(img)
output = output.squeeze().float().cpu().clamp_(0, 1).numpy() output = output.squeeze().float().cpu().clamp_(0, 1).numpy()

View file

@ -100,6 +100,8 @@ def run_extras(extras_mode, image, image_folder, gfpgan_visibility, codeformer_v
outputs.append(image) outputs.append(image)
devices.torch_gc()
return outputs, plaintext_to_html(info), '' return outputs, plaintext_to_html(info), ''
@ -191,9 +193,11 @@ def run_modelmerger(primary_model_name, secondary_model_name, interp_method, int
if save_as_half: if save_as_half:
theta_0[key] = theta_0[key].half() theta_0[key] = theta_0[key].half()
ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path
filename = primary_model_info.model_name + '_' + str(round(interp_amount, 2)) + '-' + secondary_model_info.model_name + '_' + str(round((float(1.0) - interp_amount), 2)) + '-' + interp_method.replace(" ", "_") + '-merged.ckpt' filename = primary_model_info.model_name + '_' + str(round(interp_amount, 2)) + '-' + secondary_model_info.model_name + '_' + str(round((float(1.0) - interp_amount), 2)) + '-' + interp_method.replace(" ", "_") + '-merged.ckpt'
filename = filename if custom_name == '' else (custom_name + '.ckpt') filename = filename if custom_name == '' else (custom_name + '.ckpt')
output_modelname = os.path.join(shared.cmd_opts.ckpt_dir, filename) output_modelname = os.path.join(ckpt_dir, filename)
print(f"Saving to {output_modelname}...") print(f"Saving to {output_modelname}...")
torch.save(primary_model, output_modelname) torch.save(primary_model, output_modelname)

View file

@ -21,7 +21,7 @@ def gfpgann():
global loaded_gfpgan_model global loaded_gfpgan_model
global model_path global model_path
if loaded_gfpgan_model is not None: if loaded_gfpgan_model is not None:
loaded_gfpgan_model.gfpgan.to(shared.device) loaded_gfpgan_model.gfpgan.to(devices.device_gfpgan)
return loaded_gfpgan_model return loaded_gfpgan_model
if gfpgan_constructor is None: if gfpgan_constructor is None:
@ -37,22 +37,32 @@ def gfpgann():
print("Unable to load gfpgan model!") print("Unable to load gfpgan model!")
return None return None
model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None) model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None)
model.gfpgan.to(shared.device)
loaded_gfpgan_model = model loaded_gfpgan_model = model
return model return model
def send_model_to(model, device):
model.gfpgan.to(device)
model.face_helper.face_det.to(device)
model.face_helper.face_parse.to(device)
def gfpgan_fix_faces(np_image): def gfpgan_fix_faces(np_image):
model = gfpgann() model = gfpgann()
if model is None: if model is None:
return np_image return np_image
send_model_to(model, devices.device_gfpgan)
np_image_bgr = np_image[:, :, ::-1] np_image_bgr = np_image[:, :, ::-1]
cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True) cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True)
np_image = gfpgan_output_bgr[:, :, ::-1] np_image = gfpgan_output_bgr[:, :, ::-1]
model.face_helper.clean_all()
if shared.opts.face_restoration_unload: if shared.opts.face_restoration_unload:
model.gfpgan.to(devices.cpu) send_model_to(model, devices.cpu)
return np_image return np_image
@ -97,11 +107,7 @@ def setup_model(dirname):
return "GFPGAN" return "GFPGAN"
def restore(self, np_image): def restore(self, np_image):
np_image_bgr = np_image[:, :, ::-1] return gfpgan_fix_faces(np_image)
cropped_faces, restored_faces, gfpgan_output_bgr = gfpgann().enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True)
np_image = gfpgan_output_bgr[:, :, ::-1]
return np_image
shared.face_restorers.append(FaceRestorerGFPGAN()) shared.face_restorers.append(FaceRestorerGFPGAN())
except Exception: except Exception:

View file

@ -213,17 +213,19 @@ def resize_image(resize_mode, im, width, height):
if opts.upscaler_for_img2img is None or opts.upscaler_for_img2img == "None" or im.mode == 'L': if opts.upscaler_for_img2img is None or opts.upscaler_for_img2img == "None" or im.mode == 'L':
return im.resize((w, h), resample=LANCZOS) return im.resize((w, h), resample=LANCZOS)
upscalers = [x for x in shared.sd_upscalers if x.name == opts.upscaler_for_img2img]
assert len(upscalers) > 0, f"could not find upscaler named {opts.upscaler_for_img2img}"
upscaler = upscalers[0]
scale = max(w / im.width, h / im.height) scale = max(w / im.width, h / im.height)
upscaled = upscaler.scaler.upscale(im, scale, upscaler.data_path)
if upscaled.width != w or upscaled.height != h: if scale > 1.0:
upscaled = im.resize((w, h), resample=LANCZOS) upscalers = [x for x in shared.sd_upscalers if x.name == opts.upscaler_for_img2img]
assert len(upscalers) > 0, f"could not find upscaler named {opts.upscaler_for_img2img}"
return upscaled upscaler = upscalers[0]
im = upscaler.scaler.upscale(im, scale, upscaler.data_path)
if im.width != w or im.height != h:
im = im.resize((w, h), resample=LANCZOS)
return im
if resize_mode == 0: if resize_mode == 0:
res = resize(im, width, height) res = resize(im, width, height)
@ -285,6 +287,25 @@ def apply_filename_pattern(x, p, seed, prompt):
if seed is not None: if seed is not None:
x = x.replace("[seed]", str(seed)) x = x.replace("[seed]", str(seed))
if p is not None:
x = x.replace("[steps]", str(p.steps))
x = x.replace("[cfg]", str(p.cfg_scale))
x = x.replace("[width]", str(p.width))
x = x.replace("[height]", str(p.height))
#currently disabled if using the save button, will work otherwise
# if enabled it will cause a bug because styles is not included in the save_files data dictionary
if hasattr(p, "styles"):
x = x.replace("[styles]", sanitize_filename_part(", ".join([x for x in p.styles if not x == "None"]) or "None", replace_spaces=False))
x = x.replace("[sampler]", sanitize_filename_part(sd_samplers.samplers[p.sampler_index].name, replace_spaces=False))
x = x.replace("[model_hash]", shared.sd_model.sd_model_hash)
x = x.replace("[date]", datetime.date.today().isoformat())
x = x.replace("[datetime]", datetime.datetime.now().strftime("%Y%m%d%H%M%S"))
x = x.replace("[job_timestamp]", shared.state.job_timestamp)
# Apply [prompt] at last. Because it may contain any replacement word.^M
if prompt is not None: if prompt is not None:
x = x.replace("[prompt]", sanitize_filename_part(prompt)) x = x.replace("[prompt]", sanitize_filename_part(prompt))
if "[prompt_no_styles]" in x: if "[prompt_no_styles]" in x:
@ -293,7 +314,7 @@ def apply_filename_pattern(x, p, seed, prompt):
if len(style) > 0: if len(style) > 0:
style_parts = [y for y in style.split("{prompt}")] style_parts = [y for y in style.split("{prompt}")]
for part in style_parts: for part in style_parts:
prompt_no_style = prompt_no_style.replace(part, "").replace(", ,", ",").strip().strip(',') prompt_no_style = prompt_no_style.replace(part, "").replace(", ,", ",").strip().strip(',')
prompt_no_style = prompt_no_style.replace(style, "").strip().strip(',').strip() prompt_no_style = prompt_no_style.replace(style, "").strip().strip(',').strip()
x = x.replace("[prompt_no_styles]", sanitize_filename_part(prompt_no_style, replace_spaces=False)) x = x.replace("[prompt_no_styles]", sanitize_filename_part(prompt_no_style, replace_spaces=False))
@ -304,19 +325,6 @@ def apply_filename_pattern(x, p, seed, prompt):
words = ["empty"] words = ["empty"]
x = x.replace("[prompt_words]", sanitize_filename_part(" ".join(words[0:max_prompt_words]), replace_spaces=False)) x = x.replace("[prompt_words]", sanitize_filename_part(" ".join(words[0:max_prompt_words]), replace_spaces=False))
if p is not None:
x = x.replace("[steps]", str(p.steps))
x = x.replace("[cfg]", str(p.cfg_scale))
x = x.replace("[width]", str(p.width))
x = x.replace("[height]", str(p.height))
x = x.replace("[styles]", sanitize_filename_part(", ".join([x for x in p.styles if not x == "None"]), replace_spaces=False))
x = x.replace("[sampler]", sanitize_filename_part(sd_samplers.samplers[p.sampler_index].name, replace_spaces=False))
x = x.replace("[model_hash]", shared.sd_model.sd_model_hash)
x = x.replace("[date]", datetime.date.today().isoformat())
x = x.replace("[datetime]", datetime.datetime.now().strftime("%Y%m%d%H%M%S"))
x = x.replace("[job_timestamp]", shared.state.job_timestamp)
if cmd_opts.hide_ui_dir_config: if cmd_opts.hide_ui_dir_config:
x = re.sub(r'^[\\/]+|\.{2,}[\\/]+|[\\/]+\.{2,}', '', x) x = re.sub(r'^[\\/]+|\.{2,}[\\/]+|[\\/]+\.{2,}', '', x)
@ -345,7 +353,7 @@ def get_next_sequence_number(path, basename):
return result + 1 return result + 1
def save_image(image, path, basename, seed=None, prompt=None, extension='png', info=None, short_filename=False, no_prompt=False, grid=False, pnginfo_section_name='parameters', p=None, existing_info=None, forced_filename=None, suffix=""): def save_image(image, path, basename, seed=None, prompt=None, extension='png', info=None, short_filename=False, no_prompt=False, grid=False, pnginfo_section_name='parameters', p=None, existing_info=None, forced_filename=None, suffix="", save_to_dirs=None):
if short_filename or prompt is None or seed is None: if short_filename or prompt is None or seed is None:
file_decoration = "" file_decoration = ""
elif opts.save_to_dirs: elif opts.save_to_dirs:
@ -369,10 +377,11 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
else: else:
pnginfo = None pnginfo = None
save_to_dirs = (grid and opts.grid_save_to_dirs) or (not grid and opts.save_to_dirs and not no_prompt) if save_to_dirs is None:
save_to_dirs = (grid and opts.grid_save_to_dirs) or (not grid and opts.save_to_dirs and not no_prompt)
if save_to_dirs: if save_to_dirs:
dirname = apply_filename_pattern(opts.directories_filename_pattern or "[prompt_words]", p, seed, prompt) dirname = apply_filename_pattern(opts.directories_filename_pattern or "[prompt_words]", p, seed, prompt).strip('\\ /')
path = os.path.join(path, dirname) path = os.path.join(path, dirname)
os.makedirs(path, exist_ok=True) os.makedirs(path, exist_ok=True)
@ -423,4 +432,4 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
with open(f"{fullfn_without_extension}.txt", "w", encoding="utf8") as file: with open(f"{fullfn_without_extension}.txt", "w", encoding="utf8") as file:
file.write(info + "\n") file.write(info + "\n")
return fullfn

View file

@ -23,8 +23,10 @@ def process_batch(p, input_dir, output_dir, args):
print(f"Will process {len(images)} images, creating {p.n_iter * p.batch_size} new images for each.") print(f"Will process {len(images)} images, creating {p.n_iter * p.batch_size} new images for each.")
save_normally = output_dir == ''
p.do_not_save_grid = True p.do_not_save_grid = True
p.do_not_save_samples = True p.do_not_save_samples = not save_normally
state.job_count = len(images) * p.n_iter state.job_count = len(images) * p.n_iter
@ -48,7 +50,8 @@ def process_batch(p, input_dir, output_dir, args):
left, right = os.path.splitext(filename) left, right = os.path.splitext(filename)
filename = f"{left}-{n}{right}" filename = f"{left}-{n}{right}"
processed_image.save(os.path.join(output_dir, filename)) if not save_normally:
processed_image.save(os.path.join(output_dir, filename))
def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, init_img_with_mask, init_img_inpaint, init_mask_inpaint, mask_mode, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args): def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, init_img_with_mask, init_img_inpaint, init_mask_inpaint, mask_mode, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args):
@ -103,7 +106,9 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
inpaint_full_res_padding=inpaint_full_res_padding, inpaint_full_res_padding=inpaint_full_res_padding,
inpainting_mask_invert=inpainting_mask_invert, inpainting_mask_invert=inpainting_mask_invert,
) )
print(f"\nimg2img: {prompt}", file=shared.progress_print_out)
if shared.cmd_opts.enable_console_prompts:
print(f"\nimg2img: {prompt}", file=shared.progress_print_out)
p.extra_generation_params["Mask blur"] = mask_blur p.extra_generation_params["Mask blur"] = mask_blur
@ -124,4 +129,7 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
if opts.samples_log_stdout: if opts.samples_log_stdout:
print(generation_info_js) print(generation_info_js)
if opts.do_not_show_images:
processed.images = []
return processed.images, generation_info_js, plaintext_to_html(processed.info) return processed.images, generation_info_js, plaintext_to_html(processed.info)

View file

@ -21,6 +21,7 @@ Category = namedtuple("Category", ["name", "topn", "items"])
re_topn = re.compile(r"\.top(\d+)\.") re_topn = re.compile(r"\.top(\d+)\.")
class InterrogateModels: class InterrogateModels:
blip_model = None blip_model = None
clip_model = None clip_model = None

View file

@ -22,9 +22,21 @@ class UpscalerLDSR(Upscaler):
self.scalers = [scaler_data] self.scalers = [scaler_data]
def load_model(self, path: str): def load_model(self, path: str):
# Remove incorrect project.yaml file if too big
yaml_path = os.path.join(self.model_path, "project.yaml")
old_model_path = os.path.join(self.model_path, "model.pth")
new_model_path = os.path.join(self.model_path, "model.ckpt")
if os.path.exists(yaml_path):
statinfo = os.stat(yaml_path)
if statinfo.st_size >= 10485760:
print("Removing invalid LDSR YAML file.")
os.remove(yaml_path)
if os.path.exists(old_model_path):
print("Renaming model from model.pth to model.ckpt")
os.rename(old_model_path, new_model_path)
model = load_file_from_url(url=self.model_url, model_dir=self.model_path, model = load_file_from_url(url=self.model_url, model_dir=self.model_path,
file_name="model.pth", progress=True) file_name="model.ckpt", progress=True)
yaml = load_file_from_url(url=self.model_url, model_dir=self.model_path, yaml = load_file_from_url(url=self.yaml_url, model_dir=self.model_path,
file_name="project.yaml", progress=True) file_name="project.yaml", progress=True)
try: try:
@ -41,5 +53,4 @@ class UpscalerLDSR(Upscaler):
print("NO LDSR!") print("NO LDSR!")
return img return img
ddim_steps = shared.opts.ldsr_steps ddim_steps = shared.opts.ldsr_steps
pre_scale = shared.opts.ldsr_pre_down
return ldsr.super_resolution(img, ddim_steps, self.scale) return ldsr.super_resolution(img, ddim_steps, self.scale)

View file

@ -98,9 +98,7 @@ class LDSR:
im_og = image im_og = image
width_og, height_og = im_og.size width_og, height_og = im_og.size
# If we can adjust the max upscale size, then the 4 below should be our variable # If we can adjust the max upscale size, then the 4 below should be our variable
print("Foo")
down_sample_rate = target_scale / 4 down_sample_rate = target_scale / 4
print(f"Downsample rate is {down_sample_rate}")
wd = width_og * down_sample_rate wd = width_og * down_sample_rate
hd = height_og * down_sample_rate hd = height_og * down_sample_rate
width_downsampled_pre = int(wd) width_downsampled_pre = int(wd)
@ -111,7 +109,7 @@ class LDSR:
f'Downsampling from [{width_og}, {height_og}] to [{width_downsampled_pre}, {height_downsampled_pre}]') f'Downsampling from [{width_og}, {height_og}] to [{width_downsampled_pre}, {height_downsampled_pre}]')
im_og = im_og.resize((width_downsampled_pre, height_downsampled_pre), Image.LANCZOS) im_og = im_og.resize((width_downsampled_pre, height_downsampled_pre), Image.LANCZOS)
else: else:
print(f"Down sample rate is 1 from {target_scale} / 4") print(f"Down sample rate is 1 from {target_scale} / 4 (Not downsampling)")
logs = self.run(model["model"], im_og, diffusion_steps, eta) logs = self.run(model["model"], im_og, diffusion_steps, eta)
sample = logs["sample"] sample = logs["sample"]

View file

@ -1,10 +1,10 @@
import glob
import os import os
import shutil import shutil
import importlib import importlib
from urllib.parse import urlparse from urllib.parse import urlparse
from basicsr.utils.download_util import load_file_from_url from basicsr.utils.download_util import load_file_from_url
from modules import shared from modules import shared
from modules.upscaler import Upscaler from modules.upscaler import Upscaler
from modules.paths import script_path, models_path from modules.paths import script_path, models_path
@ -41,8 +41,8 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
for place in places: for place in places:
if os.path.exists(place): if os.path.exists(place):
for file in os.listdir(place): for file in glob.iglob(place + '**/**', recursive=True):
full_path = os.path.join(place, file) full_path = file
if os.path.isdir(full_path): if os.path.isdir(full_path):
continue continue
if len(ext_filter) != 0: if len(ext_filter) != 0:
@ -120,16 +120,30 @@ def move_files(src_path: str, dest_path: str, ext_filter: str = None):
def load_upscalers(): def load_upscalers():
sd = shared.script_path
# We can only do this 'magic' method to dynamically load upscalers if they are referenced,
# so we'll try to import any _model.py files before looking in __subclasses__
modules_dir = os.path.join(sd, "modules")
for file in os.listdir(modules_dir):
if "_model.py" in file:
model_name = file.replace("_model.py", "")
full_model = f"modules.{model_name}_model"
try:
importlib.import_module(full_model)
except:
pass
datas = [] datas = []
c_o = vars(shared.cmd_opts)
for cls in Upscaler.__subclasses__(): for cls in Upscaler.__subclasses__():
name = cls.__name__ name = cls.__name__
module_name = cls.__module__ module_name = cls.__module__
module = importlib.import_module(module_name) module = importlib.import_module(module_name)
class_ = getattr(module, name) class_ = getattr(module, name)
cmd_name = f"{name.lower().replace('upscaler', '')}-models-path" cmd_name = f"{name.lower().replace('upscaler', '')}_models_path"
opt_string = None opt_string = None
try: try:
opt_string = shared.opts.__getattr__(cmd_name) if cmd_name in c_o:
opt_string = c_o[cmd_name]
except: except:
pass pass
scaler = class_(opt_string) scaler = class_(opt_string)

View file

@ -20,7 +20,6 @@ path_dirs = [
(os.path.join(sd_path, '../taming-transformers'), 'taming', 'Taming Transformers', []), (os.path.join(sd_path, '../taming-transformers'), 'taming', 'Taming Transformers', []),
(os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer', []), (os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer', []),
(os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []), (os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []),
(os.path.join(sd_path, '../latent-diffusion'), 'LDSR.py', 'LDSR', []),
(os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]), (os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]),
] ]

View file

@ -1,4 +1,3 @@
import contextlib
import json import json
import math import math
import os import os
@ -12,9 +11,8 @@ import cv2
from skimage import exposure from skimage import exposure
import modules.sd_hijack import modules.sd_hijack
from modules import devices, prompt_parser, masking from modules import devices, prompt_parser, masking, sd_samplers, lowvram
from modules.sd_hijack import model_hijack from modules.sd_hijack import model_hijack
from modules.sd_samplers import samplers, samplers_for_img2img
from modules.shared import opts, cmd_opts, state from modules.shared import opts, cmd_opts, state
import modules.shared as shared import modules.shared as shared
import modules.face_restoration import modules.face_restoration
@ -56,7 +54,7 @@ class StableDiffusionProcessing:
self.prompt: str = prompt self.prompt: str = prompt
self.prompt_for_display: str = None self.prompt_for_display: str = None
self.negative_prompt: str = (negative_prompt or "") self.negative_prompt: str = (negative_prompt or "")
self.styles: str = styles self.styles: list = styles or []
self.seed: int = seed self.seed: int = seed
self.subseed: int = subseed self.subseed: int = subseed
self.subseed_strength: float = subseed_strength self.subseed_strength: float = subseed_strength
@ -79,13 +77,13 @@ class StableDiffusionProcessing:
self.paste_to = None self.paste_to = None
self.color_corrections = None self.color_corrections = None
self.denoising_strength: float = 0 self.denoising_strength: float = 0
self.sampler_noise_scheduler_override = None
self.ddim_discretize = opts.ddim_discretize self.ddim_discretize = opts.ddim_discretize
self.s_churn = opts.s_churn self.s_churn = opts.s_churn
self.s_tmin = opts.s_tmin self.s_tmin = opts.s_tmin
self.s_tmax = float('inf') # not representable as a standard ui option self.s_tmax = float('inf') # not representable as a standard ui option
self.s_noise = opts.s_noise self.s_noise = opts.s_noise
if not seed_enable_extras: if not seed_enable_extras:
self.subseed = -1 self.subseed = -1
self.subseed_strength = 0 self.subseed_strength = 0
@ -111,7 +109,7 @@ class Processed:
self.width = p.width self.width = p.width
self.height = p.height self.height = p.height
self.sampler_index = p.sampler_index self.sampler_index = p.sampler_index
self.sampler = samplers[p.sampler_index].name self.sampler = sd_samplers.samplers[p.sampler_index].name
self.cfg_scale = p.cfg_scale self.cfg_scale = p.cfg_scale
self.steps = p.steps self.steps = p.steps
self.batch_size = p.batch_size self.batch_size = p.batch_size
@ -130,7 +128,7 @@ class Processed:
self.s_tmin = p.s_tmin self.s_tmin = p.s_tmin
self.s_tmax = p.s_tmax self.s_tmax = p.s_tmax
self.s_noise = p.s_noise self.s_noise = p.s_noise
self.sampler_noise_scheduler_override = p.sampler_noise_scheduler_override
self.prompt = self.prompt if type(self.prompt) != list else self.prompt[0] self.prompt = self.prompt if type(self.prompt) != list else self.prompt[0]
self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0] self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0]
self.seed = int(self.seed if type(self.seed) != list else self.seed[0]) self.seed = int(self.seed if type(self.seed) != list else self.seed[0])
@ -249,9 +247,16 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
return x return x
def get_fixed_seed(seed):
if seed is None or seed == '' or seed == -1:
return int(random.randrange(4294967294))
return seed
def fix_seed(p): def fix_seed(p):
p.seed = int(random.randrange(4294967294)) if p.seed is None or p.seed == '' or p.seed == -1 else p.seed p.seed = get_fixed_seed(p.seed)
p.subseed = int(random.randrange(4294967294)) if p.subseed is None or p.subseed == '' or p.subseed == -1 else p.subseed p.subseed = get_fixed_seed(p.subseed)
def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration=0, position_in_batch=0): def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration=0, position_in_batch=0):
@ -259,7 +264,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
generation_params = { generation_params = {
"Steps": p.steps, "Steps": p.steps,
"Sampler": samplers[p.sampler_index].name, "Sampler": sd_samplers.samplers[p.sampler_index].name,
"CFG scale": p.cfg_scale, "CFG scale": p.cfg_scale,
"Seed": all_seeds[index], "Seed": all_seeds[index],
"Face restoration": (opts.face_restoration_model if p.restore_faces else None), "Face restoration": (opts.face_restoration_model if p.restore_faces else None),
@ -271,7 +276,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength), "Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
"Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"), "Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
"Denoising strength": getattr(p, 'denoising_strength', None), "Denoising strength": getattr(p, 'denoising_strength', None),
"Eta": (None if p.sampler.eta == p.sampler.default_eta else p.sampler.eta), "Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta),
} }
generation_params.update(p.extra_generation_params) generation_params.update(p.extra_generation_params)
@ -290,13 +295,17 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
assert(len(p.prompt) > 0) assert(len(p.prompt) > 0)
else: else:
assert p.prompt is not None assert p.prompt is not None
devices.torch_gc() devices.torch_gc()
fix_seed(p) seed = get_fixed_seed(p.seed)
subseed = get_fixed_seed(p.subseed)
os.makedirs(p.outpath_samples, exist_ok=True) if p.outpath_samples is not None:
os.makedirs(p.outpath_grids, exist_ok=True) os.makedirs(p.outpath_samples, exist_ok=True)
if p.outpath_grids is not None:
os.makedirs(p.outpath_grids, exist_ok=True)
modules.sd_hijack.model_hijack.apply_circular(p.tiling) modules.sd_hijack.model_hijack.apply_circular(p.tiling)
@ -309,28 +318,28 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
else: else:
all_prompts = p.batch_size * p.n_iter * [p.prompt] all_prompts = p.batch_size * p.n_iter * [p.prompt]
if type(p.seed) == list: if type(seed) == list:
all_seeds = p.seed all_seeds = seed
else: else:
all_seeds = [int(p.seed) + (x if p.subseed_strength == 0 else 0) for x in range(len(all_prompts))] all_seeds = [int(seed) + (x if p.subseed_strength == 0 else 0) for x in range(len(all_prompts))]
if type(p.subseed) == list: if type(subseed) == list:
all_subseeds = p.subseed all_subseeds = subseed
else: else:
all_subseeds = [int(p.subseed) + x for x in range(len(all_prompts))] all_subseeds = [int(subseed) + x for x in range(len(all_prompts))]
def infotext(iteration=0, position_in_batch=0): def infotext(iteration=0, position_in_batch=0):
return create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration, position_in_batch) return create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration, position_in_batch)
if os.path.exists(cmd_opts.embeddings_dir): if os.path.exists(cmd_opts.embeddings_dir):
model_hijack.load_textual_inversion_embeddings(cmd_opts.embeddings_dir, p.sd_model) model_hijack.embedding_db.load_textual_inversion_embeddings()
infotexts = [] infotexts = []
output_images = [] output_images = []
precision_scope = torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext
ema_scope = (contextlib.nullcontext if cmd_opts.lowvram else p.sd_model.ema_scope) with torch.no_grad():
with torch.no_grad(), precision_scope("cuda"), ema_scope(): with devices.autocast():
p.init(all_prompts, all_seeds, all_subseeds) p.init(all_prompts, all_seeds, all_subseeds)
if state.job_count == -1: if state.job_count == -1:
state.job_count = p.n_iter state.job_count = p.n_iter
@ -348,8 +357,9 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
#uc = p.sd_model.get_learned_conditioning(len(prompts) * [p.negative_prompt]) #uc = p.sd_model.get_learned_conditioning(len(prompts) * [p.negative_prompt])
#c = p.sd_model.get_learned_conditioning(prompts) #c = p.sd_model.get_learned_conditioning(prompts)
uc = prompt_parser.get_learned_conditioning(len(prompts) * [p.negative_prompt], p.steps) with devices.autocast():
c = prompt_parser.get_learned_conditioning(prompts, p.steps) uc = prompt_parser.get_learned_conditioning(shared.sd_model, len(prompts) * [p.negative_prompt], p.steps)
c = prompt_parser.get_multicond_learned_conditioning(shared.sd_model, prompts, p.steps)
if len(model_hijack.comments) > 0: if len(model_hijack.comments) > 0:
for comment in model_hijack.comments: for comment in model_hijack.comments:
@ -358,16 +368,27 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
if p.n_iter > 1: if p.n_iter > 1:
shared.state.job = f"Batch {n+1} out of {p.n_iter}" shared.state.job = f"Batch {n+1} out of {p.n_iter}"
samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength) with devices.autocast():
samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength)
if state.interrupted: if state.interrupted:
# if we are interruped, sample returns just noise # if we are interruped, sample returns just noise
# use the image collected previously in sampler loop # use the image collected previously in sampler loop
samples_ddim = shared.state.current_latent samples_ddim = shared.state.current_latent
samples_ddim = samples_ddim.to(devices.dtype)
x_samples_ddim = p.sd_model.decode_first_stage(samples_ddim) x_samples_ddim = p.sd_model.decode_first_stage(samples_ddim)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
del samples_ddim
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
lowvram.send_everything_to_cpu()
devices.torch_gc()
if opts.filter_nsfw: if opts.filter_nsfw:
import modules.safety as safety import modules.safety as safety
x_samples_ddim = modules.safety.censor_batch(x_samples_ddim) x_samples_ddim = modules.safety.censor_batch(x_samples_ddim)
@ -383,6 +404,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
devices.torch_gc() devices.torch_gc()
x_sample = modules.face_restoration.restore_faces(x_sample) x_sample = modules.face_restoration.restore_faces(x_sample)
devices.torch_gc()
image = Image.fromarray(x_sample) image = Image.fromarray(x_sample)
@ -408,9 +430,15 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
if opts.samples_save and not p.do_not_save_samples: if opts.samples_save and not p.do_not_save_samples:
images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p) images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p)
infotexts.append(infotext(n, i)) text = infotext(n, i)
infotexts.append(text)
image.info["parameters"] = text
output_images.append(image) output_images.append(image)
del x_samples_ddim
devices.torch_gc()
state.nextjob() state.nextjob()
p.color_corrections = None p.color_corrections = None
@ -421,7 +449,9 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
grid = images.image_grid(output_images, p.batch_size) grid = images.image_grid(output_images, p.batch_size)
if opts.return_grid: if opts.return_grid:
infotexts.insert(0, infotext()) text = infotext()
infotexts.insert(0, text)
grid.info["parameters"] = text
output_images.insert(0, grid) output_images.insert(0, grid)
index_of_first_image = 1 index_of_first_image = 1
@ -462,7 +492,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.firstphase_height_truncated = int(scale * self.height) self.firstphase_height_truncated = int(scale * self.height)
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength): def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
self.sampler = samplers[self.sampler_index].constructor(self.sd_model) self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
if not self.enable_hr: if not self.enable_hr:
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
@ -505,13 +535,14 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
shared.state.nextjob() shared.state.nextjob()
self.sampler = samplers[self.sampler_index].constructor(self.sd_model) self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
# GC now before running the next img2img to prevent running out of memory # GC now before running the next img2img to prevent running out of memory
x = None x = None
devices.torch_gc() devices.torch_gc()
samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps) samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps)
return samples return samples
@ -540,7 +571,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
self.nmask = None self.nmask = None
def init(self, all_prompts, all_seeds, all_subseeds): def init(self, all_prompts, all_seeds, all_subseeds):
self.sampler = samplers_for_img2img[self.sampler_index].constructor(self.sd_model) self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers_for_img2img, self.sampler_index, self.sd_model)
crop_region = None crop_region = None
if self.image_mask is not None: if self.image_mask is not None:
@ -647,4 +678,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
if self.mask is not None: if self.mask is not None:
samples = samples * self.nmask + self.init_latent * self.mask samples = samples * self.nmask + self.init_latent * self.mask
del x
devices.torch_gc()
return samples return samples

View file

@ -1,19 +1,7 @@
import re import re
from collections import namedtuple from collections import namedtuple
import torch from typing import List
import lark
import modules.shared as shared
re_prompt = re.compile(r'''
(.*?)
\[
([^]:]+):
(?:([^]:]*):)?
([0-9]*\.?[0-9]+)
]
|
(.+)
''', re.X)
# a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]" # a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]"
# will be represented with prompt_schedule like this (assuming steps=100): # will be represented with prompt_schedule like this (assuming steps=100):
@ -23,71 +11,112 @@ re_prompt = re.compile(r'''
# [75, 'fantasy landscape with a lake and an oak in background masterful'] # [75, 'fantasy landscape with a lake and an oak in background masterful']
# [100, 'fantasy landscape with a lake and a christmas tree in background masterful'] # [100, 'fantasy landscape with a lake and a christmas tree in background masterful']
schedule_parser = lark.Lark(r"""
!start: (prompt | /[][():]/+)*
prompt: (emphasized | scheduled | plain | WHITESPACE)*
!emphasized: "(" prompt ")"
| "(" prompt ":" prompt ")"
| "[" prompt "]"
scheduled: "[" [prompt ":"] prompt ":" [WHITESPACE] NUMBER "]"
WHITESPACE: /\s+/
plain: /([^\\\[\]():]|\\.)+/
%import common.SIGNED_NUMBER -> NUMBER
""")
def get_learned_conditioning_prompt_schedules(prompts, steps): def get_learned_conditioning_prompt_schedules(prompts, steps):
res = [] """
cache = {} >>> g = lambda p: get_learned_conditioning_prompt_schedules([p], 10)[0]
>>> g("test")
[[10, 'test']]
>>> g("a [b:3]")
[[3, 'a '], [10, 'a b']]
>>> g("a [b: 3]")
[[3, 'a '], [10, 'a b']]
>>> g("a [[[b]]:2]")
[[2, 'a '], [10, 'a [[b]]']]
>>> g("[(a:2):3]")
[[3, ''], [10, '(a:2)']]
>>> g("a [b : c : 1] d")
[[1, 'a b d'], [10, 'a c d']]
>>> g("a[b:[c:d:2]:1]e")
[[1, 'abe'], [2, 'ace'], [10, 'ade']]
>>> g("a [unbalanced")
[[10, 'a [unbalanced']]
>>> g("a [b:.5] c")
[[5, 'a c'], [10, 'a b c']]
>>> g("a [{b|d{:.5] c") # not handling this right now
[[5, 'a c'], [10, 'a {b|d{ c']]
>>> g("((a][:b:c [d:3]")
[[3, '((a][:b:c '], [10, '((a][:b:c d']]
"""
for prompt in prompts: def collect_steps(steps, tree):
prompt_schedule: list[list[str | int]] = [[steps, ""]] l = [steps]
class CollectSteps(lark.Visitor):
def scheduled(self, tree):
tree.children[-1] = float(tree.children[-1])
if tree.children[-1] < 1:
tree.children[-1] *= steps
tree.children[-1] = min(steps, int(tree.children[-1]))
l.append(tree.children[-1])
CollectSteps().visit(tree)
return sorted(set(l))
cached = cache.get(prompt, None) def at_step(step, tree):
if cached is not None: class AtStep(lark.Transformer):
res.append(cached) def scheduled(self, args):
continue before, after, _, when = args
yield before or () if step <= when else after
def start(self, args):
def flatten(x):
if type(x) == str:
yield x
else:
for gen in x:
yield from flatten(gen)
return ''.join(flatten(args))
def plain(self, args):
yield args[0].value
def __default__(self, data, children, meta):
for child in children:
yield from child
return AtStep().transform(tree)
for m in re_prompt.finditer(prompt): def get_schedule(prompt):
plaintext = m.group(1) if m.group(5) is None else m.group(5) try:
concept_from = m.group(2) tree = schedule_parser.parse(prompt)
concept_to = m.group(3) except lark.exceptions.LarkError as e:
if concept_to is None: if 0:
concept_to = concept_from import traceback
concept_from = "" traceback.print_exc()
swap_position = float(m.group(4)) if m.group(4) is not None else None return [[steps, prompt]]
return [[t, at_step(t, tree)] for t in collect_steps(steps, tree)]
if swap_position is not None: promptdict = {prompt: get_schedule(prompt) for prompt in set(prompts)}
if swap_position < 1: return [promptdict[prompt] for prompt in prompts]
swap_position = swap_position * steps
swap_position = int(min(swap_position, steps))
swap_index = None
found_exact_index = False
for i in range(len(prompt_schedule)):
end_step = prompt_schedule[i][0]
prompt_schedule[i][1] += plaintext
if swap_position is not None and swap_index is None:
if swap_position == end_step:
swap_index = i
found_exact_index = True
if swap_position < end_step:
swap_index = i
if swap_index is not None:
if not found_exact_index:
prompt_schedule.insert(swap_index, [swap_position, prompt_schedule[swap_index][1]])
for i in range(len(prompt_schedule)):
end_step = prompt_schedule[i][0]
must_replace = swap_position < end_step
prompt_schedule[i][1] += concept_to if must_replace else concept_from
res.append(prompt_schedule)
cache[prompt] = prompt_schedule
#for t in prompt_schedule:
# print(t)
return res
ScheduledPromptConditioning = namedtuple("ScheduledPromptConditioning", ["end_at_step", "cond"]) ScheduledPromptConditioning = namedtuple("ScheduledPromptConditioning", ["end_at_step", "cond"])
ScheduledPromptBatch = namedtuple("ScheduledPromptBatch", ["shape", "schedules"])
def get_learned_conditioning(prompts, steps): def get_learned_conditioning(model, prompts, steps):
"""converts a list of prompts into a list of prompt schedules - each schedule is a list of ScheduledPromptConditioning, specifying the comdition (cond),
and the sampling step at which this condition is to be replaced by the next one.
Input:
(model, ['a red crown', 'a [blue:green:5] jeweled crown'], 20)
Output:
[
[
ScheduledPromptConditioning(end_at_step=20, cond=tensor([[-0.3886, 0.0229, -0.0523, ..., -0.4901, -0.3066, 0.0674], ..., [ 0.3317, -0.5102, -0.4066, ..., 0.4119, -0.7647, -1.0160]], device='cuda:0'))
],
[
ScheduledPromptConditioning(end_at_step=5, cond=tensor([[-0.3886, 0.0229, -0.0522, ..., -0.4901, -0.3067, 0.0673], ..., [-0.0192, 0.3867, -0.4644, ..., 0.1135, -0.3696, -0.4625]], device='cuda:0')),
ScheduledPromptConditioning(end_at_step=20, cond=tensor([[-0.3886, 0.0229, -0.0522, ..., -0.4901, -0.3067, 0.0673], ..., [-0.7352, -0.4356, -0.7888, ..., 0.6994, -0.4312, -1.2593]], device='cuda:0'))
]
]
"""
res = [] res = []
prompt_schedules = get_learned_conditioning_prompt_schedules(prompts, steps) prompt_schedules = get_learned_conditioning_prompt_schedules(prompts, steps)
@ -101,7 +130,7 @@ def get_learned_conditioning(prompts, steps):
continue continue
texts = [x[1] for x in prompt_schedule] texts = [x[1] for x in prompt_schedule]
conds = shared.sd_model.get_learned_conditioning(texts) conds = model.get_learned_conditioning(texts)
cond_schedule = [] cond_schedule = []
for i, (end_at_step, text) in enumerate(prompt_schedule): for i, (end_at_step, text) in enumerate(prompt_schedule):
@ -110,22 +139,109 @@ def get_learned_conditioning(prompts, steps):
cache[prompt] = cond_schedule cache[prompt] = cond_schedule
res.append(cond_schedule) res.append(cond_schedule)
return ScheduledPromptBatch((len(prompts),) + res[0][0].cond.shape, res) return res
def reconstruct_cond_batch(c: ScheduledPromptBatch, current_step): re_AND = re.compile(r"\bAND\b")
res = torch.zeros(c.shape, device=shared.device, dtype=next(shared.sd_model.parameters()).dtype) re_weight = re.compile(r"^(.*?)(?:\s*:\s*([-+]?(?:\d+\.?|\d*\.\d+)))?\s*$")
for i, cond_schedule in enumerate(c.schedules):
def get_multicond_prompt_list(prompts):
res_indexes = []
prompt_flat_list = []
prompt_indexes = {}
for prompt in prompts:
subprompts = re_AND.split(prompt)
indexes = []
for subprompt in subprompts:
match = re_weight.search(subprompt)
text, weight = match.groups() if match is not None else (subprompt, 1.0)
weight = float(weight) if weight is not None else 1.0
index = prompt_indexes.get(text, None)
if index is None:
index = len(prompt_flat_list)
prompt_flat_list.append(text)
prompt_indexes[text] = index
indexes.append((index, weight))
res_indexes.append(indexes)
return res_indexes, prompt_flat_list, prompt_indexes
class ComposableScheduledPromptConditioning:
def __init__(self, schedules, weight=1.0):
self.schedules: List[ScheduledPromptConditioning] = schedules
self.weight: float = weight
class MulticondLearnedConditioning:
def __init__(self, shape, batch):
self.shape: tuple = shape # the shape field is needed to send this object to DDIM/PLMS
self.batch: List[List[ComposableScheduledPromptConditioning]] = batch
def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearnedConditioning:
"""same as get_learned_conditioning, but returns a list of ScheduledPromptConditioning along with the weight objects for each prompt.
For each prompt, the list is obtained by splitting the prompt using the AND separator.
https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/
"""
res_indexes, prompt_flat_list, prompt_indexes = get_multicond_prompt_list(prompts)
learned_conditioning = get_learned_conditioning(model, prompt_flat_list, steps)
res = []
for indexes in res_indexes:
res.append([ComposableScheduledPromptConditioning(learned_conditioning[i], weight) for i, weight in indexes])
return MulticondLearnedConditioning(shape=(len(prompts),), batch=res)
def reconstruct_cond_batch(c: List[List[ScheduledPromptConditioning]], current_step):
param = c[0][0].cond
res = torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype)
for i, cond_schedule in enumerate(c):
target_index = 0 target_index = 0
for curret_index, (end_at, cond) in enumerate(cond_schedule): for current, (end_at, cond) in enumerate(cond_schedule):
if current_step <= end_at: if current_step <= end_at:
target_index = curret_index target_index = current
break break
res[i] = cond_schedule[target_index].cond res[i] = cond_schedule[target_index].cond
return res return res
def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step):
param = c.batch[0][0].schedules[0].cond
tensors = []
conds_list = []
for batch_no, composable_prompts in enumerate(c.batch):
conds_for_batch = []
for cond_index, composable_prompt in enumerate(composable_prompts):
target_index = 0
for current, (end_at, cond) in enumerate(composable_prompt.schedules):
if current_step <= end_at:
target_index = current
break
conds_for_batch.append((len(tensors), composable_prompt.weight))
tensors.append(composable_prompt.schedules[target_index].cond)
conds_list.append(conds_for_batch)
return conds_list, torch.stack(tensors).to(device=param.device, dtype=param.dtype)
re_attention = re.compile(r""" re_attention = re.compile(r"""
\\\(| \\\(|
\\\)| \\\)|
@ -157,23 +273,26 @@ def parse_prompt_attention(text):
\\ - literal character '\' \\ - literal character '\'
anything else - just text anything else - just text
Example: >>> parse_prompt_attention('normal text')
[['normal text', 1.0]]
'a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).' >>> parse_prompt_attention('an (important) word')
[['an ', 1.0], ['important', 1.1], [' word', 1.0]]
produces: >>> parse_prompt_attention('(unbalanced')
[['unbalanced', 1.1]]
[ >>> parse_prompt_attention('\(literal\]')
['a ', 1.0], [['(literal]', 1.0]]
['house', 1.5730000000000004], >>> parse_prompt_attention('(unnecessary)(parens)')
[' ', 1.1], [['unnecessaryparens', 1.1]]
['on', 1.0], >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
[' a ', 1.1], [['a ', 1.0],
['hill', 0.55], ['house', 1.5730000000000004],
[', sun, ', 1.1], [' ', 1.1],
['sky', 1.4641000000000006], ['on', 1.0],
['.', 1.1] [' a ', 1.1],
] ['hill', 0.55],
[', sun, ', 1.1],
['sky', 1.4641000000000006],
['.', 1.1]]
""" """
res = [] res = []
@ -215,4 +334,19 @@ def parse_prompt_attention(text):
if len(res) == 0: if len(res) == 0:
res = [["", 1.0]] res = [["", 1.0]]
# merge runs of identical weights
i = 0
while i + 1 < len(res):
if res[i][1] == res[i + 1][1]:
res[i][0] += res[i + 1][0]
res.pop(i + 1)
else:
i += 1
return res return res
if __name__ == "__main__":
import doctest
doctest.testmod(optionflags=doctest.NORMALIZE_WHITESPACE)
else:
import torch # doctest faster

View file

@ -162,6 +162,40 @@ class ScriptRunner:
return processed return processed
def reload_sources(self):
for si, script in list(enumerate(self.scripts)):
with open(script.filename, "r", encoding="utf8") as file:
args_from = script.args_from
args_to = script.args_to
filename = script.filename
text = file.read()
from types import ModuleType
compiled = compile(text, filename, 'exec')
module = ModuleType(script.filename)
exec(compiled, module.__dict__)
for key, script_class in module.__dict__.items():
if type(script_class) == type and issubclass(script_class, Script):
self.scripts[si] = script_class()
self.scripts[si].filename = filename
self.scripts[si].args_from = args_from
self.scripts[si].args_to = args_to
scripts_txt2img = ScriptRunner() scripts_txt2img = ScriptRunner()
scripts_img2img = ScriptRunner() scripts_img2img = ScriptRunner()
def reload_script_body_only():
scripts_txt2img.reload_sources()
scripts_img2img.reload_sources()
def reload_scripts(basedir):
global scripts_txt2img, scripts_img2img
scripts_data.clear()
load_scripts(basedir)
scripts_txt2img = ScriptRunner()
scripts_img2img = ScriptRunner()

90
modules/scunet_model.py Normal file
View file

@ -0,0 +1,90 @@
import os.path
import sys
import traceback
import PIL.Image
import numpy as np
import torch
from basicsr.utils.download_util import load_file_from_url
import modules.upscaler
from modules import devices, modelloader
from modules.paths import models_path
from modules.scunet_model_arch import SCUNet as net
class UpscalerScuNET(modules.upscaler.Upscaler):
def __init__(self, dirname):
self.name = "ScuNET"
self.model_path = os.path.join(models_path, self.name)
self.model_name = "ScuNET GAN"
self.model_name2 = "ScuNET PSNR"
self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/scunet_color_real_gan.pth"
self.model_url2 = "https://github.com/cszn/KAIR/releases/download/v1.0/scunet_color_real_psnr.pth"
self.user_path = dirname
super().__init__()
model_paths = self.find_models(ext_filter=[".pth"])
scalers = []
add_model2 = True
for file in model_paths:
if "http" in file:
name = self.model_name
else:
name = modelloader.friendly_name(file)
if name == self.model_name2 or file == self.model_url2:
add_model2 = False
try:
scaler_data = modules.upscaler.UpscalerData(name, file, self, 4)
scalers.append(scaler_data)
except Exception:
print(f"Error loading ScuNET model: {file}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
if add_model2:
scaler_data2 = modules.upscaler.UpscalerData(self.model_name2, self.model_url2, self)
scalers.append(scaler_data2)
self.scalers = scalers
def do_upscale(self, img: PIL.Image, selected_file):
torch.cuda.empty_cache()
model = self.load_model(selected_file)
if model is None:
return img
device = devices.device_scunet
img = np.array(img)
img = img[:, :, ::-1]
img = np.moveaxis(img, 2, 0) / 255
img = torch.from_numpy(img).float()
img = img.unsqueeze(0).to(device)
img = img.to(device)
with torch.no_grad():
output = model(img)
output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
output = 255. * np.moveaxis(output, 0, 2)
output = output.astype(np.uint8)
output = output[:, :, ::-1]
torch.cuda.empty_cache()
return PIL.Image.fromarray(output, 'RGB')
def load_model(self, path: str):
device = devices.device_scunet
if "http" in path:
filename = load_file_from_url(url=self.model_url, model_dir=self.model_path, file_name="%s.pth" % self.name,
progress=True)
else:
filename = path
if not os.path.exists(os.path.join(self.model_path, filename)) or filename is None:
print(f"ScuNET: Unable to load model from {filename}", file=sys.stderr)
return None
model = net(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64)
model.load_state_dict(torch.load(filename), strict=True)
model.eval()
for k, v in model.named_parameters():
v.requires_grad = False
model = model.to(device)
return model

View file

@ -0,0 +1,265 @@
# -*- coding: utf-8 -*-
import numpy as np
import torch
import torch.nn as nn
from einops import rearrange
from einops.layers.torch import Rearrange
from timm.models.layers import trunc_normal_, DropPath
class WMSA(nn.Module):
""" Self-attention module in Swin Transformer
"""
def __init__(self, input_dim, output_dim, head_dim, window_size, type):
super(WMSA, self).__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.head_dim = head_dim
self.scale = self.head_dim ** -0.5
self.n_heads = input_dim // head_dim
self.window_size = window_size
self.type = type
self.embedding_layer = nn.Linear(self.input_dim, 3 * self.input_dim, bias=True)
self.relative_position_params = nn.Parameter(
torch.zeros((2 * window_size - 1) * (2 * window_size - 1), self.n_heads))
self.linear = nn.Linear(self.input_dim, self.output_dim)
trunc_normal_(self.relative_position_params, std=.02)
self.relative_position_params = torch.nn.Parameter(
self.relative_position_params.view(2 * window_size - 1, 2 * window_size - 1, self.n_heads).transpose(1,
2).transpose(
0, 1))
def generate_mask(self, h, w, p, shift):
""" generating the mask of SW-MSA
Args:
shift: shift parameters in CyclicShift.
Returns:
attn_mask: should be (1 1 w p p),
"""
# supporting sqaure.
attn_mask = torch.zeros(h, w, p, p, p, p, dtype=torch.bool, device=self.relative_position_params.device)
if self.type == 'W':
return attn_mask
s = p - shift
attn_mask[-1, :, :s, :, s:, :] = True
attn_mask[-1, :, s:, :, :s, :] = True
attn_mask[:, -1, :, :s, :, s:] = True
attn_mask[:, -1, :, s:, :, :s] = True
attn_mask = rearrange(attn_mask, 'w1 w2 p1 p2 p3 p4 -> 1 1 (w1 w2) (p1 p2) (p3 p4)')
return attn_mask
def forward(self, x):
""" Forward pass of Window Multi-head Self-attention module.
Args:
x: input tensor with shape of [b h w c];
attn_mask: attention mask, fill -inf where the value is True;
Returns:
output: tensor shape [b h w c]
"""
if self.type != 'W': x = torch.roll(x, shifts=(-(self.window_size // 2), -(self.window_size // 2)), dims=(1, 2))
x = rearrange(x, 'b (w1 p1) (w2 p2) c -> b w1 w2 p1 p2 c', p1=self.window_size, p2=self.window_size)
h_windows = x.size(1)
w_windows = x.size(2)
# sqaure validation
# assert h_windows == w_windows
x = rearrange(x, 'b w1 w2 p1 p2 c -> b (w1 w2) (p1 p2) c', p1=self.window_size, p2=self.window_size)
qkv = self.embedding_layer(x)
q, k, v = rearrange(qkv, 'b nw np (threeh c) -> threeh b nw np c', c=self.head_dim).chunk(3, dim=0)
sim = torch.einsum('hbwpc,hbwqc->hbwpq', q, k) * self.scale
# Adding learnable relative embedding
sim = sim + rearrange(self.relative_embedding(), 'h p q -> h 1 1 p q')
# Using Attn Mask to distinguish different subwindows.
if self.type != 'W':
attn_mask = self.generate_mask(h_windows, w_windows, self.window_size, shift=self.window_size // 2)
sim = sim.masked_fill_(attn_mask, float("-inf"))
probs = nn.functional.softmax(sim, dim=-1)
output = torch.einsum('hbwij,hbwjc->hbwic', probs, v)
output = rearrange(output, 'h b w p c -> b w p (h c)')
output = self.linear(output)
output = rearrange(output, 'b (w1 w2) (p1 p2) c -> b (w1 p1) (w2 p2) c', w1=h_windows, p1=self.window_size)
if self.type != 'W': output = torch.roll(output, shifts=(self.window_size // 2, self.window_size // 2),
dims=(1, 2))
return output
def relative_embedding(self):
cord = torch.tensor(np.array([[i, j] for i in range(self.window_size) for j in range(self.window_size)]))
relation = cord[:, None, :] - cord[None, :, :] + self.window_size - 1
# negative is allowed
return self.relative_position_params[:, relation[:, :, 0].long(), relation[:, :, 1].long()]
class Block(nn.Module):
def __init__(self, input_dim, output_dim, head_dim, window_size, drop_path, type='W', input_resolution=None):
""" SwinTransformer Block
"""
super(Block, self).__init__()
self.input_dim = input_dim
self.output_dim = output_dim
assert type in ['W', 'SW']
self.type = type
if input_resolution <= window_size:
self.type = 'W'
self.ln1 = nn.LayerNorm(input_dim)
self.msa = WMSA(input_dim, input_dim, head_dim, window_size, self.type)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.ln2 = nn.LayerNorm(input_dim)
self.mlp = nn.Sequential(
nn.Linear(input_dim, 4 * input_dim),
nn.GELU(),
nn.Linear(4 * input_dim, output_dim),
)
def forward(self, x):
x = x + self.drop_path(self.msa(self.ln1(x)))
x = x + self.drop_path(self.mlp(self.ln2(x)))
return x
class ConvTransBlock(nn.Module):
def __init__(self, conv_dim, trans_dim, head_dim, window_size, drop_path, type='W', input_resolution=None):
""" SwinTransformer and Conv Block
"""
super(ConvTransBlock, self).__init__()
self.conv_dim = conv_dim
self.trans_dim = trans_dim
self.head_dim = head_dim
self.window_size = window_size
self.drop_path = drop_path
self.type = type
self.input_resolution = input_resolution
assert self.type in ['W', 'SW']
if self.input_resolution <= self.window_size:
self.type = 'W'
self.trans_block = Block(self.trans_dim, self.trans_dim, self.head_dim, self.window_size, self.drop_path,
self.type, self.input_resolution)
self.conv1_1 = nn.Conv2d(self.conv_dim + self.trans_dim, self.conv_dim + self.trans_dim, 1, 1, 0, bias=True)
self.conv1_2 = nn.Conv2d(self.conv_dim + self.trans_dim, self.conv_dim + self.trans_dim, 1, 1, 0, bias=True)
self.conv_block = nn.Sequential(
nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False),
nn.ReLU(True),
nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False)
)
def forward(self, x):
conv_x, trans_x = torch.split(self.conv1_1(x), (self.conv_dim, self.trans_dim), dim=1)
conv_x = self.conv_block(conv_x) + conv_x
trans_x = Rearrange('b c h w -> b h w c')(trans_x)
trans_x = self.trans_block(trans_x)
trans_x = Rearrange('b h w c -> b c h w')(trans_x)
res = self.conv1_2(torch.cat((conv_x, trans_x), dim=1))
x = x + res
return x
class SCUNet(nn.Module):
# def __init__(self, in_nc=3, config=[2, 2, 2, 2, 2, 2, 2], dim=64, drop_path_rate=0.0, input_resolution=256):
def __init__(self, in_nc=3, config=None, dim=64, drop_path_rate=0.0, input_resolution=256):
super(SCUNet, self).__init__()
if config is None:
config = [2, 2, 2, 2, 2, 2, 2]
self.config = config
self.dim = dim
self.head_dim = 32
self.window_size = 8
# drop path rate for each layer
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(config))]
self.m_head = [nn.Conv2d(in_nc, dim, 3, 1, 1, bias=False)]
begin = 0
self.m_down1 = [ConvTransBlock(dim // 2, dim // 2, self.head_dim, self.window_size, dpr[i + begin],
'W' if not i % 2 else 'SW', input_resolution)
for i in range(config[0])] + \
[nn.Conv2d(dim, 2 * dim, 2, 2, 0, bias=False)]
begin += config[0]
self.m_down2 = [ConvTransBlock(dim, dim, self.head_dim, self.window_size, dpr[i + begin],
'W' if not i % 2 else 'SW', input_resolution // 2)
for i in range(config[1])] + \
[nn.Conv2d(2 * dim, 4 * dim, 2, 2, 0, bias=False)]
begin += config[1]
self.m_down3 = [ConvTransBlock(2 * dim, 2 * dim, self.head_dim, self.window_size, dpr[i + begin],
'W' if not i % 2 else 'SW', input_resolution // 4)
for i in range(config[2])] + \
[nn.Conv2d(4 * dim, 8 * dim, 2, 2, 0, bias=False)]
begin += config[2]
self.m_body = [ConvTransBlock(4 * dim, 4 * dim, self.head_dim, self.window_size, dpr[i + begin],
'W' if not i % 2 else 'SW', input_resolution // 8)
for i in range(config[3])]
begin += config[3]
self.m_up3 = [nn.ConvTranspose2d(8 * dim, 4 * dim, 2, 2, 0, bias=False), ] + \
[ConvTransBlock(2 * dim, 2 * dim, self.head_dim, self.window_size, dpr[i + begin],
'W' if not i % 2 else 'SW', input_resolution // 4)
for i in range(config[4])]
begin += config[4]
self.m_up2 = [nn.ConvTranspose2d(4 * dim, 2 * dim, 2, 2, 0, bias=False), ] + \
[ConvTransBlock(dim, dim, self.head_dim, self.window_size, dpr[i + begin],
'W' if not i % 2 else 'SW', input_resolution // 2)
for i in range(config[5])]
begin += config[5]
self.m_up1 = [nn.ConvTranspose2d(2 * dim, dim, 2, 2, 0, bias=False), ] + \
[ConvTransBlock(dim // 2, dim // 2, self.head_dim, self.window_size, dpr[i + begin],
'W' if not i % 2 else 'SW', input_resolution)
for i in range(config[6])]
self.m_tail = [nn.Conv2d(dim, in_nc, 3, 1, 1, bias=False)]
self.m_head = nn.Sequential(*self.m_head)
self.m_down1 = nn.Sequential(*self.m_down1)
self.m_down2 = nn.Sequential(*self.m_down2)
self.m_down3 = nn.Sequential(*self.m_down3)
self.m_body = nn.Sequential(*self.m_body)
self.m_up3 = nn.Sequential(*self.m_up3)
self.m_up2 = nn.Sequential(*self.m_up2)
self.m_up1 = nn.Sequential(*self.m_up1)
self.m_tail = nn.Sequential(*self.m_tail)
# self.apply(self._init_weights)
def forward(self, x0):
h, w = x0.size()[-2:]
paddingBottom = int(np.ceil(h / 64) * 64 - h)
paddingRight = int(np.ceil(w / 64) * 64 - w)
x0 = nn.ReplicationPad2d((0, paddingRight, 0, paddingBottom))(x0)
x1 = self.m_head(x0)
x2 = self.m_down1(x1)
x3 = self.m_down2(x2)
x4 = self.m_down3(x3)
x = self.m_body(x4)
x = self.m_up3(x + x4)
x = self.m_up2(x + x3)
x = self.m_up1(x + x2)
x = self.m_tail(x + x1)
x = x[..., :h, :w]
return x
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)

View file

@ -5,240 +5,44 @@ import traceback
import torch import torch
import numpy as np import numpy as np
from torch import einsum from torch import einsum
from torch.nn.functional import silu
from modules import prompt_parser import modules.textual_inversion.textual_inversion
from modules import prompt_parser, devices, sd_hijack_optimizations, shared
from modules.shared import opts, device, cmd_opts from modules.shared import opts, device, cmd_opts
from ldm.util import default
from einops import rearrange
import ldm.modules.attention import ldm.modules.attention
import ldm.modules.diffusionmodules.model import ldm.modules.diffusionmodules.model
attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
def split_cross_attention_forward_v1(self, x, context=None, mask=None): diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
h = self.heads
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
del context, x
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
for i in range(0, q.shape[0], 2):
end = i + 2
s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
s1 *= self.scale
s2 = s1.softmax(dim=-1)
del s1
r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
del s2
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
del r1
return self.to_out(r2)
# taken from https://github.com/Doggettx/stable-diffusion def apply_optimizations():
def split_cross_attention_forward(self, x, context=None, mask=None): ldm.modules.diffusionmodules.model.nonlinearity = silu
h = self.heads
q_in = self.to_q(x) if cmd_opts.opt_split_attention_v1:
context = default(context, x) ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
k_in = self.to_k(context) * self.scale elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
v_in = self.to_v(context) ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward
del context, x ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
del q_in, k_in, v_in
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) def undo_optimizations():
ldm.modules.attention.CrossAttention.forward = attention_CrossAttention_forward
ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity
ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
stats = torch.cuda.memory_stats(q.device)
mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_cuda + mem_free_torch
gb = 1024 ** 3
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
modifier = 3 if q.element_size() == 2 else 2.5
mem_required = tensor_size * modifier
steps = 1
if mem_required > mem_free_total:
steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
if steps > 64:
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
for i in range(0, q.shape[1], slice_size):
end = i + slice_size
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
s2 = s1.softmax(dim=-1, dtype=q.dtype)
del s1
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
del s2
del q, k, v
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
del r1
return self.to_out(r2)
def nonlinearity_hijack(x):
# swish
t = torch.sigmoid(x)
x *= t
del t
return x
def cross_attention_attnblock_forward(self, x):
h_ = x
h_ = self.norm(h_)
q1 = self.q(h_)
k1 = self.k(h_)
v = self.v(h_)
# compute attention
b, c, h, w = q1.shape
q2 = q1.reshape(b, c, h*w)
del q1
q = q2.permute(0, 2, 1) # b,hw,c
del q2
k = k1.reshape(b, c, h*w) # b,c,hw
del k1
h_ = torch.zeros_like(k, device=q.device)
stats = torch.cuda.memory_stats(q.device)
mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_cuda + mem_free_torch
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
mem_required = tensor_size * 2.5
steps = 1
if mem_required > mem_free_total:
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
for i in range(0, q.shape[1], slice_size):
end = i + slice_size
w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w2 = w1 * (int(c)**(-0.5))
del w1
w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype)
del w2
# attend to values
v1 = v.reshape(b, c, h*w)
w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
del w3
h_[:, :, i:end] = torch.bmm(v1, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
del v1, w4
h2 = h_.reshape(b, c, h, w)
del h_
h3 = self.proj_out(h2)
del h2
h3 += x
return h3
class StableDiffusionModelHijack: class StableDiffusionModelHijack:
ids_lookup = {}
word_embeddings = {}
word_embeddings_checksums = {}
fixes = None fixes = None
comments = [] comments = []
dir_mtime = None
layers = None layers = None
circular_enabled = False circular_enabled = False
clip = None clip = None
def load_textual_inversion_embeddings(self, dirname, model): embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase(cmd_opts.embeddings_dir)
mt = os.path.getmtime(dirname)
if self.dir_mtime is not None and mt <= self.dir_mtime:
return
self.dir_mtime = mt
self.ids_lookup.clear()
self.word_embeddings.clear()
tokenizer = model.cond_stage_model.tokenizer
def const_hash(a):
r = 0
for v in a:
r = (r * 281 ^ int(v) * 997) & 0xFFFFFFFF
return r
def process_file(path, filename):
name = os.path.splitext(filename)[0]
data = torch.load(path, map_location="cpu")
# textual inversion embeddings
if 'string_to_param' in data:
param_dict = data['string_to_param']
if hasattr(param_dict, '_parameters'):
param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
emb = next(iter(param_dict.items()))[1]
# diffuser concepts
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
emb = next(iter(data.values()))
if len(emb.shape) == 1:
emb = emb.unsqueeze(0)
self.word_embeddings[name] = emb.detach().to(device)
self.word_embeddings_checksums[name] = f'{const_hash(emb.reshape(-1)*100)&0xffff:04x}'
ids = tokenizer([name], add_special_tokens=False)['input_ids'][0]
first_id = ids[0]
if first_id not in self.ids_lookup:
self.ids_lookup[first_id] = []
self.ids_lookup[first_id].append((ids, name))
for fn in os.listdir(dirname):
try:
process_file(os.path.join(dirname, fn), fn)
except Exception:
print(f"Error loading emedding {fn}:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
continue
print(f"Loaded a total of {len(self.word_embeddings)} textual inversion embeddings.")
def hijack(self, m): def hijack(self, m):
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
@ -248,12 +52,7 @@ class StableDiffusionModelHijack:
self.clip = m.cond_stage_model self.clip = m.cond_stage_model
if cmd_opts.opt_split_attention_v1: apply_optimizations()
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward
ldm.modules.diffusionmodules.model.nonlinearity = nonlinearity_hijack
ldm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward
def flatten(el): def flatten(el):
flattened = [flatten(children) for children in el.children()] flattened = [flatten(children) for children in el.children()]
@ -291,7 +90,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
def __init__(self, wrapped, hijack): def __init__(self, wrapped, hijack):
super().__init__() super().__init__()
self.wrapped = wrapped self.wrapped = wrapped
self.hijack = hijack self.hijack: StableDiffusionModelHijack = hijack
self.tokenizer = wrapped.tokenizer self.tokenizer = wrapped.tokenizer
self.max_length = wrapped.max_length self.max_length = wrapped.max_length
self.token_mults = {} self.token_mults = {}
@ -312,7 +111,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
if mult != 1.0: if mult != 1.0:
self.token_mults[ident] = mult self.token_mults[ident] = mult
def tokenize_line(self, line, used_custom_terms, hijack_comments): def tokenize_line(self, line, used_custom_terms, hijack_comments):
id_start = self.wrapped.tokenizer.bos_token_id id_start = self.wrapped.tokenizer.bos_token_id
id_end = self.wrapped.tokenizer.eos_token_id id_end = self.wrapped.tokenizer.eos_token_id
@ -334,28 +132,19 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
while i < len(tokens): while i < len(tokens):
token = tokens[i] token = tokens[i]
possible_matches = self.hijack.ids_lookup.get(token, None) embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
if possible_matches is None: if embedding is None:
remade_tokens.append(token) remade_tokens.append(token)
multipliers.append(weight) multipliers.append(weight)
i += 1
else: else:
found = False emb_len = int(embedding.vec.shape[0])
for ids, word in possible_matches: fixes.append((len(remade_tokens), embedding))
if tokens[i:i + len(ids)] == ids: remade_tokens += [0] * emb_len
emb_len = int(self.hijack.word_embeddings[word].shape[0]) multipliers += [weight] * emb_len
fixes.append((len(remade_tokens), word)) used_custom_terms.append((embedding.name, embedding.checksum()))
remade_tokens += [0] * emb_len i += embedding_length_in_tokens
multipliers += [weight] * emb_len
i += len(ids) - 1
found = True
used_custom_terms.append((word, self.hijack.word_embeddings_checksums[word]))
break
if not found:
remade_tokens.append(token)
multipliers.append(weight)
i += 1
if len(remade_tokens) > maxlen - 2: if len(remade_tokens) > maxlen - 2:
vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()} vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
@ -426,32 +215,23 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
while i < len(tokens): while i < len(tokens):
token = tokens[i] token = tokens[i]
possible_matches = self.hijack.ids_lookup.get(token, None) embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
mult_change = self.token_mults.get(token) if opts.enable_emphasis else None mult_change = self.token_mults.get(token) if opts.enable_emphasis else None
if mult_change is not None: if mult_change is not None:
mult *= mult_change mult *= mult_change
elif possible_matches is None: i += 1
elif embedding is None:
remade_tokens.append(token) remade_tokens.append(token)
multipliers.append(mult) multipliers.append(mult)
i += 1
else: else:
found = False emb_len = int(embedding.vec.shape[0])
for ids, word in possible_matches: fixes.append((len(remade_tokens), embedding))
if tokens[i:i+len(ids)] == ids: remade_tokens += [0] * emb_len
emb_len = int(self.hijack.word_embeddings[word].shape[0]) multipliers += [mult] * emb_len
fixes.append((len(remade_tokens), word)) used_custom_terms.append((embedding.name, embedding.checksum()))
remade_tokens += [0] * emb_len i += embedding_length_in_tokens
multipliers += [mult] * emb_len
i += len(ids) - 1
found = True
used_custom_terms.append((word, self.hijack.word_embeddings_checksums[word]))
break
if not found:
remade_tokens.append(token)
multipliers.append(mult)
i += 1
if len(remade_tokens) > maxlen - 2: if len(remade_tokens) > maxlen - 2:
vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()} vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
@ -459,6 +239,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
overflowing_words = [vocab.get(int(x), "") for x in ovf] overflowing_words = [vocab.get(int(x), "") for x in ovf]
overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words)) overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n") hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
token_count = len(remade_tokens) token_count = len(remade_tokens)
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens)) remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end] remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end]
@ -479,7 +260,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
else: else:
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text) batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text)
self.hijack.fixes = hijack_fixes self.hijack.fixes = hijack_fixes
self.hijack.comments = hijack_comments self.hijack.comments = hijack_comments
@ -512,14 +292,19 @@ class EmbeddingsWithFixes(torch.nn.Module):
inputs_embeds = self.wrapped(input_ids) inputs_embeds = self.wrapped(input_ids)
if batch_fixes is not None: if batch_fixes is None or len(batch_fixes) == 0 or max([len(x) for x in batch_fixes]) == 0:
for fixes, tensor in zip(batch_fixes, inputs_embeds): return inputs_embeds
for offset, word in fixes:
emb = self.embeddings.word_embeddings[word]
emb_len = min(tensor.shape[0]-offset-1, emb.shape[0])
tensor[offset+1:offset+1+emb_len] = self.embeddings.word_embeddings[word][0:emb_len]
return inputs_embeds vecs = []
for fixes, tensor in zip(batch_fixes, inputs_embeds):
for offset, embedding in fixes:
emb = embedding.vec
emb_len = min(tensor.shape[0]-offset-1, emb.shape[0])
tensor = torch.cat([tensor[0:offset+1], emb[0:emb_len], tensor[offset+1+emb_len:]])
vecs.append(tensor)
return torch.stack(vecs)
def add_circular_option_to_conv_2d(): def add_circular_option_to_conv_2d():

View file

@ -0,0 +1,156 @@
import math
import torch
from torch import einsum
from ldm.util import default
from einops import rearrange
# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
def split_cross_attention_forward_v1(self, x, context=None, mask=None):
h = self.heads
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
del context, x
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
for i in range(0, q.shape[0], 2):
end = i + 2
s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
s1 *= self.scale
s2 = s1.softmax(dim=-1)
del s1
r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
del s2
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
del r1
return self.to_out(r2)
# taken from https://github.com/Doggettx/stable-diffusion
def split_cross_attention_forward(self, x, context=None, mask=None):
h = self.heads
q_in = self.to_q(x)
context = default(context, x)
k_in = self.to_k(context) * self.scale
v_in = self.to_v(context)
del context, x
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
del q_in, k_in, v_in
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
stats = torch.cuda.memory_stats(q.device)
mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_cuda + mem_free_torch
gb = 1024 ** 3
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
modifier = 3 if q.element_size() == 2 else 2.5
mem_required = tensor_size * modifier
steps = 1
if mem_required > mem_free_total:
steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
if steps > 64:
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
for i in range(0, q.shape[1], slice_size):
end = i + slice_size
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
s2 = s1.softmax(dim=-1, dtype=q.dtype)
del s1
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
del s2
del q, k, v
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
del r1
return self.to_out(r2)
def cross_attention_attnblock_forward(self, x):
h_ = x
h_ = self.norm(h_)
q1 = self.q(h_)
k1 = self.k(h_)
v = self.v(h_)
# compute attention
b, c, h, w = q1.shape
q2 = q1.reshape(b, c, h*w)
del q1
q = q2.permute(0, 2, 1) # b,hw,c
del q2
k = k1.reshape(b, c, h*w) # b,c,hw
del k1
h_ = torch.zeros_like(k, device=q.device)
stats = torch.cuda.memory_stats(q.device)
mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_cuda + mem_free_torch
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
mem_required = tensor_size * 2.5
steps = 1
if mem_required > mem_free_total:
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
for i in range(0, q.shape[1], slice_size):
end = i + slice_size
w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w2 = w1 * (int(c)**(-0.5))
del w1
w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype)
del w2
# attend to values
v1 = v.reshape(b, c, h*w)
w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
del w3
h_[:, :, i:end] = torch.bmm(v1, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
del v1, w4
h2 = h_.reshape(b, c, h, w)
del h_
h3 = self.proj_out(h2)
del h2
h3 += x
return h3

View file

@ -8,14 +8,11 @@ from omegaconf import OmegaConf
from ldm.util import instantiate_from_config from ldm.util import instantiate_from_config
from modules import shared, modelloader from modules import shared, modelloader, devices
from modules.paths import models_path from modules.paths import models_path
model_dir = "Stable-diffusion" model_dir = "Stable-diffusion"
model_path = os.path.abspath(os.path.join(models_path, model_dir)) model_path = os.path.abspath(os.path.join(models_path, model_dir))
model_name = "sd-v1-4.ckpt"
model_url = "https://drive.yerf.org/wl/?id=EBfTrmcCCUAGaQBXVIj5lJmEhjoP1tgl&mode=grid&download=1"
user_dir = None
CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name']) CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name'])
checkpoints_list = {} checkpoints_list = {}
@ -30,12 +27,10 @@ except Exception:
pass pass
def setup_model(dirname): def setup_model():
global user_dir
user_dir = dirname
if not os.path.exists(model_path): if not os.path.exists(model_path):
os.makedirs(model_path) os.makedirs(model_path)
checkpoints_list.clear()
list_models() list_models()
@ -45,13 +40,13 @@ def checkpoint_tiles():
def list_models(): def list_models():
checkpoints_list.clear() checkpoints_list.clear()
model_list = modelloader.load_models(model_path=model_path, model_url=model_url, command_path=user_dir, ext_filter=[".ckpt"], download_name=model_name) model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt"])
def modeltitle(path, shorthash): def modeltitle(path, shorthash):
abspath = os.path.abspath(path) abspath = os.path.abspath(path)
if user_dir is not None and abspath.startswith(user_dir): if shared.cmd_opts.ckpt_dir is not None and abspath.startswith(shared.cmd_opts.ckpt_dir):
name = abspath.replace(user_dir, '') name = abspath.replace(shared.cmd_opts.ckpt_dir, '')
elif abspath.startswith(model_path): elif abspath.startswith(model_path):
name = abspath.replace(model_path, '') name = abspath.replace(model_path, '')
else: else:
@ -69,6 +64,7 @@ def list_models():
h = model_hash(cmd_ckpt) h = model_hash(cmd_ckpt)
title, short_model_name = modeltitle(cmd_ckpt, h) title, short_model_name = modeltitle(cmd_ckpt, h)
checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name) checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name)
shared.opts.data['sd_model_checkpoint'] = title
elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file: elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr) print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr)
for filename in model_list: for filename in model_list:
@ -105,8 +101,11 @@ def select_checkpoint():
if len(checkpoints_list) == 0: if len(checkpoints_list) == 0:
print(f"No checkpoints found. When searching for checkpoints, looked at:", file=sys.stderr) print(f"No checkpoints found. When searching for checkpoints, looked at:", file=sys.stderr)
print(f" - file {os.path.abspath(shared.cmd_opts.ckpt)}", file=sys.stderr) if shared.cmd_opts.ckpt is not None:
print(f" - directory {os.path.abspath(shared.cmd_opts.ckpt_dir)}", file=sys.stderr) print(f" - file {os.path.abspath(shared.cmd_opts.ckpt)}", file=sys.stderr)
print(f" - directory {model_path}", file=sys.stderr)
if shared.cmd_opts.ckpt_dir is not None:
print(f" - directory {os.path.abspath(shared.cmd_opts.ckpt_dir)}", file=sys.stderr)
print(f"Can't run without a checkpoint. Find and place a .ckpt file into any of those locations. The program will exit.", file=sys.stderr) print(f"Can't run without a checkpoint. Find and place a .ckpt file into any of those locations. The program will exit.", file=sys.stderr)
exit(1) exit(1)
@ -133,6 +132,8 @@ def load_model_weights(model, checkpoint_file, sd_model_hash):
if not shared.cmd_opts.no_half: if not shared.cmd_opts.no_half:
model.half() model.half()
devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16
model.sd_model_hash = sd_model_hash model.sd_model_hash = sd_model_hash
model.sd_model_checkpint = checkpoint_file model.sd_model_checkpint = checkpoint_file

View file

@ -13,31 +13,57 @@ from modules.shared import opts, cmd_opts, state
import modules.shared as shared import modules.shared as shared
SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases']) SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
samplers_k_diffusion = [ samplers_k_diffusion = [
('Euler a', 'sample_euler_ancestral', ['k_euler_a']), ('Euler a', 'sample_euler_ancestral', ['k_euler_a'], {}),
('Euler', 'sample_euler', ['k_euler']), ('Euler', 'sample_euler', ['k_euler'], {}),
('LMS', 'sample_lms', ['k_lms']), ('LMS', 'sample_lms', ['k_lms'], {}),
('Heun', 'sample_heun', ['k_heun']), ('Heun', 'sample_heun', ['k_heun'], {}),
('DPM2', 'sample_dpm_2', ['k_dpm_2']), ('DPM2', 'sample_dpm_2', ['k_dpm_2'], {}),
('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a']), ('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {}),
('DPM fast', 'sample_dpm_fast', ['k_dpm_fast']), ('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {}),
('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad']), ('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {}),
('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}),
('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras'}),
('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras'}),
] ]
samplers_data_k_diffusion = [ samplers_data_k_diffusion = [
SamplerData(label, lambda model, funcname=funcname: KDiffusionSampler(funcname, model), aliases) SamplerData(label, lambda model, funcname=funcname: KDiffusionSampler(funcname, model), aliases, options)
for label, funcname, aliases in samplers_k_diffusion for label, funcname, aliases, options in samplers_k_diffusion
if hasattr(k_diffusion.sampling, funcname) if hasattr(k_diffusion.sampling, funcname)
] ]
samplers = [ all_samplers = [
*samplers_data_k_diffusion, *samplers_data_k_diffusion,
SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), []), SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {}),
SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), []), SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {}),
] ]
samplers_for_img2img = [x for x in samplers if x.name not in ['PLMS', 'DPM fast', 'DPM adaptive']]
samplers = []
samplers_for_img2img = []
def create_sampler_with_index(list_of_configs, index, model):
config = list_of_configs[index]
sampler = config.constructor(model)
sampler.config = config
return sampler
def set_samplers():
global samplers, samplers_for_img2img
hidden = set(opts.hide_samplers)
hidden_img2img = set(opts.hide_samplers + ['PLMS', 'DPM fast', 'DPM adaptive'])
samplers = [x for x in all_samplers if x.name not in hidden]
samplers_for_img2img = [x for x in all_samplers if x.name not in hidden_img2img]
set_samplers()
sampler_extra_params = { sampler_extra_params = {
'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'], 'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
@ -77,7 +103,9 @@ def extended_tdqm(sequence, *args, desc=None, **kwargs):
state.sampling_steps = len(sequence) state.sampling_steps = len(sequence)
state.sampling_step = 0 state.sampling_step = 0
for x in tqdm.tqdm(sequence, *args, desc=state.job, file=shared.progress_print_out, **kwargs): seq = sequence if cmd_opts.disable_console_progressbars else tqdm.tqdm(sequence, *args, desc=state.job, file=shared.progress_print_out, **kwargs)
for x in seq:
if state.interrupted: if state.interrupted:
break break
@ -102,14 +130,18 @@ class VanillaStableDiffusionSampler:
self.step = 0 self.step = 0
self.eta = None self.eta = None
self.default_eta = 0.0 self.default_eta = 0.0
self.config = None
def number_of_needed_noises(self, p): def number_of_needed_noises(self, p):
return 0 return 0
def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs): def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs):
cond = prompt_parser.reconstruct_cond_batch(cond, self.step) conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step) unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step)
assert all([len(conds) == 1 for conds in conds_list]), 'composition via AND is not supported for DDIM/PLMS samplers'
cond = tensor
if self.mask is not None: if self.mask is not None:
img_orig = self.sampler.model.q_sample(self.init_latent, ts) img_orig = self.sampler.model.q_sample(self.init_latent, ts)
x_dec = img_orig * self.mask + self.nmask * x_dec x_dec = img_orig * self.mask + self.nmask * x_dec
@ -125,7 +157,7 @@ class VanillaStableDiffusionSampler:
return res return res
def initialize(self, p): def initialize(self, p):
self.eta = p.eta or opts.eta_ddim self.eta = p.eta if p.eta is not None else opts.eta_ddim
for fieldname in ['p_sample_ddim', 'p_sample_plms']: for fieldname in ['p_sample_ddim', 'p_sample_plms']:
if hasattr(self.sampler, fieldname): if hasattr(self.sampler, fieldname):
@ -181,19 +213,31 @@ class CFGDenoiser(torch.nn.Module):
self.step = 0 self.step = 0
def forward(self, x, sigma, uncond, cond, cond_scale): def forward(self, x, sigma, uncond, cond, cond_scale):
cond = prompt_parser.reconstruct_cond_batch(cond, self.step) conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step) uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
batch_size = len(conds_list)
repeats = [len(conds_list[i]) for i in range(batch_size)]
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
cond_in = torch.cat([tensor, uncond])
if shared.batch_cond_uncond: if shared.batch_cond_uncond:
x_in = torch.cat([x] * 2) x_out = self.inner_model(x_in, sigma_in, cond=cond_in)
sigma_in = torch.cat([sigma] * 2)
cond_in = torch.cat([uncond, cond])
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
denoised = uncond + (cond - uncond) * cond_scale
else: else:
uncond = self.inner_model(x, sigma, cond=uncond) x_out = torch.zeros_like(x_in)
cond = self.inner_model(x, sigma, cond=cond) for batch_offset in range(0, x_out.shape[0], batch_size):
denoised = uncond + (cond - uncond) * cond_scale a = batch_offset
b = a + batch_size
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=cond_in[a:b])
denoised_uncond = x_out[-batch_size:]
denoised = torch.clone(denoised_uncond)
for i, conds in enumerate(conds_list):
for cond_index, weight in conds:
denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale)
if self.mask is not None: if self.mask is not None:
denoised = self.init_latent * self.mask + self.nmask * denoised denoised = self.init_latent * self.mask + self.nmask * denoised
@ -207,7 +251,9 @@ def extended_trange(sampler, count, *args, **kwargs):
state.sampling_steps = count state.sampling_steps = count
state.sampling_step = 0 state.sampling_step = 0
for x in tqdm.trange(count, *args, desc=state.job, file=shared.progress_print_out, **kwargs): seq = range(count) if cmd_opts.disable_console_progressbars else tqdm.trange(count, *args, desc=state.job, file=shared.progress_print_out, **kwargs)
for x in seq:
if state.interrupted: if state.interrupted:
break break
@ -246,6 +292,7 @@ class KDiffusionSampler:
self.stop_at = None self.stop_at = None
self.eta = None self.eta = None
self.default_eta = 1.0 self.default_eta = 1.0
self.config = None
def callback_state(self, d): def callback_state(self, d):
store_latent(d["denoised"]) store_latent(d["denoised"])
@ -290,7 +337,10 @@ class KDiffusionSampler:
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None): def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None):
steps, t_enc = setup_img2img_steps(p, steps) steps, t_enc = setup_img2img_steps(p, steps)
sigmas = self.model_wrap.get_sigmas(steps) if p.sampler_noise_scheduler_override:
sigmas = p.sampler_noise_scheduler_override(steps)
else:
sigmas = self.model_wrap.get_sigmas(steps)
noise = noise * sigmas[steps - t_enc - 1] noise = noise * sigmas[steps - t_enc - 1]
xi = x + noise xi = x + noise
@ -306,7 +356,13 @@ class KDiffusionSampler:
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None): def sample(self, p, x, conditioning, unconditional_conditioning, steps=None):
steps = steps or p.steps steps = steps or p.steps
sigmas = self.model_wrap.get_sigmas(steps) if p.sampler_noise_scheduler_override:
sigmas = p.sampler_noise_scheduler_override(steps)
elif self.config is not None and self.config.options.get('scheduler', None) == 'karras':
sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=0.1, sigma_max=10, device=shared.device)
else:
sigmas = self.model_wrap.get_sigmas(steps)
x = x * sigmas[0] x = x * sigmas[0]
extra_params_kwargs = self.initialize(p) extra_params_kwargs = self.initialize(p)

View file

@ -12,15 +12,15 @@ import modules.interrogate
import modules.memmon import modules.memmon
import modules.sd_models import modules.sd_models
import modules.styles import modules.styles
from modules.devices import get_optimal_device import modules.devices as devices
from modules.paths import script_path, sd_path from modules import sd_samplers
from modules.paths import models_path, script_path, sd_path
sd_model_file = os.path.join(script_path, 'model.ckpt') sd_model_file = os.path.join(script_path, 'model.ckpt')
default_sd_model_file = sd_model_file default_sd_model_file = sd_model_file
model_path = os.path.join(script_path, 'models')
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default=os.path.join(sd_path, "configs/stable-diffusion/v1-inference.yaml"), help="path to config which constructs model",) parser.add_argument("--config", type=str, default=os.path.join(sd_path, "configs/stable-diffusion/v1-inference.yaml"), help="path to config which constructs model",)
parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; this checkpoint will be added to the list of checkpoints and loaded by default if you don't have a checkpoint selected in settings",) parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",)
parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints") parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints")
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN')) parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default=None) parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default=None)
@ -35,16 +35,18 @@ parser.add_argument("--always-batch-cond-uncond", action='store_true', help="dis
parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.") parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.")
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast") parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site (doesn't work for me but you might have better luck)") parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site (doesn't work for me but you might have better luck)")
parser.add_argument("--codeformer-models-path", type=str, help="Path to directory with codeformer model file(s).", default=os.path.join(model_path, 'Codeformer')) parser.add_argument("--codeformer-models-path", type=str, help="Path to directory with codeformer model file(s).", default=os.path.join(models_path, 'Codeformer'))
parser.add_argument("--gfpgan-models-path", type=str, help="Path to directory with GFPGAN model file(s).", default=os.path.join(model_path, 'GFPGAN')) parser.add_argument("--gfpgan-models-path", type=str, help="Path to directory with GFPGAN model file(s).", default=os.path.join(models_path, 'GFPGAN'))
parser.add_argument("--esrgan-models-path", type=str, help="Path to directory with ESRGAN model file(s).", default=os.path.join(model_path, 'ESRGAN')) parser.add_argument("--esrgan-models-path", type=str, help="Path to directory with ESRGAN model file(s).", default=os.path.join(models_path, 'ESRGAN'))
parser.add_argument("--bsrgan-models-path", type=str, help="Path to directory with BSRGAN model file(s).", default=os.path.join(model_path, 'BSRGAN')) parser.add_argument("--bsrgan-models-path", type=str, help="Path to directory with BSRGAN model file(s).", default=os.path.join(models_path, 'BSRGAN'))
parser.add_argument("--realesrgan-models-path", type=str, help="Path to directory with RealESRGAN model file(s).", default=os.path.join(model_path, 'RealESRGAN')) parser.add_argument("--realesrgan-models-path", type=str, help="Path to directory with RealESRGAN model file(s).", default=os.path.join(models_path, 'RealESRGAN'))
parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(model_path, 'SwinIR')) parser.add_argument("--scunet-models-path", type=str, help="Path to directory with ScuNET model file(s).", default=os.path.join(models_path, 'ScuNET'))
parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(model_path, 'LDSR')) parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(models_path, 'SwinIR'))
parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(models_path, 'LDSR'))
parser.add_argument("--opt-split-attention", action='store_true', help="force-enables cross-attention layer optimization. By default, it's on for torch.cuda and off for other torch devices.") parser.add_argument("--opt-split-attention", action='store_true', help="force-enables cross-attention layer optimization. By default, it's on for torch.cuda and off for other torch devices.")
parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization") parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find") parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
parser.add_argument("--use-cpu", nargs='+',choices=['SD', 'GFPGAN', 'BSRGAN', 'ESRGAN', 'SCUNet', 'CodeFormer'], help="use CPU as torch device for specified modules", default=[])
parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests") parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None) parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None)
parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False) parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False)
@ -53,13 +55,21 @@ parser.add_argument("--hide-ui-dir-config", action='store_true', help="hide dire
parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default=os.path.join(script_path, 'config.json')) parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default=os.path.join(script_path, 'config.json'))
parser.add_argument("--gradio-debug", action='store_true', help="launch gradio with --debug option") parser.add_argument("--gradio-debug", action='store_true', help="launch gradio with --debug option")
parser.add_argument("--gradio-auth", type=str, help='set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None) parser.add_argument("--gradio-auth", type=str, help='set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
parser.add_argument("--gradio-img2img-tool", type=str, help='gradio image uploader tool: can be either editor for ctopping, or color-sketch for drawing', choices=["color-sketch", "editor"], default="editor")
parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last") parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last")
parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(script_path, 'styles.csv')) parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(script_path, 'styles.csv'))
parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False) parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False)
parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False) parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False)
parser.add_argument("--disable-console-progressbars", action='store_true', help="do not output progressbars to console", default=False)
parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False)
cmd_opts = parser.parse_args() cmd_opts = parser.parse_args()
device = get_optimal_device()
devices.device, devices.device_gfpgan, devices.device_bsrgan, devices.device_esrgan, devices.device_scunet, devices.device_codeformer = \
(devices.cpu if x in cmd_opts.use_cpu else devices.get_optimal_device() for x in ['SD', 'GFPGAN', 'BSRGAN', 'ESRGAN', 'SCUNet', 'CodeFormer'])
device = devices.device
batch_cond_uncond = cmd_opts.always_batch_cond_uncond or not (cmd_opts.lowvram or cmd_opts.medvram) batch_cond_uncond = cmd_opts.always_batch_cond_uncond or not (cmd_opts.lowvram or cmd_opts.medvram)
parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram
@ -78,6 +88,7 @@ class State:
current_latent = None current_latent = None
current_image = None current_image = None
current_image_sampling_step = 0 current_image_sampling_step = 0
textinfo = None
def interrupt(self): def interrupt(self):
self.interrupted = True self.interrupted = True
@ -88,7 +99,7 @@ class State:
self.current_image_sampling_step = 0 self.current_image_sampling_step = 0
def get_job_timestamp(self): def get_job_timestamp(self):
return datetime.datetime.now().strftime("%Y%m%d%H%M%S") return datetime.datetime.now().strftime("%Y%m%d%H%M%S") # shouldn't this return job_timestamp?
state = State() state = State()
@ -165,9 +176,10 @@ options_templates.update(options_section(('saving-paths', "Paths for saving"), {
options_templates.update(options_section(('saving-to-dirs', "Saving to a directory"), { options_templates.update(options_section(('saving-to-dirs', "Saving to a directory"), {
"save_to_dirs": OptionInfo(False, "Save images to a subdirectory"), "save_to_dirs": OptionInfo(False, "Save images to a subdirectory"),
"grid_save_to_dirs": OptionInfo(False, "Save grids to subdirectory"), "grid_save_to_dirs": OptionInfo(False, "Save grids to a subdirectory"),
"use_save_to_dirs_for_ui": OptionInfo(False, "When using \"Save\" button, save images to a subdirectory"),
"directories_filename_pattern": OptionInfo("", "Directory name pattern"), "directories_filename_pattern": OptionInfo("", "Directory name pattern"),
"directories_max_prompt_words": OptionInfo(8, "Max prompt words", gr.Slider, {"minimum": 1, "maximum": 20, "step": 1}), "directories_max_prompt_words": OptionInfo(8, "Max prompt words for [prompt_words] pattern", gr.Slider, {"minimum": 1, "maximum": 20, "step": 1}),
})) }))
options_templates.update(options_section(('upscaling', "Upscaling"), { options_templates.update(options_section(('upscaling', "Upscaling"), {
@ -177,7 +189,7 @@ options_templates.update(options_section(('upscaling', "Upscaling"), {
"SWIN_tile": OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}), "SWIN_tile": OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}),
"SWIN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}), "SWIN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
"ldsr_steps": OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}), "ldsr_steps": OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}),
"upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Radio, lambda: {"choices": [x.name for x in sd_upscalers]}), "upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}),
})) }))
options_templates.update(options_section(('face-restoration', "Face restoration"), { options_templates.update(options_section(('face-restoration', "Face restoration"), {
@ -189,7 +201,7 @@ options_templates.update(options_section(('face-restoration', "Face restoration"
options_templates.update(options_section(('system', "System"), { options_templates.update(options_section(('system', "System"), {
"memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation. Set to 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}), "memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation. Set to 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}),
"samples_log_stdout": OptionInfo(False, "Always print all generation info to standard output"), "samples_log_stdout": OptionInfo(False, "Always print all generation info to standard output"),
"multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job. Broken in PyCharm console."), "multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."),
})) }))
options_templates.update(options_section(('sd', "Stable Diffusion"), { options_templates.update(options_section(('sd', "Stable Diffusion"), {
@ -198,7 +210,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
"save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"), "save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
"img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."), "img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."),
"enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."), "enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."),
"enable_emphasis": OptionInfo(True, "Eemphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"), "enable_emphasis": OptionInfo(True, "Emphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"),
"use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."), "use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."),
"enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"), "enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
"filter_nsfw": OptionInfo(False, "Filter NSFW content"), "filter_nsfw": OptionInfo(False, "Filter NSFW content"),
@ -218,21 +230,25 @@ options_templates.update(options_section(('ui', "User interface"), {
"show_progressbar": OptionInfo(True, "Show progressbar"), "show_progressbar": OptionInfo(True, "Show progressbar"),
"show_progress_every_n_steps": OptionInfo(0, "Show show image creation progress every N sampling steps. Set 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 32, "step": 1}), "show_progress_every_n_steps": OptionInfo(0, "Show show image creation progress every N sampling steps. Set 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 32, "step": 1}),
"return_grid": OptionInfo(True, "Show grid in results for web"), "return_grid": OptionInfo(True, "Show grid in results for web"),
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"), "add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
"font": OptionInfo("", "Font for image grids that have text"), "font": OptionInfo("", "Font for image grids that have text"),
"js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"), "js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"),
"js_modal_lightbox_initialy_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"), "js_modal_lightbox_initialy_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),
"show_progress_in_title": OptionInfo(True, "Show generation progress in window title."),
})) }))
options_templates.update(options_section(('sampler-params', "Sampler parameters"), { options_templates.update(options_section(('sampler-params', "Sampler parameters"), {
"eta_ddim": OptionInfo(0.0, "eta (noise multiplier) for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), "hide_samplers": OptionInfo([], "Hide samplers in user interface (requires restart)", gr.CheckboxGroup, lambda: {"choices": [x.name for x in sd_samplers.all_samplers]}),
"eta_ancestral": OptionInfo(1.0, "eta (noise multiplier) for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), "eta_ddim": OptionInfo(0.0, "eta (noise multiplier) for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
"ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}), "eta_ancestral": OptionInfo(1.0, "eta (noise multiplier) for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), "ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}),
's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), 's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), 's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
})) }))
class Options: class Options:
data = None data = None
data_labels = options_templates data_labels = options_templates
@ -318,14 +334,14 @@ class TotalTQDM:
) )
def update(self): def update(self):
if not opts.multiple_tqdm: if not opts.multiple_tqdm or cmd_opts.disable_console_progressbars:
return return
if self._tqdm is None: if self._tqdm is None:
self.reset() self.reset()
self._tqdm.update() self._tqdm.update()
def updateTotal(self, new_total): def updateTotal(self, new_total):
if not opts.multiple_tqdm: if not opts.multiple_tqdm or cmd_opts.disable_console_progressbars:
return return
if self._tqdm is None: if self._tqdm is None:
self.reset() self.reset()

View file

@ -5,6 +5,7 @@ import numpy as np
import torch import torch
from PIL import Image from PIL import Image
from basicsr.utils.download_util import load_file_from_url from basicsr.utils.download_util import load_file_from_url
from tqdm import tqdm
from modules import modelloader from modules import modelloader
from modules.paths import models_path from modules.paths import models_path
@ -122,18 +123,20 @@ def inference(img, model, tile, tile_overlap, window_size, scale):
E = torch.zeros(b, c, h * sf, w * sf, dtype=torch.half, device=device).type_as(img) E = torch.zeros(b, c, h * sf, w * sf, dtype=torch.half, device=device).type_as(img)
W = torch.zeros_like(E, dtype=torch.half, device=device) W = torch.zeros_like(E, dtype=torch.half, device=device)
for h_idx in h_idx_list: with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar:
for w_idx in w_idx_list: for h_idx in h_idx_list:
in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile] for w_idx in w_idx_list:
out_patch = model(in_patch) in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
out_patch_mask = torch.ones_like(out_patch) out_patch = model(in_patch)
out_patch_mask = torch.ones_like(out_patch)
E[ E[
..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
].add_(out_patch) ].add_(out_patch)
W[ W[
..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
].add_(out_patch_mask) ].add_(out_patch_mask)
pbar.update(1)
output = E.div_(W) output = E.div_(W)
return output return output

View file

@ -0,0 +1,81 @@
import os
import numpy as np
import PIL
import torch
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
import random
import tqdm
from modules import devices
import re
re_tag = re.compile(r"[a-zA-Z][_\w\d()]+")
class PersonalizedBase(Dataset):
def __init__(self, data_root, size=None, repeats=100, flip_p=0.5, placeholder_token="*", width=512, height=512, model=None, device=None, template_file=None):
self.placeholder_token = placeholder_token
self.size = size
self.width = width
self.height = height
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
self.dataset = []
with open(template_file, "r") as file:
lines = [x.strip() for x in file.readlines()]
self.lines = lines
assert data_root, 'dataset directory not specified'
self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
print("Preparing dataset...")
for path in tqdm.tqdm(self.image_paths):
image = Image.open(path)
image = image.convert('RGB')
image = image.resize((self.width, self.height), PIL.Image.BICUBIC)
filename = os.path.basename(path)
filename_tokens = os.path.splitext(filename)[0]
filename_tokens = re_tag.findall(filename_tokens)
npimage = np.array(image).astype(np.uint8)
npimage = (npimage / 127.5 - 1.0).astype(np.float32)
torchdata = torch.from_numpy(npimage).to(device=device, dtype=torch.float32)
torchdata = torch.moveaxis(torchdata, 2, 0)
init_latent = model.get_first_stage_encoding(model.encode_first_stage(torchdata.unsqueeze(dim=0))).squeeze()
init_latent = init_latent.to(devices.cpu)
self.dataset.append((init_latent, filename_tokens))
self.length = len(self.dataset) * repeats
self.initial_indexes = np.arange(self.length) % len(self.dataset)
self.indexes = None
self.shuffle()
def shuffle(self):
self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0])]
def __len__(self):
return self.length
def __getitem__(self, i):
if i % len(self.dataset) == 0:
self.shuffle()
index = self.indexes[i % len(self.indexes)]
x, filename_tokens = self.dataset[index]
text = random.choice(self.lines)
text = text.replace("[name]", self.placeholder_token)
text = text.replace("[filewords]", ' '.join(filename_tokens))
return x, text

View file

@ -0,0 +1,104 @@
import os
from PIL import Image, ImageOps
import platform
import sys
import tqdm
from modules import shared, images
def preprocess(process_src, process_dst, process_flip, process_split, process_caption):
size = 512
src = os.path.abspath(process_src)
dst = os.path.abspath(process_dst)
assert src != dst, 'same directory specified as source and destination'
os.makedirs(dst, exist_ok=True)
files = os.listdir(src)
shared.state.textinfo = "Preprocessing..."
shared.state.job_count = len(files)
if process_caption:
shared.interrogator.load()
def save_pic_with_caption(image, index):
if process_caption:
caption = "-" + shared.interrogator.generate_caption(image)
caption = sanitize_caption(os.path.join(dst, f"{index:05}-{subindex[0]}"), caption, ".png")
else:
caption = filename
caption = os.path.splitext(caption)[0]
caption = os.path.basename(caption)
image.save(os.path.join(dst, f"{index:05}-{subindex[0]}{caption}.png"))
subindex[0] += 1
def save_pic(image, index):
save_pic_with_caption(image, index)
if process_flip:
save_pic_with_caption(ImageOps.mirror(image), index)
for index, imagefile in enumerate(tqdm.tqdm(files)):
subindex = [0]
filename = os.path.join(src, imagefile)
img = Image.open(filename).convert("RGB")
if shared.state.interrupted:
break
ratio = img.height / img.width
is_tall = ratio > 1.35
is_wide = ratio < 1 / 1.35
if process_split and is_tall:
img = img.resize((size, size * img.height // img.width))
top = img.crop((0, 0, size, size))
save_pic(top, index)
bot = img.crop((0, img.height - size, size, img.height))
save_pic(bot, index)
elif process_split and is_wide:
img = img.resize((size * img.width // img.height, size))
left = img.crop((0, 0, size, size))
save_pic(left, index)
right = img.crop((img.width - size, 0, img.width, size))
save_pic(right, index)
else:
img = images.resize_image(1, img, size, size)
save_pic(img, index)
shared.state.nextjob()
if process_caption:
shared.interrogator.send_blip_to_ram()
def sanitize_caption(base_path, original_caption, suffix):
operating_system = platform.system().lower()
if (operating_system == "windows"):
invalid_path_characters = "\\/:*?\"<>|"
max_path_length = 259
else:
invalid_path_characters = "/" #linux/macos
max_path_length = 1023
caption = original_caption
for invalid_character in invalid_path_characters:
caption = caption.replace(invalid_character, "")
fixed_path_length = len(base_path) + len(suffix)
if fixed_path_length + len(caption) <= max_path_length:
return caption
caption_tokens = caption.split()
new_caption = ""
for token in caption_tokens:
last_caption = new_caption
new_caption = new_caption + token + " "
if (len(new_caption) + fixed_path_length - 1 > max_path_length):
break
print(f"\nPath will be too long. Truncated caption: {original_caption}\nto: {last_caption}", file=sys.stderr)
return last_caption.strip()

View file

@ -0,0 +1,271 @@
import os
import sys
import traceback
import torch
import tqdm
import html
import datetime
from modules import shared, devices, sd_hijack, processing, sd_models
import modules.textual_inversion.dataset
class Embedding:
def __init__(self, vec, name, step=None):
self.vec = vec
self.name = name
self.step = step
self.cached_checksum = None
self.sd_checkpoint = None
self.sd_checkpoint_name = None
def save(self, filename):
embedding_data = {
"string_to_token": {"*": 265},
"string_to_param": {"*": self.vec},
"name": self.name,
"step": self.step,
"sd_checkpoint": self.sd_checkpoint,
"sd_checkpoint_name": self.sd_checkpoint_name,
}
torch.save(embedding_data, filename)
def checksum(self):
if self.cached_checksum is not None:
return self.cached_checksum
def const_hash(a):
r = 0
for v in a:
r = (r * 281 ^ int(v) * 997) & 0xFFFFFFFF
return r
self.cached_checksum = f'{const_hash(self.vec.reshape(-1) * 100) & 0xffff:04x}'
return self.cached_checksum
class EmbeddingDatabase:
def __init__(self, embeddings_dir):
self.ids_lookup = {}
self.word_embeddings = {}
self.dir_mtime = None
self.embeddings_dir = embeddings_dir
def register_embedding(self, embedding, model):
self.word_embeddings[embedding.name] = embedding
ids = model.cond_stage_model.tokenizer([embedding.name], add_special_tokens=False)['input_ids'][0]
first_id = ids[0]
if first_id not in self.ids_lookup:
self.ids_lookup[first_id] = []
self.ids_lookup[first_id] = sorted(self.ids_lookup[first_id] + [(ids, embedding)], key=lambda x: len(x[0]), reverse=True)
return embedding
def load_textual_inversion_embeddings(self):
mt = os.path.getmtime(self.embeddings_dir)
if self.dir_mtime is not None and mt <= self.dir_mtime:
return
self.dir_mtime = mt
self.ids_lookup.clear()
self.word_embeddings.clear()
def process_file(path, filename):
name = os.path.splitext(filename)[0]
data = torch.load(path, map_location="cpu")
# textual inversion embeddings
if 'string_to_param' in data:
param_dict = data['string_to_param']
if hasattr(param_dict, '_parameters'):
param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
emb = next(iter(param_dict.items()))[1]
# diffuser concepts
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
emb = next(iter(data.values()))
if len(emb.shape) == 1:
emb = emb.unsqueeze(0)
else:
raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
vec = emb.detach().to(devices.device, dtype=torch.float32)
embedding = Embedding(vec, name)
embedding.step = data.get('step', None)
embedding.sd_checkpoint = data.get('hash', None)
embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
self.register_embedding(embedding, shared.sd_model)
for fn in os.listdir(self.embeddings_dir):
try:
fullfn = os.path.join(self.embeddings_dir, fn)
if os.stat(fullfn).st_size == 0:
continue
process_file(fullfn, fn)
except Exception:
print(f"Error loading emedding {fn}:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
continue
print(f"Loaded a total of {len(self.word_embeddings)} textual inversion embeddings.")
def find_embedding_at_position(self, tokens, offset):
token = tokens[offset]
possible_matches = self.ids_lookup.get(token, None)
if possible_matches is None:
return None, None
for ids, embedding in possible_matches:
if tokens[offset:offset + len(ids)] == ids:
return embedding, len(ids)
return None, None
def create_embedding(name, num_vectors_per_token, init_text='*'):
cond_model = shared.sd_model.cond_stage_model
embedding_layer = cond_model.wrapped.transformer.text_model.embeddings
ids = cond_model.tokenizer(init_text, max_length=num_vectors_per_token, return_tensors="pt", add_special_tokens=False)["input_ids"]
embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0)
vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device)
for i in range(num_vectors_per_token):
vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token]
fn = os.path.join(shared.cmd_opts.embeddings_dir, f"{name}.pt")
assert not os.path.exists(fn), f"file {fn} already exists"
embedding = Embedding(vec, name)
embedding.step = 0
embedding.save(fn)
return fn
def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, create_image_every, save_embedding_every, template_file):
assert embedding_name, 'embedding not selected'
shared.state.textinfo = "Initializing textual inversion training..."
shared.state.job_count = steps
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), embedding_name)
if save_embedding_every > 0:
embedding_dir = os.path.join(log_directory, "embeddings")
os.makedirs(embedding_dir, exist_ok=True)
else:
embedding_dir = None
if create_image_every > 0:
images_dir = os.path.join(log_directory, "images")
os.makedirs(images_dir, exist_ok=True)
else:
images_dir = None
cond_model = shared.sd_model.cond_stage_model
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"):
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, size=512, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file)
hijack = sd_hijack.model_hijack
embedding = hijack.embedding_db.word_embeddings[embedding_name]
embedding.vec.requires_grad = True
optimizer = torch.optim.AdamW([embedding.vec], lr=learn_rate)
losses = torch.zeros((32,))
last_saved_file = "<none>"
last_saved_image = "<none>"
ititial_step = embedding.step or 0
if ititial_step > steps:
return embedding, filename
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
for i, (x, text) in pbar:
embedding.step = i + ititial_step
if embedding.step > steps:
break
if shared.state.interrupted:
break
with torch.autocast("cuda"):
c = cond_model([text])
x = x.to(devices.device)
loss = shared.sd_model(x.unsqueeze(0), c)[0]
del x
losses[embedding.step % losses.shape[0]] = loss.item()
optimizer.zero_grad()
loss.backward()
optimizer.step()
pbar.set_description(f"loss: {losses.mean():.7f}")
if embedding.step > 0 and embedding_dir is not None and embedding.step % save_embedding_every == 0:
last_saved_file = os.path.join(embedding_dir, f'{embedding_name}-{embedding.step}.pt')
embedding.save(last_saved_file)
if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0:
last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png')
p = processing.StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
prompt=text,
steps=20,
do_not_save_grid=True,
do_not_save_samples=True,
)
processed = processing.process_images(p)
image = processed.images[0]
shared.state.current_image = image
image.save(last_saved_image)
last_saved_image += f", prompt: {text}"
shared.state.job_no = embedding.step
shared.state.textinfo = f"""
<p>
Loss: {losses.mean():.7f}<br/>
Step: {embedding.step}<br/>
Last prompt: {html.escape(text)}<br/>
Last saved embedding: {html.escape(last_saved_file)}<br/>
Last saved image: {html.escape(last_saved_image)}<br/>
</p>
"""
checkpoint = sd_models.select_checkpoint()
embedding.sd_checkpoint = checkpoint.hash
embedding.sd_checkpoint_name = checkpoint.model_name
embedding.cached_checksum = None
embedding.save(filename)
return embedding, filename

View file

@ -0,0 +1,40 @@
import html
import gradio as gr
import modules.textual_inversion.textual_inversion
import modules.textual_inversion.preprocess
from modules import sd_hijack, shared
def create_embedding(name, initialization_text, nvpt):
filename = modules.textual_inversion.textual_inversion.create_embedding(name, nvpt, init_text=initialization_text)
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
return gr.Dropdown.update(choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())), f"Created: {filename}", ""
def preprocess(*args):
modules.textual_inversion.preprocess.preprocess(*args)
return "Preprocessing finished.", ""
def train_embedding(*args):
try:
sd_hijack.undo_optimizations()
embedding, filename = modules.textual_inversion.textual_inversion.train_embedding(*args)
res = f"""
Training {'interrupted' if shared.state.interrupted else 'finished'} at {embedding.step} steps.
Embedding saved to {html.escape(filename)}
"""
return res, ""
except Exception:
raise
finally:
sd_hijack.apply_optimizations()

View file

@ -34,7 +34,9 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
denoising_strength=denoising_strength if enable_hr else None, denoising_strength=denoising_strength if enable_hr else None,
) )
print(f"\ntxt2img: {prompt}", file=shared.progress_print_out) if cmd_opts.enable_console_prompts:
print(f"\ntxt2img: {prompt}", file=shared.progress_print_out)
processed = modules.scripts.scripts_txt2img.run(p, *args) processed = modules.scripts.scripts_txt2img.run(p, *args)
if processed is None: if processed is None:
@ -46,5 +48,8 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
if opts.samples_log_stdout: if opts.samples_log_stdout:
print(generation_info_js) print(generation_info_js)
if opts.do_not_show_images:
processed.images = []
return processed.images, generation_info_js, plaintext_to_html(processed.info) return processed.images, generation_info_js, plaintext_to_html(processed.info)

View file

@ -11,6 +11,7 @@ import time
import traceback import traceback
import platform import platform
import subprocess as sp import subprocess as sp
from functools import reduce
import numpy as np import numpy as np
import torch import torch
@ -21,6 +22,7 @@ import gradio as gr
import gradio.utils import gradio.utils
import gradio.routes import gradio.routes
from modules import sd_hijack
from modules.paths import script_path from modules.paths import script_path
from modules.shared import opts, cmd_opts from modules.shared import opts, cmd_opts
import modules.shared as shared import modules.shared as shared
@ -32,6 +34,9 @@ import modules.gfpgan_model
import modules.codeformer_model import modules.codeformer_model
import modules.styles import modules.styles
import modules.generation_parameters_copypaste import modules.generation_parameters_copypaste
from modules import prompt_parser
from modules.images import save_image
import modules.textual_inversion.ui
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the bowser will not show any UI # this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the bowser will not show any UI
mimetypes.init() mimetypes.init()
@ -64,7 +69,7 @@ random_symbol = '\U0001f3b2\ufe0f' # 🎲️
reuse_symbol = '\u267b\ufe0f' # ♻️ reuse_symbol = '\u267b\ufe0f' # ♻️
art_symbol = '\U0001f3a8' # 🎨 art_symbol = '\U0001f3a8' # 🎨
paste_symbol = '\u2199\ufe0f' # ↙ paste_symbol = '\u2199\ufe0f' # ↙
folder_symbol = '\uD83D\uDCC2' folder_symbol = '\U0001f4c2' # 📂
def plaintext_to_html(text): def plaintext_to_html(text):
text = "<p>" + "<br>\n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "</p>" text = "<p>" + "<br>\n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "</p>"
@ -94,18 +99,28 @@ def send_gradio_gallery_to_image(x):
def save_files(js_data, images, index): def save_files(js_data, images, index):
import csv import csv
os.makedirs(opts.outdir_save, exist_ok=True)
filenames = [] filenames = []
#quick dictionary to class object conversion. Its neccesary due apply_filename_pattern requiring it
class MyObject:
def __init__(self, d=None):
if d is not None:
for key, value in d.items():
setattr(self, key, value)
data = json.loads(js_data) data = json.loads(js_data)
p = MyObject(data)
path = opts.outdir_save
save_to_dirs = opts.use_save_to_dirs_for_ui
extension: str = opts.samples_format
start_index = 0
if index > -1 and opts.save_selected_only and (index >= data["index_of_first_image"]): # ensures we are looking at a specific non-grid picture, and we have save_selected_only if index > -1 and opts.save_selected_only and (index >= data["index_of_first_image"]): # ensures we are looking at a specific non-grid picture, and we have save_selected_only
images = [images[index]] images = [images[index]]
infotexts = [data["infotexts"][index]] start_index = index
else:
infotexts = data["infotexts"]
with open(os.path.join(opts.outdir_save, "log.csv"), "a", encoding="utf8", newline='') as file: with open(os.path.join(opts.outdir_save, "log.csv"), "a", encoding="utf8", newline='') as file:
at_start = file.tell() == 0 at_start = file.tell() == 0
@ -113,28 +128,18 @@ def save_files(js_data, images, index):
if at_start: if at_start:
writer.writerow(["prompt", "seed", "width", "height", "sampler", "cfgs", "steps", "filename", "negative_prompt"]) writer.writerow(["prompt", "seed", "width", "height", "sampler", "cfgs", "steps", "filename", "negative_prompt"])
filename_base = str(int(time.time() * 1000)) for image_index, filedata in enumerate(images, start_index):
extension = opts.samples_format.lower()
for i, filedata in enumerate(images):
filename = filename_base + ("" if len(images) == 1 else "-" + str(i + 1)) + f".{extension}"
filepath = os.path.join(opts.outdir_save, filename)
if filedata.startswith("data:image/png;base64,"): if filedata.startswith("data:image/png;base64,"):
filedata = filedata[len("data:image/png;base64,"):] filedata = filedata[len("data:image/png;base64,"):]
image = Image.open(io.BytesIO(base64.decodebytes(filedata.encode('utf-8')))) image = Image.open(io.BytesIO(base64.decodebytes(filedata.encode('utf-8'))))
if opts.enable_pnginfo and extension == 'png':
pnginfo = PngImagePlugin.PngInfo()
pnginfo.add_text('parameters', infotexts[i])
image.save(filepath, pnginfo=pnginfo)
else:
image.save(filepath, quality=opts.jpeg_quality)
if opts.enable_pnginfo and extension in ("jpg", "jpeg", "webp"): is_grid = image_index < p.index_of_first_image
piexif.insert(piexif.dump({"Exif": { i = 0 if is_grid else (image_index - p.index_of_first_image)
piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(infotexts[i], encoding="unicode")
}}), filepath)
fullfn = save_image(image, path, "", seed=p.all_seeds[i], prompt=p.all_prompts[i], extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs)
filename = os.path.relpath(fullfn, path)
filenames.append(filename) filenames.append(filename)
writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]]) writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]])
@ -142,8 +147,8 @@ def save_files(js_data, images, index):
return '', '', plaintext_to_html(f"Saved: {filenames[0]}") return '', '', plaintext_to_html(f"Saved: {filenames[0]}")
def wrap_gradio_call(func): def wrap_gradio_call(func, extra_outputs=None):
def f(*args, **kwargs): def f(*args, extra_outputs_array=extra_outputs, **kwargs):
run_memmon = opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled run_memmon = opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled
if run_memmon: if run_memmon:
shared.mem_mon.monitor() shared.mem_mon.monitor()
@ -159,9 +164,17 @@ def wrap_gradio_call(func):
shared.state.job = "" shared.state.job = ""
shared.state.job_count = 0 shared.state.job_count = 0
res = [None, '', f"<div class='error'>{plaintext_to_html(type(e).__name__+': '+str(e))}</div>"] if extra_outputs_array is None:
extra_outputs_array = [None, '']
res = extra_outputs_array + [f"<div class='error'>{plaintext_to_html(type(e).__name__+': '+str(e))}</div>"]
elapsed = time.perf_counter() - t elapsed = time.perf_counter() - t
elapsed_m = int(elapsed // 60)
elapsed_s = elapsed % 60
elapsed_text = f"{elapsed_s:.2f}s"
if (elapsed_m > 0):
elapsed_text = f"{elapsed_m}m "+elapsed_text
if run_memmon: if run_memmon:
mem_stats = {k: -(v//-(1024*1024)) for k, v in shared.mem_mon.stop().items()} mem_stats = {k: -(v//-(1024*1024)) for k, v in shared.mem_mon.stop().items()}
@ -176,9 +189,10 @@ def wrap_gradio_call(func):
vram_html = '' vram_html = ''
# last item is always HTML # last item is always HTML
res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr>{elapsed:.2f}s</p>{vram_html}</div>" res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr>{elapsed_text}</p>{vram_html}</div>"
shared.state.interrupted = False shared.state.interrupted = False
shared.state.job_count = 0
return tuple(res) return tuple(res)
@ -187,7 +201,7 @@ def wrap_gradio_call(func):
def check_progress_call(id_part): def check_progress_call(id_part):
if shared.state.job_count == 0: if shared.state.job_count == 0:
return "", gr_show(False), gr_show(False) return "", gr_show(False), gr_show(False), gr_show(False)
progress = 0 progress = 0
@ -219,13 +233,19 @@ def check_progress_call(id_part):
else: else:
preview_visibility = gr_show(True) preview_visibility = gr_show(True)
return f"<span id='{id_part}_progress_span' style='display: none'>{time.time()}</span><p>{progressbar}</p>", preview_visibility, image if shared.state.textinfo is not None:
textinfo_result = gr.HTML.update(value=shared.state.textinfo, visible=True)
else:
textinfo_result = gr_show(False)
return f"<span id='{id_part}_progress_span' style='display: none'>{time.time()}</span><p>{progressbar}</p>", preview_visibility, image, textinfo_result
def check_progress_call_initial(id_part): def check_progress_call_initial(id_part):
shared.state.job_count = -1 shared.state.job_count = -1
shared.state.current_latent = None shared.state.current_latent = None
shared.state.current_image = None shared.state.current_image = None
shared.state.textinfo = None
return check_progress_call(id_part) return check_progress_call(id_part)
@ -345,11 +365,24 @@ def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info:
outputs=[seed, dummy_component] outputs=[seed, dummy_component]
) )
def update_token_counter(text):
tokens, token_count, max_length = model_hijack.tokenize(text) def update_token_counter(text, steps):
try:
_, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text])
prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompt_flat_list, steps)
except Exception:
# a parsing error can happen here during typing, and we don't want to bother the user with
# messages related to it in console
prompt_schedules = [[[steps, text]]]
flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules)
prompts = [prompt_text for step, prompt_text in flat_prompts]
tokens, token_count, max_length = max([model_hijack.tokenize(prompt) for prompt in prompts], key=lambda args: args[1])
style_class = ' class="red"' if (token_count > max_length) else "" style_class = ' class="red"' if (token_count > max_length) else ""
return f"<span {style_class}>{token_count}/{max_length}</span>" return f"<span {style_class}>{token_count}/{max_length}</span>"
def create_toprow(is_img2img): def create_toprow(is_img2img):
id_part = "img2img" if is_img2img else "txt2img" id_part = "img2img" if is_img2img else "txt2img"
@ -364,8 +397,7 @@ def create_toprow(is_img2img):
roll = gr.Button(value=art_symbol, elem_id="roll", visible=len(shared.artist_db.artists) > 0) roll = gr.Button(value=art_symbol, elem_id="roll", visible=len(shared.artist_db.artists) > 0)
paste = gr.Button(value=paste_symbol, elem_id="paste") paste = gr.Button(value=paste_symbol, elem_id="paste")
token_counter = gr.HTML(value="<span></span>", elem_id=f"{id_part}_token_counter") token_counter = gr.HTML(value="<span></span>", elem_id=f"{id_part}_token_counter")
hidden_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button") token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
hidden_button.click(fn=update_token_counter, inputs=[prompt], outputs=[token_counter])
with gr.Column(scale=10, elem_id="style_pos_col"): with gr.Column(scale=10, elem_id="style_pos_col"):
prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())), visible=len(shared.prompt_styles.styles) > 1) prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())), visible=len(shared.prompt_styles.styles) > 1)
@ -380,7 +412,7 @@ def create_toprow(is_img2img):
with gr.Column(scale=1): with gr.Column(scale=1):
with gr.Row(): with gr.Row():
interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt") interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt")
submit = gr.Button('Generate', elem_id="generate", variant='primary') submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary')
interrupt.click( interrupt.click(
fn=lambda: shared.state.interrupt(), fn=lambda: shared.state.interrupt(),
@ -396,16 +428,19 @@ def create_toprow(is_img2img):
prompt_style_apply = gr.Button('Apply style', elem_id="style_apply") prompt_style_apply = gr.Button('Apply style', elem_id="style_apply")
save_style = gr.Button('Create style', elem_id="style_create") save_style = gr.Button('Create style', elem_id="style_create")
return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, interrogate, prompt_style_apply, save_style, paste return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, interrogate, prompt_style_apply, save_style, paste, token_counter, token_button
def setup_progressbar(progressbar, preview, id_part): def setup_progressbar(progressbar, preview, id_part, textinfo=None):
if textinfo is None:
textinfo = gr.HTML(visible=False)
check_progress = gr.Button('Check progress', elem_id=f"{id_part}_check_progress", visible=False) check_progress = gr.Button('Check progress', elem_id=f"{id_part}_check_progress", visible=False)
check_progress.click( check_progress.click(
fn=lambda: check_progress_call(id_part), fn=lambda: check_progress_call(id_part),
show_progress=False, show_progress=False,
inputs=[], inputs=[],
outputs=[progressbar, preview, preview], outputs=[progressbar, preview, preview, textinfo],
) )
check_progress_initial = gr.Button('Check progress (first)', elem_id=f"{id_part}_check_progress_initial", visible=False) check_progress_initial = gr.Button('Check progress (first)', elem_id=f"{id_part}_check_progress_initial", visible=False)
@ -413,13 +448,16 @@ def setup_progressbar(progressbar, preview, id_part):
fn=lambda: check_progress_call_initial(id_part), fn=lambda: check_progress_call_initial(id_part),
show_progress=False, show_progress=False,
inputs=[], inputs=[],
outputs=[progressbar, preview, preview], outputs=[progressbar, preview, preview, textinfo],
) )
def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): def create_ui(wrap_gradio_gpu_call):
import modules.img2img
import modules.txt2img
with gr.Blocks(analytics_enabled=False) as txt2img_interface: with gr.Blocks(analytics_enabled=False) as txt2img_interface:
txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, txt2img_prompt_style_apply, txt2img_save_style, paste = create_toprow(is_img2img=False) txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, txt2img_prompt_style_apply, txt2img_save_style, paste, token_counter, token_button = create_toprow(is_img2img=False)
dummy_component = gr.Label(visible=False) dummy_component = gr.Label(visible=False)
with gr.Row(elem_id='txt2img_progress_row'): with gr.Row(elem_id='txt2img_progress_row'):
@ -483,7 +521,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
txt2img_args = dict( txt2img_args = dict(
fn=txt2img, fn=wrap_gradio_gpu_call(modules.txt2img.txt2img),
_js="submit", _js="submit",
inputs=[ inputs=[
txt2img_prompt, txt2img_prompt,
@ -539,6 +577,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
roll.click( roll.click(
fn=roll_artist, fn=roll_artist,
_js="update_txt2img_tokens",
inputs=[ inputs=[
txt2img_prompt, txt2img_prompt,
], ],
@ -567,9 +606,10 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
(hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)), (hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)),
] ]
modules.generation_parameters_copypaste.connect_paste(paste, txt2img_paste_fields, txt2img_prompt) modules.generation_parameters_copypaste.connect_paste(paste, txt2img_paste_fields, txt2img_prompt)
token_button.click(fn=update_token_counter, inputs=[txt2img_prompt, steps], outputs=[token_counter])
with gr.Blocks(analytics_enabled=False) as img2img_interface: with gr.Blocks(analytics_enabled=False) as img2img_interface:
img2img_prompt, roll, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_prompt_style_apply, img2img_save_style, paste = create_toprow(is_img2img=True) img2img_prompt, roll, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_prompt_style_apply, img2img_save_style, paste, token_counter, token_button = create_toprow(is_img2img=True)
with gr.Row(elem_id='img2img_progress_row'): with gr.Row(elem_id='img2img_progress_row'):
with gr.Column(scale=1): with gr.Column(scale=1):
@ -585,7 +625,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
with gr.Tabs(elem_id="mode_img2img") as tabs_img2img_mode: with gr.Tabs(elem_id="mode_img2img") as tabs_img2img_mode:
with gr.TabItem('img2img', id='img2img'): with gr.TabItem('img2img', id='img2img'):
init_img = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil") init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_img2img_tool)
with gr.TabItem('Inpaint', id='inpaint'): with gr.TabItem('Inpaint', id='inpaint'):
init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA") init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA")
@ -599,7 +639,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
mask_mode = gr.Radio(label="Mask mode", show_label=False, choices=["Draw mask", "Upload mask"], type="index", value="Draw mask", elem_id="mask_mode") mask_mode = gr.Radio(label="Mask mode", show_label=False, choices=["Draw mask", "Upload mask"], type="index", value="Draw mask", elem_id="mask_mode")
inpainting_mask_invert = gr.Radio(label='Masking mode', show_label=False, choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index") inpainting_mask_invert = gr.Radio(label='Masking mode', show_label=False, choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index")
inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='fill', type="index") inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='original', type="index")
with gr.Row(): with gr.Row():
inpaint_full_res = gr.Checkbox(label='Inpaint at full resolution', value=False) inpaint_full_res = gr.Checkbox(label='Inpaint at full resolution', value=False)
@ -607,7 +647,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
with gr.TabItem('Batch img2img', id='batch'): with gr.TabItem('Batch img2img', id='batch'):
hidden = '<br>Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else '' hidden = '<br>Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else ''
gr.HTML(f"<p class=\"text-gray-500\">Process images in a directory on the same machine where the server is running.{hidden}</p>") gr.HTML(f"<p class=\"text-gray-500\">Process images in a directory on the same machine where the server is running.<br>Use an empty output directory to save pictures normally instead of writing to the output directory.{hidden}</p>")
img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs) img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs)
img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs) img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs)
@ -675,7 +715,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
) )
img2img_args = dict( img2img_args = dict(
fn=img2img, fn=wrap_gradio_gpu_call(modules.img2img.img2img),
_js="submit_img2img", _js="submit_img2img",
inputs=[ inputs=[
dummy_component, dummy_component,
@ -743,6 +783,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
roll.click( roll.click(
fn=roll_artist, fn=roll_artist,
_js="update_img2img_tokens",
inputs=[ inputs=[
img2img_prompt, img2img_prompt,
], ],
@ -753,6 +794,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
prompts = [(txt2img_prompt, txt2img_negative_prompt), (img2img_prompt, img2img_negative_prompt)] prompts = [(txt2img_prompt, txt2img_negative_prompt), (img2img_prompt, img2img_negative_prompt)]
style_dropdowns = [(txt2img_prompt_style, txt2img_prompt_style2), (img2img_prompt_style, img2img_prompt_style2)] style_dropdowns = [(txt2img_prompt_style, txt2img_prompt_style2), (img2img_prompt_style, img2img_prompt_style2)]
style_js_funcs = ["update_txt2img_tokens", "update_img2img_tokens"]
for button, (prompt, negative_prompt) in zip([txt2img_save_style, img2img_save_style], prompts): for button, (prompt, negative_prompt) in zip([txt2img_save_style, img2img_save_style], prompts):
button.click( button.click(
@ -764,9 +806,10 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
outputs=[txt2img_prompt_style, img2img_prompt_style, txt2img_prompt_style2, img2img_prompt_style2], outputs=[txt2img_prompt_style, img2img_prompt_style, txt2img_prompt_style2, img2img_prompt_style2],
) )
for button, (prompt, negative_prompt), (style1, style2) in zip([txt2img_prompt_style_apply, img2img_prompt_style_apply], prompts, style_dropdowns): for button, (prompt, negative_prompt), (style1, style2), js_func in zip([txt2img_prompt_style_apply, img2img_prompt_style_apply], prompts, style_dropdowns, style_js_funcs):
button.click( button.click(
fn=apply_styles, fn=apply_styles,
_js=js_func,
inputs=[prompt, negative_prompt, style1, style2], inputs=[prompt, negative_prompt, style1, style2],
outputs=[prompt, negative_prompt, style1, style2], outputs=[prompt, negative_prompt, style1, style2],
) )
@ -789,6 +832,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
(denoising_strength, "Denoising strength"), (denoising_strength, "Denoising strength"),
] ]
modules.generation_parameters_copypaste.connect_paste(paste, img2img_paste_fields, img2img_prompt) modules.generation_parameters_copypaste.connect_paste(paste, img2img_paste_fields, img2img_prompt)
token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter])
with gr.Blocks(analytics_enabled=False) as extras_interface: with gr.Blocks(analytics_enabled=False) as extras_interface:
with gr.Row().style(equal_height=False): with gr.Row().style(equal_height=False):
@ -828,7 +872,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
open_extras_folder = gr.Button('Open output directory', elem_id=button_id) open_extras_folder = gr.Button('Open output directory', elem_id=button_id)
submit.click( submit.click(
fn=run_extras, fn=wrap_gradio_gpu_call(modules.extras.run_extras),
_js="get_extras_tab_index", _js="get_extras_tab_index",
inputs=[ inputs=[
dummy_component, dummy_component,
@ -878,7 +922,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
pnginfo_send_to_img2img = gr.Button('Send to img2img') pnginfo_send_to_img2img = gr.Button('Send to img2img')
image.change( image.change(
fn=wrap_gradio_call(run_pnginfo), fn=wrap_gradio_call(modules.extras.run_pnginfo),
inputs=[image], inputs=[image],
outputs=[html, generation_info, html2], outputs=[html, generation_info, html2],
) )
@ -887,7 +931,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
with gr.Row().style(equal_height=False): with gr.Row().style(equal_height=False):
with gr.Column(variant='panel'): with gr.Column(variant='panel'):
gr.HTML(value="<p>A merger of the two checkpoints will be generated in your <b>checkpoint</b> directory.</p>") gr.HTML(value="<p>A merger of the two checkpoints will be generated in your <b>checkpoint</b> directory.</p>")
with gr.Row(): with gr.Row():
primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary Model Name") primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary Model Name")
secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary Model Name") secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary Model Name")
@ -896,10 +940,134 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
interp_method = gr.Radio(choices=["Weighted Sum", "Sigmoid", "Inverse Sigmoid"], value="Weighted Sum", label="Interpolation Method") interp_method = gr.Radio(choices=["Weighted Sum", "Sigmoid", "Inverse Sigmoid"], value="Weighted Sum", label="Interpolation Method")
save_as_half = gr.Checkbox(value=False, label="Safe as float16") save_as_half = gr.Checkbox(value=False, label="Safe as float16")
modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary') modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary')
with gr.Column(variant='panel'): with gr.Column(variant='panel'):
submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False) submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False)
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
with gr.Blocks() as textual_inversion_interface:
with gr.Row().style(equal_height=False):
with gr.Column():
with gr.Group():
gr.HTML(value="<p style='margin-bottom: 0.7em'>See <b><a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\">wiki</a></b> for detailed explanation.</p>")
gr.HTML(value="<p style='margin-bottom: 0.7em'>Create a new embedding</p>")
new_embedding_name = gr.Textbox(label="Name")
initialization_text = gr.Textbox(label="Initialization text", value="*")
nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1)
with gr.Row():
with gr.Column(scale=3):
gr.HTML(value="")
with gr.Column():
create_embedding = gr.Button(value="Create", variant='primary')
with gr.Group():
gr.HTML(value="<p style='margin-bottom: 0.7em'>Preprocess images</p>")
process_src = gr.Textbox(label='Source directory')
process_dst = gr.Textbox(label='Destination directory')
with gr.Row():
process_flip = gr.Checkbox(label='Flip')
process_split = gr.Checkbox(label='Split into two')
process_caption = gr.Checkbox(label='Add caption')
with gr.Row():
with gr.Column(scale=3):
gr.HTML(value="")
with gr.Column():
run_preprocess = gr.Button(value="Preprocess", variant='primary')
with gr.Group():
gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding; must specify a directory with a set of 512x512 images</p>")
train_embedding_name = gr.Dropdown(label='Embedding', choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
learn_rate = gr.Number(label='Learning rate', value=5.0e-03)
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"))
steps = gr.Number(label='Max steps', value=100000, precision=0)
create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0)
save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0)
with gr.Row():
with gr.Column(scale=2):
gr.HTML(value="")
with gr.Column():
with gr.Row():
interrupt_training = gr.Button(value="Interrupt")
train_embedding = gr.Button(value="Train", variant='primary')
with gr.Column():
progressbar = gr.HTML(elem_id="ti_progressbar")
ti_output = gr.Text(elem_id="ti_output", value="", show_label=False)
ti_gallery = gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery').style(grid=4)
ti_preview = gr.Image(elem_id='ti_preview', visible=False)
ti_progress = gr.HTML(elem_id="ti_progress", value="")
ti_outcome = gr.HTML(elem_id="ti_error", value="")
setup_progressbar(progressbar, ti_preview, 'ti', textinfo=ti_progress)
create_embedding.click(
fn=modules.textual_inversion.ui.create_embedding,
inputs=[
new_embedding_name,
initialization_text,
nvpt,
],
outputs=[
train_embedding_name,
ti_output,
ti_outcome,
]
)
run_preprocess.click(
fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.preprocess, extra_outputs=[gr.update()]),
_js="start_training_textual_inversion",
inputs=[
process_src,
process_dst,
process_flip,
process_split,
process_caption,
],
outputs=[
ti_output,
ti_outcome,
],
)
train_embedding.click(
fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.train_embedding, extra_outputs=[gr.update()]),
_js="start_training_textual_inversion",
inputs=[
train_embedding_name,
learn_rate,
dataset_directory,
log_directory,
steps,
create_image_every,
save_embedding_every,
template_file,
],
outputs=[
ti_output,
ti_outcome,
]
)
interrupt_training.click(
fn=lambda: shared.state.interrupt(),
inputs=[],
outputs=[],
)
def create_setting_component(key): def create_setting_component(key):
def fun(): def fun():
return opts.data[key] if key in opts.data else opts.data_labels[key].default return opts.data[key] if key in opts.data else opts.data_labels[key].default
@ -1002,6 +1170,32 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
_js='function(){}' _js='function(){}'
) )
with gr.Row():
reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary')
restart_gradio = gr.Button(value='Restart Gradio and Refresh components (Custom Scripts, ui.py, js and css only)', variant='primary')
def reload_scripts():
modules.scripts.reload_script_body_only()
reload_script_bodies.click(
fn=reload_scripts,
inputs=[],
outputs=[],
_js='function(){}'
)
def request_restart():
shared.state.interrupt()
settings_interface.gradio_ref.do_restart = True
restart_gradio.click(
fn=request_restart,
inputs=[],
outputs=[],
_js='function(){restart_reload()}'
)
if column is not None: if column is not None:
column.__exit__() column.__exit__()
@ -1011,6 +1205,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
(extras_interface, "Extras", "extras"), (extras_interface, "Extras", "extras"),
(pnginfo_interface, "PNG Info", "pnginfo"), (pnginfo_interface, "PNG Info", "pnginfo"),
(modelmerger_interface, "Checkpoint Merger", "modelmerger"), (modelmerger_interface, "Checkpoint Merger", "modelmerger"),
(textual_inversion_interface, "Textual inversion", "ti"),
(settings_interface, "Settings", "settings"), (settings_interface, "Settings", "settings"),
] ]
@ -1026,7 +1221,9 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
css += css_hide_progressbar css += css_hide_progressbar
with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo: with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo:
settings_interface.gradio_ref = demo
with gr.Tabs() as tabs: with gr.Tabs() as tabs:
for interface, label, ifid in interfaces: for interface, label, ifid in interfaces:
with gr.TabItem(label, id=ifid): with gr.TabItem(label, id=ifid):
@ -1044,11 +1241,11 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
def modelmerger(*args): def modelmerger(*args):
try: try:
results = run_modelmerger(*args) results = modules.extras.run_modelmerger(*args)
except Exception as e: except Exception as e:
print("Error loading/saving model file:", file=sys.stderr) print("Error loading/saving model file:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr) print(traceback.format_exc(), file=sys.stderr)
modules.sd_models.list_models() #To remove the potentially missing models from the list modules.sd_models.list_models() # to remove the potentially missing models from the list
return ["Error loading/saving model file. It doesn't exist or the name contains illegal characters"] + [gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(3)] return ["Error loading/saving model file. It doesn't exist or the name contains illegal characters"] + [gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(3)]
return results return results
@ -1206,12 +1403,12 @@ for filename in sorted(os.listdir(jsdir)):
javascript += f"\n<script>{jsfile.read()}</script>" javascript += f"\n<script>{jsfile.read()}</script>"
def template_response(*args, **kwargs): if 'gradio_routes_templates_response' not in globals():
res = gradio_routes_templates_response(*args, **kwargs) def template_response(*args, **kwargs):
res.body = res.body.replace(b'</head>', f'{javascript}</head>'.encode("utf8")) res = gradio_routes_templates_response(*args, **kwargs)
res.init_headers() res.body = res.body.replace(b'</head>', f'{javascript}</head>'.encode("utf8"))
return res res.init_headers()
return res
gradio_routes_templates_response = gradio.routes.templates.TemplateResponse
gradio_routes_templates_response = gradio.routes.templates.TemplateResponse gradio.routes.templates.TemplateResponse = template_response
gradio.routes.templates.TemplateResponse = template_response

View file

@ -13,14 +13,13 @@ Pillow
pytorch_lightning pytorch_lightning
realesrgan realesrgan
scikit-image>=0.19 scikit-image>=0.19
git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379
timm==0.4.12 timm==0.4.12
transformers==4.19.2 transformers==4.19.2
torch torch
einops einops
jsonmerge jsonmerge
clean-fid clean-fid
git+https://github.com/openai/CLIP@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1
resize-right resize-right
torchdiffeq torchdiffeq
kornia kornia
lark

View file

@ -18,7 +18,7 @@ piexif==1.1.3
einops==0.4.1 einops==0.4.1
jsonmerge==1.8.0 jsonmerge==1.8.0
clean-fid==0.1.29 clean-fid==0.1.29
git+https://github.com/openai/CLIP@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1
resize-right==0.0.2 resize-right==0.0.2
torchdiffeq==0.2.3 torchdiffeq==0.2.3
kornia==0.6.7 kornia==0.6.7
lark==1.1.2

View file

@ -8,7 +8,6 @@ import gradio as gr
from modules import processing, shared, sd_samplers, prompt_parser from modules import processing, shared, sd_samplers, prompt_parser
from modules.processing import Processed from modules.processing import Processed
from modules.sd_samplers import samplers
from modules.shared import opts, cmd_opts, state from modules.shared import opts, cmd_opts, state
import torch import torch
@ -159,7 +158,7 @@ class Script(scripts.Script):
combined_noise = ((1 - randomness) * rec_noise + randomness * rand_noise) / ((randomness**2 + (1-randomness)**2) ** 0.5) combined_noise = ((1 - randomness) * rec_noise + randomness * rand_noise) / ((randomness**2 + (1-randomness)**2) ** 0.5)
sampler = samplers[p.sampler_index].constructor(p.sd_model) sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, p.sampler_index, p.sd_model)
sigmas = sampler.model_wrap.get_sigmas(p.steps) sigmas = sampler.model_wrap.get_sigmas(p.steps)

View file

@ -11,46 +11,8 @@ from modules import images, processing, devices
from modules.processing import Processed, process_images from modules.processing import Processed, process_images
from modules.shared import opts, cmd_opts, state from modules.shared import opts, cmd_opts, state
# https://github.com/parlance-zz/g-diffuser-bot
def expand(x, dir, amount, power=0.75):
is_left = dir == 3
is_right = dir == 1
is_up = dir == 0
is_down = dir == 2
if is_left or is_right:
noise = np.zeros((x.shape[0], amount, 3), dtype=float)
indexes = np.random.random((x.shape[0], amount)) ** power * (1 - np.arange(amount) / amount)
if is_right:
indexes = 1 - indexes
indexes = (indexes * (x.shape[1] - 1)).astype(int)
for row in range(x.shape[0]):
if is_left:
noise[row] = x[row][indexes[row]]
else:
noise[row] = np.flip(x[row][indexes[row]], axis=0)
x = np.concatenate([noise, x] if is_left else [x, noise], axis=1)
return x
if is_up or is_down:
noise = np.zeros((amount, x.shape[1], 3), dtype=float)
indexes = np.random.random((x.shape[1], amount)) ** power * (1 - np.arange(amount) / amount)
if is_down:
indexes = 1 - indexes
indexes = (indexes * x.shape[0] - 1).astype(int)
for row in range(x.shape[1]):
if is_up:
noise[:, row] = x[:, row][indexes[row]]
else:
noise[:, row] = np.flip(x[:, row][indexes[row]], axis=0)
x = np.concatenate([noise, x] if is_up else [x, noise], axis=0)
return x
# this function is taken from https://github.com/parlance-zz/g-diffuser-bot
def get_matched_noise(_np_src_image, np_mask_rgb, noise_q=1, color_variation=0.05): def get_matched_noise(_np_src_image, np_mask_rgb, noise_q=1, color_variation=0.05):
# helper fft routines that keep ortho normalization and auto-shift before and after fft # helper fft routines that keep ortho normalization and auto-shift before and after fft
def _fft2(data): def _fft2(data):
@ -123,8 +85,11 @@ def get_matched_noise(_np_src_image, np_mask_rgb, noise_q=1, color_variation=0.0
src_dist = np.absolute(src_fft) src_dist = np.absolute(src_fft)
src_phase = src_fft / src_dist src_phase = src_fft / src_dist
# create a generator with a static seed to make outpainting deterministic / only follow global seed
rng = np.random.default_rng(0)
noise_window = _get_gaussian_window(width, height, mode=1) # start with simple gaussian noise noise_window = _get_gaussian_window(width, height, mode=1) # start with simple gaussian noise
noise_rgb = np.random.random_sample((width, height, num_channels)) noise_rgb = rng.random((width, height, num_channels))
noise_grey = (np.sum(noise_rgb, axis=2) / 3.) noise_grey = (np.sum(noise_rgb, axis=2) / 3.)
noise_rgb *= color_variation # the colorfulness of the starting noise is blended to greyscale with a parameter noise_rgb *= color_variation # the colorfulness of the starting noise is blended to greyscale with a parameter
for c in range(num_channels): for c in range(num_channels):

View file

@ -34,7 +34,11 @@ class Script(scripts.Script):
seed = p.seed seed = p.seed
init_img = p.init_images[0] init_img = p.init_images[0]
img = upscaler.scaler.upscale(init_img, 2, upscaler.data_path)
if(upscaler.name != "None"):
img = upscaler.scaler.upscale(init_img, 2, upscaler.data_path)
else:
img = init_img
devices.torch_gc() devices.torch_gc()

View file

@ -1,7 +1,9 @@
from collections import namedtuple from collections import namedtuple
from copy import copy from copy import copy
from itertools import permutations, chain
import random import random
import csv
from io import StringIO
from PIL import Image from PIL import Image
import numpy as np import numpy as np
@ -29,6 +31,31 @@ def apply_prompt(p, x, xs):
p.negative_prompt = p.negative_prompt.replace(xs[0], x) p.negative_prompt = p.negative_prompt.replace(xs[0], x)
def apply_order(p, x, xs):
token_order = []
# Initally grab the tokens from the prompt, so they can be replaced in order of earliest seen
for token in x:
token_order.append((p.prompt.find(token), token))
token_order.sort(key=lambda t: t[0])
prompt_parts = []
# Split the prompt up, taking out the tokens
for _, token in token_order:
n = p.prompt.find(token)
prompt_parts.append(p.prompt[0:n])
p.prompt = p.prompt[n + len(token):]
# Rebuild the prompt with the tokens in the order we want
prompt_tmp = ""
for idx, part in enumerate(prompt_parts):
prompt_tmp += part
prompt_tmp += x[idx]
p.prompt = prompt_tmp + p.prompt
samplers_dict = {} samplers_dict = {}
for i, sampler in enumerate(modules.sd_samplers.samplers): for i, sampler in enumerate(modules.sd_samplers.samplers):
samplers_dict[sampler.name.lower()] = i samplers_dict[sampler.name.lower()] = i
@ -60,16 +87,26 @@ def format_value_add_label(p, opt, x):
def format_value(p, opt, x): def format_value(p, opt, x):
if type(x) == float: if type(x) == float:
x = round(x, 8) x = round(x, 8)
return x return x
def format_value_join_list(p, opt, x):
return ", ".join(x)
def do_nothing(p, x, xs): def do_nothing(p, x, xs):
pass pass
def format_nothing(p, opt, x): def format_nothing(p, opt, x):
return "" return ""
def str_permutations(x):
"""dummy function for specifying it in AxisOption's type when you want to get a list of permutations"""
return x
AxisOption = namedtuple("AxisOption", ["label", "type", "apply", "format_value"]) AxisOption = namedtuple("AxisOption", ["label", "type", "apply", "format_value"])
AxisOptionImg2Img = namedtuple("AxisOptionImg2Img", ["label", "type", "apply", "format_value"]) AxisOptionImg2Img = namedtuple("AxisOptionImg2Img", ["label", "type", "apply", "format_value"])
@ -82,6 +119,7 @@ axis_options = [
AxisOption("Steps", int, apply_field("steps"), format_value_add_label), AxisOption("Steps", int, apply_field("steps"), format_value_add_label),
AxisOption("CFG Scale", float, apply_field("cfg_scale"), format_value_add_label), AxisOption("CFG Scale", float, apply_field("cfg_scale"), format_value_add_label),
AxisOption("Prompt S/R", str, apply_prompt, format_value), AxisOption("Prompt S/R", str, apply_prompt, format_value),
AxisOption("Prompt order", str_permutations, apply_order, format_value_join_list),
AxisOption("Sampler", str, apply_sampler, format_value), AxisOption("Sampler", str, apply_sampler, format_value),
AxisOption("Checkpoint name", str, apply_checkpoint, format_value), AxisOption("Checkpoint name", str, apply_checkpoint, format_value),
AxisOption("Sigma Churn", float, apply_field("s_churn"), format_value_add_label), AxisOption("Sigma Churn", float, apply_field("s_churn"), format_value_add_label),
@ -159,7 +197,7 @@ class Script(scripts.Script):
if opt.label == 'Nothing': if opt.label == 'Nothing':
return [0] return [0]
valslist = [x.strip() for x in vals.split(",")] valslist = [x.strip() for x in chain.from_iterable(csv.reader(StringIO(vals)))]
if opt.type == int: if opt.type == int:
valslist_ext = [] valslist_ext = []
@ -206,6 +244,8 @@ class Script(scripts.Script):
valslist_ext.append(val) valslist_ext.append(val)
valslist = valslist_ext valslist = valslist_ext
elif opt.type == str_permutations:
valslist = list(permutations(valslist))
valslist = [opt.type(x) for x in valslist] valslist = [opt.type(x) for x in valslist]

View file

@ -23,7 +23,7 @@
text-align: right; text-align: right;
} }
#generate{ #txt2img_generate, #img2img_generate {
min-height: 4.5em; min-height: 4.5em;
} }
@ -157,7 +157,7 @@ button{
max-width: 10em; max-width: 10em;
} }
#txt2img_preview, #img2img_preview{ #txt2img_preview, #img2img_preview, #ti_preview{
position: absolute; position: absolute;
width: 320px; width: 320px;
left: 0; left: 0;
@ -172,18 +172,18 @@ button{
} }
@media screen and (min-width: 768px) { @media screen and (min-width: 768px) {
#txt2img_preview, #img2img_preview { #txt2img_preview, #img2img_preview, #ti_preview {
position: absolute; position: absolute;
} }
} }
@media screen and (max-width: 767px) { @media screen and (max-width: 767px) {
#txt2img_preview, #img2img_preview { #txt2img_preview, #img2img_preview, #ti_preview {
position: relative; position: relative;
} }
} }
#txt2img_preview div.left-0.top-0, #img2img_preview div.left-0.top-0{ #txt2img_preview div.left-0.top-0, #img2img_preview div.left-0.top-0, #ti_preview div.left-0.top-0{
display: none; display: none;
} }
@ -247,7 +247,7 @@ input[type="range"]{
#txt2img_negative_prompt, #img2img_negative_prompt{ #txt2img_negative_prompt, #img2img_negative_prompt{
} }
#txt2img_progressbar, #img2img_progressbar{ #txt2img_progressbar, #img2img_progressbar, #ti_progressbar{
position: absolute; position: absolute;
z-index: 1000; z-index: 1000;
right: 0; right: 0;
@ -407,3 +407,7 @@ input[type="range"]{
.gallery-item { .gallery-item {
--tw-bg-opacity: 0 !important; --tw-bg-opacity: 0 !important;
} }
#img2img_image div.h-60{
height: 480px;
}

View file

@ -0,0 +1,19 @@
a painting, art by [name]
a rendering, art by [name]
a cropped painting, art by [name]
the painting, art by [name]
a clean painting, art by [name]
a dirty painting, art by [name]
a dark painting, art by [name]
a picture, art by [name]
a cool painting, art by [name]
a close-up painting, art by [name]
a bright painting, art by [name]
a cropped painting, art by [name]
a good painting, art by [name]
a close-up painting, art by [name]
a rendition, art by [name]
a nice painting, art by [name]
a small painting, art by [name]
a weird painting, art by [name]
a large painting, art by [name]

View file

@ -0,0 +1,19 @@
a painting of [filewords], art by [name]
a rendering of [filewords], art by [name]
a cropped painting of [filewords], art by [name]
the painting of [filewords], art by [name]
a clean painting of [filewords], art by [name]
a dirty painting of [filewords], art by [name]
a dark painting of [filewords], art by [name]
a picture of [filewords], art by [name]
a cool painting of [filewords], art by [name]
a close-up painting of [filewords], art by [name]
a bright painting of [filewords], art by [name]
a cropped painting of [filewords], art by [name]
a good painting of [filewords], art by [name]
a close-up painting of [filewords], art by [name]
a rendition of [filewords], art by [name]
a nice painting of [filewords], art by [name]
a small painting of [filewords], art by [name]
a weird painting of [filewords], art by [name]
a large painting of [filewords], art by [name]

View file

@ -0,0 +1,27 @@
a photo of a [name]
a rendering of a [name]
a cropped photo of the [name]
the photo of a [name]
a photo of a clean [name]
a photo of a dirty [name]
a dark photo of the [name]
a photo of my [name]
a photo of the cool [name]
a close-up photo of a [name]
a bright photo of the [name]
a cropped photo of a [name]
a photo of the [name]
a good photo of the [name]
a photo of one [name]
a close-up photo of the [name]
a rendition of the [name]
a photo of the clean [name]
a rendition of a [name]
a photo of a nice [name]
a good photo of a [name]
a photo of the nice [name]
a photo of the small [name]
a photo of the weird [name]
a photo of the large [name]
a photo of a cool [name]
a photo of a small [name]

View file

@ -0,0 +1,27 @@
a photo of a [name], [filewords]
a rendering of a [name], [filewords]
a cropped photo of the [name], [filewords]
the photo of a [name], [filewords]
a photo of a clean [name], [filewords]
a photo of a dirty [name], [filewords]
a dark photo of the [name], [filewords]
a photo of my [name], [filewords]
a photo of the cool [name], [filewords]
a close-up photo of a [name], [filewords]
a bright photo of the [name], [filewords]
a cropped photo of a [name], [filewords]
a photo of the [name], [filewords]
a good photo of the [name], [filewords]
a photo of one [name], [filewords]
a close-up photo of the [name], [filewords]
a rendition of the [name], [filewords]
a photo of the clean [name], [filewords]
a rendition of a [name], [filewords]
a photo of a nice [name], [filewords]
a good photo of a [name], [filewords]
a photo of the nice [name], [filewords]
a photo of the small [name], [filewords]
a photo of the weird [name], [filewords]
a photo of the large [name], [filewords]
a photo of a cool [name], [filewords]
a photo of a small [name], [filewords]

View file

@ -1,34 +1,35 @@
import os import os
import threading import threading
import time
from modules import devices import importlib
from modules.paths import script_path
import signal import signal
import threading import threading
import modules.paths
from modules.paths import script_path
from modules import devices, sd_samplers
import modules.codeformer_model as codeformer import modules.codeformer_model as codeformer
import modules.esrgan_model as esrgan
import modules.bsrgan_model as bsrgan
import modules.extras import modules.extras
import modules.face_restoration import modules.face_restoration
import modules.gfpgan_model as gfpgan import modules.gfpgan_model as gfpgan
import modules.img2img import modules.img2img
import modules.ldsr_model as ldsr
import modules.lowvram import modules.lowvram
import modules.realesrgan_model as realesrgan import modules.paths
import modules.scripts import modules.scripts
import modules.sd_hijack import modules.sd_hijack
import modules.sd_models import modules.sd_models
import modules.shared as shared import modules.shared as shared
import modules.swinir_model as swinir
import modules.txt2img import modules.txt2img
import modules.ui import modules.ui
from modules import devices
from modules import modelloader from modules import modelloader
from modules.paths import script_path from modules.paths import script_path
from modules.shared import cmd_opts from modules.shared import cmd_opts
modelloader.cleanup_models() modelloader.cleanup_models()
modules.sd_models.setup_model(cmd_opts.ckpt_dir) modules.sd_models.setup_model()
codeformer.setup_model(cmd_opts.codeformer_models_path) codeformer.setup_model(cmd_opts.codeformer_models_path)
gfpgan.setup_model(cmd_opts.gfpgan_models_path) gfpgan.setup_model(cmd_opts.gfpgan_models_path)
shared.face_restorers.append(modules.face_restoration.FaceRestoration()) shared.face_restorers.append(modules.face_restoration.FaceRestoration())
@ -46,7 +47,7 @@ def wrap_queued_call(func):
return f return f
def wrap_gradio_gpu_call(func): def wrap_gradio_gpu_call(func, extra_outputs=None):
def f(*args, **kwargs): def f(*args, **kwargs):
devices.torch_gc() devices.torch_gc()
@ -58,6 +59,7 @@ def wrap_gradio_gpu_call(func):
shared.state.current_image = None shared.state.current_image = None
shared.state.current_image_sampling_step = 0 shared.state.current_image_sampling_step = 0
shared.state.interrupted = False shared.state.interrupted = False
shared.state.textinfo = None
with queue_lock: with queue_lock:
res = func(*args, **kwargs) res = func(*args, **kwargs)
@ -69,7 +71,7 @@ def wrap_gradio_gpu_call(func):
return res return res
return modules.ui.wrap_gradio_call(f) return modules.ui.wrap_gradio_call(f, extra_outputs=extra_outputs)
modules.scripts.load_scripts(os.path.join(script_path, "scripts")) modules.scripts.load_scripts(os.path.join(script_path, "scripts"))
@ -86,22 +88,36 @@ def webui():
signal.signal(signal.SIGINT, sigint_handler) signal.signal(signal.SIGINT, sigint_handler)
demo = modules.ui.create_ui( while 1:
txt2img=wrap_gradio_gpu_call(modules.txt2img.txt2img),
img2img=wrap_gradio_gpu_call(modules.img2img.img2img), demo = modules.ui.create_ui(wrap_gradio_gpu_call=wrap_gradio_gpu_call)
run_extras=wrap_gradio_gpu_call(modules.extras.run_extras),
run_pnginfo=modules.extras.run_pnginfo, demo.launch(
run_modelmerger=modules.extras.run_modelmerger share=cmd_opts.share,
) server_name="0.0.0.0" if cmd_opts.listen else None,
server_port=cmd_opts.port,
debug=cmd_opts.gradio_debug,
auth=[tuple(cred.split(':')) for cred in cmd_opts.gradio_auth.strip('"').split(',')] if cmd_opts.gradio_auth else None,
inbrowser=cmd_opts.autolaunch,
prevent_thread_lock=True
)
while 1:
time.sleep(0.5)
if getattr(demo, 'do_restart', False):
time.sleep(0.5)
demo.close()
time.sleep(0.5)
break
sd_samplers.set_samplers()
print('Reloading Custom Scripts')
modules.scripts.reload_scripts(os.path.join(script_path, "scripts"))
print('Reloading modules: modules.ui')
importlib.reload(modules.ui)
print('Restarting Gradio')
demo.launch(
share=cmd_opts.share,
server_name="0.0.0.0" if cmd_opts.listen else None,
server_port=cmd_opts.port,
debug=cmd_opts.gradio_debug,
auth=[tuple(cred.split(':')) for cred in cmd_opts.gradio_auth.strip('"').split(',')] if cmd_opts.gradio_auth else None,
inbrowser=cmd_opts.autolaunch,
)
if __name__ == "__main__": if __name__ == "__main__":