Merge branch 'master' into saving
This commit is contained in:
commit
f28ce3e3a1
20 changed files with 472 additions and 131 deletions
2
.gitignore
vendored
2
.gitignore
vendored
|
@ -4,7 +4,7 @@ __pycache__
|
||||||
/venv
|
/venv
|
||||||
/tmp
|
/tmp
|
||||||
/model.ckpt
|
/model.ckpt
|
||||||
/models/*.ckpt
|
/models/**/*.ckpt
|
||||||
/GFPGANv1.3.pth
|
/GFPGANv1.3.pth
|
||||||
/gfpgan/weights/*.pth
|
/gfpgan/weights/*.pth
|
||||||
/ui-config.json
|
/ui-config.json
|
||||||
|
|
|
@ -15,6 +15,7 @@ titles = {
|
||||||
"\u267b\ufe0f": "Reuse seed from last generation, mostly useful if it was randomed",
|
"\u267b\ufe0f": "Reuse seed from last generation, mostly useful if it was randomed",
|
||||||
"\u{1f3a8}": "Add a random artist to the prompt.",
|
"\u{1f3a8}": "Add a random artist to the prompt.",
|
||||||
"\u2199\ufe0f": "Read generation parameters from prompt into user interface.",
|
"\u2199\ufe0f": "Read generation parameters from prompt into user interface.",
|
||||||
|
"\uD83D\uDCC2": "Open images output directory",
|
||||||
|
|
||||||
"Inpaint a part of image": "Draw a mask over an image, and the script will regenerate the masked area with content according to prompt",
|
"Inpaint a part of image": "Draw a mask over an image, and the script will regenerate the masked area with content according to prompt",
|
||||||
"SD upscale": "Upscale image normally, split result into tiles, improve each tile using img2img, merge whole image back",
|
"SD upscale": "Upscale image normally, split result into tiles, improve each tile using img2img, merge whole image back",
|
||||||
|
|
|
@ -182,4 +182,23 @@ onUiUpdate(function(){
|
||||||
});
|
});
|
||||||
|
|
||||||
json_elem.parentElement.style.display="none"
|
json_elem.parentElement.style.display="none"
|
||||||
|
|
||||||
|
if (!txt2img_textarea) {
|
||||||
|
txt2img_textarea = gradioApp().querySelector("#txt2img_prompt > label > textarea");
|
||||||
|
txt2img_textarea?.addEventListener("input", () => update_token_counter("txt2img_token_button"));
|
||||||
|
}
|
||||||
|
if (!img2img_textarea) {
|
||||||
|
img2img_textarea = gradioApp().querySelector("#img2img_prompt > label > textarea");
|
||||||
|
img2img_textarea?.addEventListener("input", () => update_token_counter("img2img_token_button"));
|
||||||
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
let txt2img_textarea, img2img_textarea = undefined;
|
||||||
|
let wait_time = 800
|
||||||
|
let token_timeout;
|
||||||
|
|
||||||
|
function update_token_counter(button_id) {
|
||||||
|
if (token_timeout)
|
||||||
|
clearTimeout(token_timeout);
|
||||||
|
token_timeout = setTimeout(() => gradioApp().getElementById(button_id)?.click(), wait_time);
|
||||||
|
}
|
||||||
|
|
|
@ -15,14 +15,14 @@ torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.12.1+cu113
|
||||||
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
|
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
|
||||||
commandline_args = os.environ.get('COMMANDLINE_ARGS', "")
|
commandline_args = os.environ.get('COMMANDLINE_ARGS', "")
|
||||||
|
|
||||||
k_diffusion_package = os.environ.get('K_DIFFUSION_PACKAGE', "git+https://github.com/crowsonkb/k-diffusion.git@1a0703dfb7d24d8806267c3e7ccc4caf67fd1331")
|
|
||||||
gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379")
|
gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379")
|
||||||
|
|
||||||
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "69ae4b35e0a0f6ee1af8bb9a5d0016ccb27e36dc")
|
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "69ae4b35e0a0f6ee1af8bb9a5d0016ccb27e36dc")
|
||||||
taming_transformers_commit_hash = os.environ.get('TAMING_TRANSFORMERS_COMMIT_HASH', "24268930bf1dce879235a7fddd0b2355b84d7ea6")
|
taming_transformers_commit_hash = os.environ.get('TAMING_TRANSFORMERS_COMMIT_HASH', "24268930bf1dce879235a7fddd0b2355b84d7ea6")
|
||||||
|
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "9e3002b7cd64df7870e08527b7664eb2f2f5f3f5")
|
||||||
codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
|
codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
|
||||||
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
|
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
|
||||||
ldsr_commit_hash = os.environ.get('LDSR_COMMIT_HASH',"abf33e7002d59d9085081bce93ec798dcabd49af")
|
ldsr_commit_hash = os.environ.get('LDSR_COMMIT_HASH', "abf33e7002d59d9085081bce93ec798dcabd49af")
|
||||||
|
|
||||||
args = shlex.split(commandline_args)
|
args = shlex.split(commandline_args)
|
||||||
|
|
||||||
|
@ -110,9 +110,6 @@ if not is_installed("torch") or not is_installed("torchvision"):
|
||||||
if not skip_torch_cuda_test:
|
if not skip_torch_cuda_test:
|
||||||
run_python("import torch; assert torch.cuda.is_available(), 'Torch is not able to use GPU; add --skip-torch-cuda-test to COMMANDLINE_ARGS variable to disable this check'")
|
run_python("import torch; assert torch.cuda.is_available(), 'Torch is not able to use GPU; add --skip-torch-cuda-test to COMMANDLINE_ARGS variable to disable this check'")
|
||||||
|
|
||||||
if not is_installed("k_diffusion.sampling"):
|
|
||||||
run_pip(f"install {k_diffusion_package}", "k-diffusion")
|
|
||||||
|
|
||||||
if not is_installed("gfpgan"):
|
if not is_installed("gfpgan"):
|
||||||
run_pip(f"install {gfpgan_package}", "gfpgan")
|
run_pip(f"install {gfpgan_package}", "gfpgan")
|
||||||
|
|
||||||
|
@ -120,6 +117,7 @@ os.makedirs(dir_repos, exist_ok=True)
|
||||||
|
|
||||||
git_clone("https://github.com/CompVis/stable-diffusion.git", repo_dir('stable-diffusion'), "Stable Diffusion", stable_diffusion_commit_hash)
|
git_clone("https://github.com/CompVis/stable-diffusion.git", repo_dir('stable-diffusion'), "Stable Diffusion", stable_diffusion_commit_hash)
|
||||||
git_clone("https://github.com/CompVis/taming-transformers.git", repo_dir('taming-transformers'), "Taming Transformers", taming_transformers_commit_hash)
|
git_clone("https://github.com/CompVis/taming-transformers.git", repo_dir('taming-transformers'), "Taming Transformers", taming_transformers_commit_hash)
|
||||||
|
git_clone("https://github.com/crowsonkb/k-diffusion.git", repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash)
|
||||||
git_clone("https://github.com/sczhou/CodeFormer.git", repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash)
|
git_clone("https://github.com/sczhou/CodeFormer.git", repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash)
|
||||||
git_clone("https://github.com/salesforce/BLIP.git", repo_dir('BLIP'), "BLIP", blip_commit_hash)
|
git_clone("https://github.com/salesforce/BLIP.git", repo_dir('BLIP'), "BLIP", blip_commit_hash)
|
||||||
# Using my repo until my changes are merged, as this makes interfacing with our version of SD-web a lot easier
|
# Using my repo until my changes are merged, as this makes interfacing with our version of SD-web a lot easier
|
||||||
|
|
|
@ -6,13 +6,14 @@ from PIL import Image
|
||||||
import torch
|
import torch
|
||||||
import tqdm
|
import tqdm
|
||||||
|
|
||||||
from modules import processing, shared, images, devices
|
from modules import processing, shared, images, devices, sd_models
|
||||||
from modules.shared import opts
|
from modules.shared import opts
|
||||||
import modules.gfpgan_model
|
import modules.gfpgan_model
|
||||||
from modules.ui import plaintext_to_html
|
from modules.ui import plaintext_to_html
|
||||||
import modules.codeformer_model
|
import modules.codeformer_model
|
||||||
import piexif
|
import piexif
|
||||||
import piexif.helper
|
import piexif.helper
|
||||||
|
import gradio as gr
|
||||||
|
|
||||||
|
|
||||||
cached_images = {}
|
cached_images = {}
|
||||||
|
@ -140,7 +141,7 @@ def run_pnginfo(image):
|
||||||
return '', geninfo, info
|
return '', geninfo, info
|
||||||
|
|
||||||
|
|
||||||
def run_modelmerger(primary_model_name, secondary_model_name, interp_method, interp_amount):
|
def run_modelmerger(primary_model_name, secondary_model_name, interp_method, interp_amount, save_as_half, custom_name):
|
||||||
# Linear interpolation (https://en.wikipedia.org/wiki/Linear_interpolation)
|
# Linear interpolation (https://en.wikipedia.org/wiki/Linear_interpolation)
|
||||||
def weighted_sum(theta0, theta1, alpha):
|
def weighted_sum(theta0, theta1, alpha):
|
||||||
return ((1 - alpha) * theta0) + (alpha * theta1)
|
return ((1 - alpha) * theta0) + (alpha * theta1)
|
||||||
|
@ -150,23 +151,20 @@ def run_modelmerger(primary_model_name, secondary_model_name, interp_method, int
|
||||||
alpha = alpha * alpha * (3 - (2 * alpha))
|
alpha = alpha * alpha * (3 - (2 * alpha))
|
||||||
return theta0 + ((theta1 - theta0) * alpha)
|
return theta0 + ((theta1 - theta0) * alpha)
|
||||||
|
|
||||||
if os.path.exists(primary_model_name):
|
# Inverse Smoothstep (https://en.wikipedia.org/wiki/Smoothstep)
|
||||||
primary_model_filename = primary_model_name
|
def inv_sigmoid(theta0, theta1, alpha):
|
||||||
primary_model_name = os.path.splitext(os.path.basename(primary_model_name))[0]
|
import math
|
||||||
else:
|
alpha = 0.5 - math.sin(math.asin(1.0 - 2.0 * alpha) / 3.0)
|
||||||
primary_model_filename = 'models/' + primary_model_name + '.ckpt'
|
return theta0 + ((theta1 - theta0) * alpha)
|
||||||
|
|
||||||
if os.path.exists(secondary_model_name):
|
primary_model_info = sd_models.checkpoints_list[primary_model_name]
|
||||||
secondary_model_filename = secondary_model_name
|
secondary_model_info = sd_models.checkpoints_list[secondary_model_name]
|
||||||
secondary_model_name = os.path.splitext(os.path.basename(secondary_model_name))[0]
|
|
||||||
else:
|
|
||||||
secondary_model_filename = 'models/' + secondary_model_name + '.ckpt'
|
|
||||||
|
|
||||||
print(f"Loading {primary_model_filename}...")
|
print(f"Loading {primary_model_info.filename}...")
|
||||||
primary_model = torch.load(primary_model_filename, map_location='cpu')
|
primary_model = torch.load(primary_model_info.filename, map_location='cpu')
|
||||||
|
|
||||||
print(f"Loading {secondary_model_filename}...")
|
print(f"Loading {secondary_model_info.filename}...")
|
||||||
secondary_model = torch.load(secondary_model_filename, map_location='cpu')
|
secondary_model = torch.load(secondary_model_info.filename, map_location='cpu')
|
||||||
|
|
||||||
theta_0 = primary_model['state_dict']
|
theta_0 = primary_model['state_dict']
|
||||||
theta_1 = secondary_model['state_dict']
|
theta_1 = secondary_model['state_dict']
|
||||||
|
@ -174,6 +172,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, interp_method, int
|
||||||
theta_funcs = {
|
theta_funcs = {
|
||||||
"Weighted Sum": weighted_sum,
|
"Weighted Sum": weighted_sum,
|
||||||
"Sigmoid": sigmoid,
|
"Sigmoid": sigmoid,
|
||||||
|
"Inverse Sigmoid": inv_sigmoid,
|
||||||
}
|
}
|
||||||
theta_func = theta_funcs[interp_method]
|
theta_func = theta_funcs[interp_method]
|
||||||
|
|
||||||
|
@ -181,14 +180,23 @@ def run_modelmerger(primary_model_name, secondary_model_name, interp_method, int
|
||||||
for key in tqdm.tqdm(theta_0.keys()):
|
for key in tqdm.tqdm(theta_0.keys()):
|
||||||
if 'model' in key and key in theta_1:
|
if 'model' in key and key in theta_1:
|
||||||
theta_0[key] = theta_func(theta_0[key], theta_1[key], (float(1.0) - interp_amount)) # Need to reverse the interp_amount to match the desired mix ration in the merged checkpoint
|
theta_0[key] = theta_func(theta_0[key], theta_1[key], (float(1.0) - interp_amount)) # Need to reverse the interp_amount to match the desired mix ration in the merged checkpoint
|
||||||
|
if save_as_half:
|
||||||
|
theta_0[key] = theta_0[key].half()
|
||||||
|
|
||||||
for key in theta_1.keys():
|
for key in theta_1.keys():
|
||||||
if 'model' in key and key not in theta_0:
|
if 'model' in key and key not in theta_0:
|
||||||
theta_0[key] = theta_1[key]
|
theta_0[key] = theta_1[key]
|
||||||
|
if save_as_half:
|
||||||
|
theta_0[key] = theta_0[key].half()
|
||||||
|
|
||||||
|
filename = primary_model_info.model_name + '_' + str(round(interp_amount, 2)) + '-' + secondary_model_info.model_name + '_' + str(round((float(1.0) - interp_amount), 2)) + '-' + interp_method.replace(" ", "_") + '-merged.ckpt'
|
||||||
|
filename = filename if custom_name == '' else (custom_name + '.ckpt')
|
||||||
|
output_modelname = os.path.join(shared.cmd_opts.ckpt_dir, filename)
|
||||||
|
|
||||||
output_modelname = 'models/' + primary_model_name + '_' + str(round(interp_amount,2)) + '-' + secondary_model_name + '_' + str(round((float(1.0) - interp_amount),2)) + '-' + interp_method.replace(" ", "_") + '-merged.ckpt'
|
|
||||||
print(f"Saving to {output_modelname}...")
|
print(f"Saving to {output_modelname}...")
|
||||||
torch.save(primary_model, output_modelname)
|
torch.save(primary_model, output_modelname)
|
||||||
|
|
||||||
|
sd_models.list_models()
|
||||||
|
|
||||||
print(f"Checkpoint saved.")
|
print(f"Checkpoint saved.")
|
||||||
return "Checkpoint saved to " + output_modelname
|
return ["Checkpoint saved to " + output_modelname] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(3)]
|
||||||
|
|
|
@ -20,6 +20,7 @@ path_dirs = [
|
||||||
(os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer'),
|
(os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer'),
|
||||||
(os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP'),
|
(os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP'),
|
||||||
(os.path.join(sd_path, '../latent-diffusion'), 'LDSR.py', 'LDSR'),
|
(os.path.join(sd_path, '../latent-diffusion'), 'LDSR.py', 'LDSR'),
|
||||||
|
(os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion'),
|
||||||
]
|
]
|
||||||
|
|
||||||
paths = {}
|
paths = {}
|
||||||
|
|
|
@ -49,7 +49,7 @@ def apply_color_correction(correction, image):
|
||||||
|
|
||||||
|
|
||||||
class StableDiffusionProcessing:
|
class StableDiffusionProcessing:
|
||||||
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", styles=None, seed=-1, subseed=-1, subseed_strength=0, seed_resize_from_h=-1, seed_resize_from_w=-1, seed_enable_extras=True, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None):
|
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", styles=None, seed=-1, subseed=-1, subseed_strength=0, seed_resize_from_h=-1, seed_resize_from_w=-1, seed_enable_extras=True, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None, eta=None):
|
||||||
self.sd_model = sd_model
|
self.sd_model = sd_model
|
||||||
self.outpath_samples: str = outpath_samples
|
self.outpath_samples: str = outpath_samples
|
||||||
self.outpath_grids: str = outpath_grids
|
self.outpath_grids: str = outpath_grids
|
||||||
|
@ -75,11 +75,11 @@ class StableDiffusionProcessing:
|
||||||
self.do_not_save_grid: bool = do_not_save_grid
|
self.do_not_save_grid: bool = do_not_save_grid
|
||||||
self.extra_generation_params: dict = extra_generation_params or {}
|
self.extra_generation_params: dict = extra_generation_params or {}
|
||||||
self.overlay_images = overlay_images
|
self.overlay_images = overlay_images
|
||||||
|
self.eta = eta
|
||||||
self.paste_to = None
|
self.paste_to = None
|
||||||
self.color_corrections = None
|
self.color_corrections = None
|
||||||
self.denoising_strength: float = 0
|
self.denoising_strength: float = 0
|
||||||
|
|
||||||
self.eta = opts.eta
|
|
||||||
self.ddim_discretize = opts.ddim_discretize
|
self.ddim_discretize = opts.ddim_discretize
|
||||||
self.s_churn = opts.s_churn
|
self.s_churn = opts.s_churn
|
||||||
self.s_tmin = opts.s_tmin
|
self.s_tmin = opts.s_tmin
|
||||||
|
@ -100,7 +100,7 @@ class StableDiffusionProcessing:
|
||||||
|
|
||||||
|
|
||||||
class Processed:
|
class Processed:
|
||||||
def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0):
|
def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None):
|
||||||
self.images = images_list
|
self.images = images_list
|
||||||
self.prompt = p.prompt
|
self.prompt = p.prompt
|
||||||
self.negative_prompt = p.negative_prompt
|
self.negative_prompt = p.negative_prompt
|
||||||
|
@ -139,6 +139,7 @@ class Processed:
|
||||||
self.all_prompts = all_prompts or [self.prompt]
|
self.all_prompts = all_prompts or [self.prompt]
|
||||||
self.all_seeds = all_seeds or [self.seed]
|
self.all_seeds = all_seeds or [self.seed]
|
||||||
self.all_subseeds = all_subseeds or [self.subseed]
|
self.all_subseeds = all_subseeds or [self.subseed]
|
||||||
|
self.infotexts = infotexts or [info]
|
||||||
|
|
||||||
def js(self):
|
def js(self):
|
||||||
obj = {
|
obj = {
|
||||||
|
@ -165,6 +166,7 @@ class Processed:
|
||||||
"denoising_strength": self.denoising_strength,
|
"denoising_strength": self.denoising_strength,
|
||||||
"extra_generation_params": self.extra_generation_params,
|
"extra_generation_params": self.extra_generation_params,
|
||||||
"index_of_first_image": self.index_of_first_image,
|
"index_of_first_image": self.index_of_first_image,
|
||||||
|
"infotexts": self.infotexts,
|
||||||
}
|
}
|
||||||
|
|
||||||
return json.dumps(obj)
|
return json.dumps(obj)
|
||||||
|
@ -269,6 +271,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
|
||||||
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
|
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
|
||||||
"Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
|
"Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
|
||||||
"Denoising strength": getattr(p, 'denoising_strength', None),
|
"Denoising strength": getattr(p, 'denoising_strength', None),
|
||||||
|
"Eta": (None if p.sampler.eta == p.sampler.default_eta else p.sampler.eta),
|
||||||
}
|
}
|
||||||
|
|
||||||
generation_params.update(p.extra_generation_params)
|
generation_params.update(p.extra_generation_params)
|
||||||
|
@ -277,7 +280,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
|
||||||
|
|
||||||
negative_prompt_text = "\nNegative prompt: " + p.negative_prompt if p.negative_prompt else ""
|
negative_prompt_text = "\nNegative prompt: " + p.negative_prompt if p.negative_prompt else ""
|
||||||
|
|
||||||
return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip() + "".join(["\n\n" + x for x in comments])
|
return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip()
|
||||||
|
|
||||||
|
|
||||||
def process_images(p: StableDiffusionProcessing) -> Processed:
|
def process_images(p: StableDiffusionProcessing) -> Processed:
|
||||||
|
@ -322,6 +325,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
||||||
if os.path.exists(cmd_opts.embeddings_dir):
|
if os.path.exists(cmd_opts.embeddings_dir):
|
||||||
model_hijack.load_textual_inversion_embeddings(cmd_opts.embeddings_dir, p.sd_model)
|
model_hijack.load_textual_inversion_embeddings(cmd_opts.embeddings_dir, p.sd_model)
|
||||||
|
|
||||||
|
infotexts = []
|
||||||
output_images = []
|
output_images = []
|
||||||
precision_scope = torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext
|
precision_scope = torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext
|
||||||
ema_scope = (contextlib.nullcontext if cmd_opts.lowvram else p.sd_model.ema_scope)
|
ema_scope = (contextlib.nullcontext if cmd_opts.lowvram else p.sd_model.ema_scope)
|
||||||
|
@ -404,6 +408,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
||||||
if opts.samples_save and not p.do_not_save_samples:
|
if opts.samples_save and not p.do_not_save_samples:
|
||||||
images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p)
|
images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p)
|
||||||
|
|
||||||
|
infotexts.append(infotext(n, i))
|
||||||
output_images.append(image)
|
output_images.append(image)
|
||||||
|
|
||||||
state.nextjob()
|
state.nextjob()
|
||||||
|
@ -416,6 +421,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
||||||
grid = images.image_grid(output_images, p.batch_size)
|
grid = images.image_grid(output_images, p.batch_size)
|
||||||
|
|
||||||
if opts.return_grid:
|
if opts.return_grid:
|
||||||
|
infotexts.insert(0, infotext())
|
||||||
output_images.insert(0, grid)
|
output_images.insert(0, grid)
|
||||||
index_of_first_image = 1
|
index_of_first_image = 1
|
||||||
|
|
||||||
|
@ -423,7 +429,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
||||||
images.save_image(grid, p.outpath_grids, "grid", all_seeds[0], all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
|
images.save_image(grid, p.outpath_grids, "grid", all_seeds[0], all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
|
||||||
|
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
return Processed(p, output_images, all_seeds[0], infotext(), subseed=all_subseeds[0], all_prompts=all_prompts, all_seeds=all_seeds, all_subseeds=all_subseeds, index_of_first_image=index_of_first_image)
|
return Processed(p, output_images, all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=all_subseeds[0], all_prompts=all_prompts, all_seeds=all_seeds, all_subseeds=all_subseeds, index_of_first_image=index_of_first_image, infotexts=infotexts)
|
||||||
|
|
||||||
|
|
||||||
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||||
|
|
|
@ -126,5 +126,93 @@ def reconstruct_cond_batch(c: ScheduledPromptBatch, current_step):
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
re_attention = re.compile(r"""
|
||||||
|
\\\(|
|
||||||
|
\\\)|
|
||||||
|
\\\[|
|
||||||
|
\\]|
|
||||||
|
\\\\|
|
||||||
|
\\|
|
||||||
|
\(|
|
||||||
|
\[|
|
||||||
|
:([+-]?[.\d]+)\)|
|
||||||
|
\)|
|
||||||
|
]|
|
||||||
|
[^\\()\[\]:]+|
|
||||||
|
:
|
||||||
|
""", re.X)
|
||||||
|
|
||||||
#get_learned_conditioning_prompt_schedules(["fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]"], 100)
|
|
||||||
|
def parse_prompt_attention(text):
|
||||||
|
"""
|
||||||
|
Parses a string with attention tokens and returns a list of pairs: text and its assoicated weight.
|
||||||
|
Accepted tokens are:
|
||||||
|
(abc) - increases attention to abc by a multiplier of 1.1
|
||||||
|
(abc:3.12) - increases attention to abc by a multiplier of 3.12
|
||||||
|
[abc] - decreases attention to abc by a multiplier of 1.1
|
||||||
|
\( - literal character '('
|
||||||
|
\[ - literal character '['
|
||||||
|
\) - literal character ')'
|
||||||
|
\] - literal character ']'
|
||||||
|
\\ - literal character '\'
|
||||||
|
anything else - just text
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
'a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).'
|
||||||
|
|
||||||
|
produces:
|
||||||
|
|
||||||
|
[
|
||||||
|
['a ', 1.0],
|
||||||
|
['house', 1.5730000000000004],
|
||||||
|
[' ', 1.1],
|
||||||
|
['on', 1.0],
|
||||||
|
[' a ', 1.1],
|
||||||
|
['hill', 0.55],
|
||||||
|
[', sun, ', 1.1],
|
||||||
|
['sky', 1.4641000000000006],
|
||||||
|
['.', 1.1]
|
||||||
|
]
|
||||||
|
"""
|
||||||
|
|
||||||
|
res = []
|
||||||
|
round_brackets = []
|
||||||
|
square_brackets = []
|
||||||
|
|
||||||
|
round_bracket_multiplier = 1.1
|
||||||
|
square_bracket_multiplier = 1 / 1.1
|
||||||
|
|
||||||
|
def multiply_range(start_position, multiplier):
|
||||||
|
for p in range(start_position, len(res)):
|
||||||
|
res[p][1] *= multiplier
|
||||||
|
|
||||||
|
for m in re_attention.finditer(text):
|
||||||
|
text = m.group(0)
|
||||||
|
weight = m.group(1)
|
||||||
|
|
||||||
|
if text.startswith('\\'):
|
||||||
|
res.append([text[1:], 1.0])
|
||||||
|
elif text == '(':
|
||||||
|
round_brackets.append(len(res))
|
||||||
|
elif text == '[':
|
||||||
|
square_brackets.append(len(res))
|
||||||
|
elif weight is not None and len(round_brackets) > 0:
|
||||||
|
multiply_range(round_brackets.pop(), float(weight))
|
||||||
|
elif text == ')' and len(round_brackets) > 0:
|
||||||
|
multiply_range(round_brackets.pop(), round_bracket_multiplier)
|
||||||
|
elif text == ']' and len(square_brackets) > 0:
|
||||||
|
multiply_range(square_brackets.pop(), square_bracket_multiplier)
|
||||||
|
else:
|
||||||
|
res.append([text, 1.0])
|
||||||
|
|
||||||
|
for pos in round_brackets:
|
||||||
|
multiply_range(pos, round_bracket_multiplier)
|
||||||
|
|
||||||
|
for pos in square_brackets:
|
||||||
|
multiply_range(pos, square_bracket_multiplier)
|
||||||
|
|
||||||
|
if len(res) == 0:
|
||||||
|
res = [["", 1.0]]
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
|
@ -55,7 +55,7 @@ def load_scripts(basedir):
|
||||||
if not os.path.exists(basedir):
|
if not os.path.exists(basedir):
|
||||||
return
|
return
|
||||||
|
|
||||||
for filename in os.listdir(basedir):
|
for filename in sorted(os.listdir(basedir)):
|
||||||
path = os.path.join(basedir, filename)
|
path = os.path.join(basedir, filename)
|
||||||
|
|
||||||
if not os.path.isfile(path):
|
if not os.path.isfile(path):
|
||||||
|
|
|
@ -6,6 +6,7 @@ import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from torch import einsum
|
from torch import einsum
|
||||||
|
|
||||||
|
from modules import prompt_parser
|
||||||
from modules.shared import opts, device, cmd_opts
|
from modules.shared import opts, device, cmd_opts
|
||||||
|
|
||||||
from ldm.util import default
|
from ldm.util import default
|
||||||
|
@ -180,6 +181,7 @@ class StableDiffusionModelHijack:
|
||||||
dir_mtime = None
|
dir_mtime = None
|
||||||
layers = None
|
layers = None
|
||||||
circular_enabled = False
|
circular_enabled = False
|
||||||
|
clip = None
|
||||||
|
|
||||||
def load_textual_inversion_embeddings(self, dirname, model):
|
def load_textual_inversion_embeddings(self, dirname, model):
|
||||||
mt = os.path.getmtime(dirname)
|
mt = os.path.getmtime(dirname)
|
||||||
|
@ -210,6 +212,7 @@ class StableDiffusionModelHijack:
|
||||||
param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11
|
param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11
|
||||||
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
|
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
|
||||||
emb = next(iter(param_dict.items()))[1]
|
emb = next(iter(param_dict.items()))[1]
|
||||||
|
# diffuser concepts
|
||||||
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
|
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
|
||||||
assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
|
assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
|
||||||
|
|
||||||
|
@ -235,13 +238,14 @@ class StableDiffusionModelHijack:
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
print(f"Loaded a total of {len(self.word_embeddings)} text inversion embeddings.")
|
print(f"Loaded a total of {len(self.word_embeddings)} textual inversion embeddings.")
|
||||||
|
|
||||||
def hijack(self, m):
|
def hijack(self, m):
|
||||||
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
|
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
|
||||||
|
|
||||||
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
|
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
|
||||||
m.cond_stage_model = FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
|
m.cond_stage_model = FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
|
||||||
|
self.clip = m.cond_stage_model
|
||||||
|
|
||||||
if cmd_opts.opt_split_attention_v1:
|
if cmd_opts.opt_split_attention_v1:
|
||||||
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1
|
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1
|
||||||
|
@ -268,6 +272,11 @@ class StableDiffusionModelHijack:
|
||||||
for layer in [layer for layer in self.layers if type(layer) == torch.nn.Conv2d]:
|
for layer in [layer for layer in self.layers if type(layer) == torch.nn.Conv2d]:
|
||||||
layer.padding_mode = 'circular' if enable else 'zeros'
|
layer.padding_mode = 'circular' if enable else 'zeros'
|
||||||
|
|
||||||
|
def tokenize(self, text):
|
||||||
|
max_length = self.clip.max_length - 2
|
||||||
|
_, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text])
|
||||||
|
return remade_batch_tokens[0], token_count, max_length
|
||||||
|
|
||||||
|
|
||||||
class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
||||||
def __init__(self, wrapped, hijack):
|
def __init__(self, wrapped, hijack):
|
||||||
|
@ -294,14 +303,101 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
||||||
if mult != 1.0:
|
if mult != 1.0:
|
||||||
self.token_mults[ident] = mult
|
self.token_mults[ident] = mult
|
||||||
|
|
||||||
def forward(self, text):
|
|
||||||
self.hijack.fixes = []
|
def tokenize_line(self, line, used_custom_terms, hijack_comments):
|
||||||
self.hijack.comments = []
|
id_start = self.wrapped.tokenizer.bos_token_id
|
||||||
|
id_end = self.wrapped.tokenizer.eos_token_id
|
||||||
|
maxlen = self.wrapped.max_length
|
||||||
|
|
||||||
|
if opts.enable_emphasis:
|
||||||
|
parsed = prompt_parser.parse_prompt_attention(line)
|
||||||
|
else:
|
||||||
|
parsed = [[line, 1.0]]
|
||||||
|
|
||||||
|
tokenized = self.wrapped.tokenizer([text for text, _ in parsed], truncation=False, add_special_tokens=False)["input_ids"]
|
||||||
|
|
||||||
|
fixes = []
|
||||||
|
remade_tokens = []
|
||||||
|
multipliers = []
|
||||||
|
|
||||||
|
for tokens, (text, weight) in zip(tokenized, parsed):
|
||||||
|
i = 0
|
||||||
|
while i < len(tokens):
|
||||||
|
token = tokens[i]
|
||||||
|
|
||||||
|
possible_matches = self.hijack.ids_lookup.get(token, None)
|
||||||
|
|
||||||
|
if possible_matches is None:
|
||||||
|
remade_tokens.append(token)
|
||||||
|
multipliers.append(weight)
|
||||||
|
else:
|
||||||
|
found = False
|
||||||
|
for ids, word in possible_matches:
|
||||||
|
if tokens[i:i + len(ids)] == ids:
|
||||||
|
emb_len = int(self.hijack.word_embeddings[word].shape[0])
|
||||||
|
fixes.append((len(remade_tokens), word))
|
||||||
|
remade_tokens += [0] * emb_len
|
||||||
|
multipliers += [weight] * emb_len
|
||||||
|
i += len(ids) - 1
|
||||||
|
found = True
|
||||||
|
used_custom_terms.append((word, self.hijack.word_embeddings_checksums[word]))
|
||||||
|
break
|
||||||
|
|
||||||
|
if not found:
|
||||||
|
remade_tokens.append(token)
|
||||||
|
multipliers.append(weight)
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
if len(remade_tokens) > maxlen - 2:
|
||||||
|
vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
|
||||||
|
ovf = remade_tokens[maxlen - 2:]
|
||||||
|
overflowing_words = [vocab.get(int(x), "") for x in ovf]
|
||||||
|
overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
|
||||||
|
hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
|
||||||
|
|
||||||
|
token_count = len(remade_tokens)
|
||||||
|
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
|
||||||
|
remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end]
|
||||||
|
|
||||||
|
multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
|
||||||
|
multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
|
||||||
|
|
||||||
|
return remade_tokens, fixes, multipliers, token_count
|
||||||
|
|
||||||
|
def process_text(self, texts):
|
||||||
|
used_custom_terms = []
|
||||||
remade_batch_tokens = []
|
remade_batch_tokens = []
|
||||||
|
hijack_comments = []
|
||||||
|
hijack_fixes = []
|
||||||
|
token_count = 0
|
||||||
|
|
||||||
|
cache = {}
|
||||||
|
batch_multipliers = []
|
||||||
|
for line in texts:
|
||||||
|
if line in cache:
|
||||||
|
remade_tokens, fixes, multipliers = cache[line]
|
||||||
|
else:
|
||||||
|
remade_tokens, fixes, multipliers, token_count = self.tokenize_line(line, used_custom_terms, hijack_comments)
|
||||||
|
|
||||||
|
cache[line] = (remade_tokens, fixes, multipliers)
|
||||||
|
|
||||||
|
remade_batch_tokens.append(remade_tokens)
|
||||||
|
hijack_fixes.append(fixes)
|
||||||
|
batch_multipliers.append(multipliers)
|
||||||
|
|
||||||
|
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
|
||||||
|
|
||||||
|
|
||||||
|
def process_text_old(self, text):
|
||||||
id_start = self.wrapped.tokenizer.bos_token_id
|
id_start = self.wrapped.tokenizer.bos_token_id
|
||||||
id_end = self.wrapped.tokenizer.eos_token_id
|
id_end = self.wrapped.tokenizer.eos_token_id
|
||||||
maxlen = self.wrapped.max_length
|
maxlen = self.wrapped.max_length
|
||||||
used_custom_terms = []
|
used_custom_terms = []
|
||||||
|
remade_batch_tokens = []
|
||||||
|
overflowing_words = []
|
||||||
|
hijack_comments = []
|
||||||
|
hijack_fixes = []
|
||||||
|
token_count = 0
|
||||||
|
|
||||||
cache = {}
|
cache = {}
|
||||||
batch_tokens = self.wrapped.tokenizer(text, truncation=False, add_special_tokens=False)["input_ids"]
|
batch_tokens = self.wrapped.tokenizer(text, truncation=False, add_special_tokens=False)["input_ids"]
|
||||||
|
@ -353,9 +449,8 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
||||||
ovf = remade_tokens[maxlen - 2:]
|
ovf = remade_tokens[maxlen - 2:]
|
||||||
overflowing_words = [vocab.get(int(x), "") for x in ovf]
|
overflowing_words = [vocab.get(int(x), "") for x in ovf]
|
||||||
overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
|
overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
|
||||||
|
hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
|
||||||
self.hijack.comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
|
token_count = len(remade_tokens)
|
||||||
|
|
||||||
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
|
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
|
||||||
remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end]
|
remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end]
|
||||||
cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
|
cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
|
||||||
|
@ -364,11 +459,23 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
||||||
multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
|
multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
|
||||||
|
|
||||||
remade_batch_tokens.append(remade_tokens)
|
remade_batch_tokens.append(remade_tokens)
|
||||||
self.hijack.fixes.append(fixes)
|
hijack_fixes.append(fixes)
|
||||||
batch_multipliers.append(multipliers)
|
batch_multipliers.append(multipliers)
|
||||||
|
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
|
||||||
|
|
||||||
|
def forward(self, text):
|
||||||
|
|
||||||
|
if opts.use_old_emphasis_implementation:
|
||||||
|
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text)
|
||||||
|
else:
|
||||||
|
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text)
|
||||||
|
|
||||||
|
|
||||||
|
self.hijack.fixes = hijack_fixes
|
||||||
|
self.hijack.comments = hijack_comments
|
||||||
|
|
||||||
if len(used_custom_terms) > 0:
|
if len(used_custom_terms) > 0:
|
||||||
self.hijack.comments.append("Used custom terms: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
|
self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
|
||||||
|
|
||||||
tokens = torch.asarray(remade_batch_tokens).to(device)
|
tokens = torch.asarray(remade_batch_tokens).to(device)
|
||||||
outputs = self.wrapped.transformer(input_ids=tokens)
|
outputs = self.wrapped.transformer(input_ids=tokens)
|
||||||
|
|
|
@ -23,6 +23,10 @@ except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def checkpoint_tiles():
|
||||||
|
return sorted([x.title for x in checkpoints_list.values()])
|
||||||
|
|
||||||
|
|
||||||
def list_models():
|
def list_models():
|
||||||
checkpoints_list.clear()
|
checkpoints_list.clear()
|
||||||
|
|
||||||
|
@ -39,13 +43,14 @@ def list_models():
|
||||||
if name.startswith("\\") or name.startswith("/"):
|
if name.startswith("\\") or name.startswith("/"):
|
||||||
name = name[1:]
|
name = name[1:]
|
||||||
|
|
||||||
return f'{name} [{h}]'
|
shortname = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
|
||||||
|
|
||||||
|
return f'{name} [{h}]', shortname
|
||||||
|
|
||||||
cmd_ckpt = shared.cmd_opts.ckpt
|
cmd_ckpt = shared.cmd_opts.ckpt
|
||||||
if os.path.exists(cmd_ckpt):
|
if os.path.exists(cmd_ckpt):
|
||||||
h = model_hash(cmd_ckpt)
|
h = model_hash(cmd_ckpt)
|
||||||
title = modeltitle(cmd_ckpt, h)
|
title, model_name = modeltitle(cmd_ckpt, h)
|
||||||
model_name = title.rsplit(".",1)[0] # remove extension if present
|
|
||||||
checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, model_name)
|
checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, model_name)
|
||||||
elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
|
elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
|
||||||
print(f"Checkpoint in --ckpt argument not found: {cmd_ckpt}", file=sys.stderr)
|
print(f"Checkpoint in --ckpt argument not found: {cmd_ckpt}", file=sys.stderr)
|
||||||
|
@ -53,8 +58,7 @@ def list_models():
|
||||||
if os.path.exists(model_dir):
|
if os.path.exists(model_dir):
|
||||||
for filename in glob.glob(model_dir + '/**/*.ckpt', recursive=True):
|
for filename in glob.glob(model_dir + '/**/*.ckpt', recursive=True):
|
||||||
h = model_hash(filename)
|
h = model_hash(filename)
|
||||||
title = modeltitle(filename, h)
|
title, model_name = modeltitle(filename, h)
|
||||||
model_name = title.rsplit(".",1)[0] # remove extension if present
|
|
||||||
checkpoints_list[title] = CheckpointInfo(filename, title, h, model_name)
|
checkpoints_list[title] = CheckpointInfo(filename, title, h, model_name)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -40,10 +40,8 @@ samplers_for_img2img = [x for x in samplers if x.name != 'PLMS']
|
||||||
|
|
||||||
sampler_extra_params = {
|
sampler_extra_params = {
|
||||||
'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
|
'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
|
||||||
'sample_euler_ancestral': ['eta'],
|
|
||||||
'sample_heun': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
|
'sample_heun': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
|
||||||
'sample_dpm_2': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
|
'sample_dpm_2': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
|
||||||
'sample_dpm_2_ancestral': ['eta'],
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def setup_img2img_steps(p, steps=None):
|
def setup_img2img_steps(p, steps=None):
|
||||||
|
@ -101,6 +99,8 @@ class VanillaStableDiffusionSampler:
|
||||||
self.init_latent = None
|
self.init_latent = None
|
||||||
self.sampler_noises = None
|
self.sampler_noises = None
|
||||||
self.step = 0
|
self.step = 0
|
||||||
|
self.eta = None
|
||||||
|
self.default_eta = 0.0
|
||||||
|
|
||||||
def number_of_needed_noises(self, p):
|
def number_of_needed_noises(self, p):
|
||||||
return 0
|
return 0
|
||||||
|
@ -123,20 +123,29 @@ class VanillaStableDiffusionSampler:
|
||||||
self.step += 1
|
self.step += 1
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
def initialize(self, p):
|
||||||
|
self.eta = p.eta or opts.eta_ddim
|
||||||
|
|
||||||
|
for fieldname in ['p_sample_ddim', 'p_sample_plms']:
|
||||||
|
if hasattr(self.sampler, fieldname):
|
||||||
|
setattr(self.sampler, fieldname, self.p_sample_ddim_hook)
|
||||||
|
|
||||||
|
self.mask = p.mask if hasattr(p, 'mask') else None
|
||||||
|
self.nmask = p.nmask if hasattr(p, 'nmask') else None
|
||||||
|
|
||||||
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None):
|
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None):
|
||||||
steps, t_enc = setup_img2img_steps(p, steps)
|
steps, t_enc = setup_img2img_steps(p, steps)
|
||||||
|
|
||||||
|
self.initialize(p)
|
||||||
|
|
||||||
# existing code fails with cetain step counts, like 9
|
# existing code fails with cetain step counts, like 9
|
||||||
try:
|
try:
|
||||||
self.sampler.make_schedule(ddim_num_steps=steps, ddim_eta=p.ddim_eta, ddim_discretize=p.ddim_discretize, verbose=False)
|
self.sampler.make_schedule(ddim_num_steps=steps, ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False)
|
||||||
except Exception:
|
except Exception:
|
||||||
self.sampler.make_schedule(ddim_num_steps=steps+1,ddim_eta=p.ddim_eta, ddim_discretize=p.ddim_discretize, verbose=False)
|
self.sampler.make_schedule(ddim_num_steps=steps+1, ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False)
|
||||||
|
|
||||||
x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise)
|
x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise)
|
||||||
|
|
||||||
self.sampler.p_sample_ddim = self.p_sample_ddim_hook
|
|
||||||
self.mask = p.mask if hasattr(p, 'mask') else None
|
|
||||||
self.nmask = p.nmask if hasattr(p, 'nmask') else None
|
|
||||||
self.init_latent = x
|
self.init_latent = x
|
||||||
self.step = 0
|
self.step = 0
|
||||||
|
|
||||||
|
@ -145,11 +154,8 @@ class VanillaStableDiffusionSampler:
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None):
|
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None):
|
||||||
for fieldname in ['p_sample_ddim', 'p_sample_plms']:
|
self.initialize(p)
|
||||||
if hasattr(self.sampler, fieldname):
|
|
||||||
setattr(self.sampler, fieldname, self.p_sample_ddim_hook)
|
|
||||||
self.mask = None
|
|
||||||
self.nmask = None
|
|
||||||
self.init_latent = None
|
self.init_latent = None
|
||||||
self.step = 0
|
self.step = 0
|
||||||
|
|
||||||
|
@ -157,9 +163,9 @@ class VanillaStableDiffusionSampler:
|
||||||
|
|
||||||
# existing code fails with cetin step counts, like 9
|
# existing code fails with cetin step counts, like 9
|
||||||
try:
|
try:
|
||||||
samples_ddim, _ = self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=p.eta)
|
samples_ddim, _ = self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)
|
||||||
except Exception:
|
except Exception:
|
||||||
samples_ddim, _ = self.sampler.sample(S=steps+1, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=p.eta)
|
samples_ddim, _ = self.sampler.sample(S=steps+1, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)
|
||||||
|
|
||||||
return samples_ddim
|
return samples_ddim
|
||||||
|
|
||||||
|
@ -237,6 +243,8 @@ class KDiffusionSampler:
|
||||||
self.sampler_noises = None
|
self.sampler_noises = None
|
||||||
self.sampler_noise_index = 0
|
self.sampler_noise_index = 0
|
||||||
self.stop_at = None
|
self.stop_at = None
|
||||||
|
self.eta = None
|
||||||
|
self.default_eta = 1.0
|
||||||
|
|
||||||
def callback_state(self, d):
|
def callback_state(self, d):
|
||||||
store_latent(d["denoised"])
|
store_latent(d["denoised"])
|
||||||
|
@ -255,22 +263,12 @@ class KDiffusionSampler:
|
||||||
self.sampler_noise_index += 1
|
self.sampler_noise_index += 1
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None):
|
def initialize(self, p):
|
||||||
steps, t_enc = setup_img2img_steps(p, steps)
|
|
||||||
|
|
||||||
sigmas = self.model_wrap.get_sigmas(steps)
|
|
||||||
|
|
||||||
noise = noise * sigmas[steps - t_enc - 1]
|
|
||||||
|
|
||||||
xi = x + noise
|
|
||||||
|
|
||||||
sigma_sched = sigmas[steps - t_enc - 1:]
|
|
||||||
|
|
||||||
self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
|
self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
|
||||||
self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
|
self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
|
||||||
self.model_wrap_cfg.init_latent = x
|
|
||||||
self.model_wrap.step = 0
|
self.model_wrap.step = 0
|
||||||
self.sampler_noise_index = 0
|
self.sampler_noise_index = 0
|
||||||
|
self.eta = p.eta or opts.eta_ancestral
|
||||||
|
|
||||||
if hasattr(k_diffusion.sampling, 'trange'):
|
if hasattr(k_diffusion.sampling, 'trange'):
|
||||||
k_diffusion.sampling.trange = lambda *args, **kwargs: extended_trange(self, *args, **kwargs)
|
k_diffusion.sampling.trange = lambda *args, **kwargs: extended_trange(self, *args, **kwargs)
|
||||||
|
@ -283,6 +281,25 @@ class KDiffusionSampler:
|
||||||
if hasattr(p, param_name) and param_name in inspect.signature(self.func).parameters:
|
if hasattr(p, param_name) and param_name in inspect.signature(self.func).parameters:
|
||||||
extra_params_kwargs[param_name] = getattr(p, param_name)
|
extra_params_kwargs[param_name] = getattr(p, param_name)
|
||||||
|
|
||||||
|
if 'eta' in inspect.signature(self.func).parameters:
|
||||||
|
extra_params_kwargs['eta'] = self.eta
|
||||||
|
|
||||||
|
return extra_params_kwargs
|
||||||
|
|
||||||
|
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None):
|
||||||
|
steps, t_enc = setup_img2img_steps(p, steps)
|
||||||
|
|
||||||
|
sigmas = self.model_wrap.get_sigmas(steps)
|
||||||
|
|
||||||
|
noise = noise * sigmas[steps - t_enc - 1]
|
||||||
|
xi = x + noise
|
||||||
|
|
||||||
|
extra_params_kwargs = self.initialize(p)
|
||||||
|
|
||||||
|
sigma_sched = sigmas[steps - t_enc - 1:]
|
||||||
|
|
||||||
|
self.model_wrap_cfg.init_latent = x
|
||||||
|
|
||||||
return self.func(self.model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs)
|
return self.func(self.model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs)
|
||||||
|
|
||||||
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None):
|
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None):
|
||||||
|
@ -291,19 +308,7 @@ class KDiffusionSampler:
|
||||||
sigmas = self.model_wrap.get_sigmas(steps)
|
sigmas = self.model_wrap.get_sigmas(steps)
|
||||||
x = x * sigmas[0]
|
x = x * sigmas[0]
|
||||||
|
|
||||||
self.model_wrap_cfg.step = 0
|
extra_params_kwargs = self.initialize(p)
|
||||||
self.sampler_noise_index = 0
|
|
||||||
|
|
||||||
if hasattr(k_diffusion.sampling, 'trange'):
|
|
||||||
k_diffusion.sampling.trange = lambda *args, **kwargs: extended_trange(self, *args, **kwargs)
|
|
||||||
|
|
||||||
if self.sampler_noises is not None:
|
|
||||||
k_diffusion.sampling.torch = TorchHijack(self)
|
|
||||||
|
|
||||||
extra_params_kwargs = {}
|
|
||||||
for param_name in self.extra_params:
|
|
||||||
if hasattr(p, param_name) and param_name in inspect.signature(self.func).parameters:
|
|
||||||
extra_params_kwargs[param_name] = getattr(p, param_name)
|
|
||||||
|
|
||||||
samples = self.func(self.model_wrap_cfg, x, sigmas, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs)
|
samples = self.func(self.model_wrap_cfg, x, sigmas, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs)
|
||||||
|
|
||||||
|
|
|
@ -143,6 +143,7 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
|
||||||
"export_for_4chan": OptionInfo(True, "If PNG image is larger than 4MB or any dimension is larger than 4000, downscale and save copy as JPG"),
|
"export_for_4chan": OptionInfo(True, "If PNG image is larger than 4MB or any dimension is larger than 4000, downscale and save copy as JPG"),
|
||||||
|
|
||||||
"use_original_name_batch": OptionInfo(False, "Use original name for output filename during batch process in extras tab"),
|
"use_original_name_batch": OptionInfo(False, "Use original name for output filename during batch process in extras tab"),
|
||||||
|
"save_selected_only": OptionInfo(True, "When using 'Save' button, only save a single selected image"),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('saving-paths', "Paths for saving"), {
|
options_templates.update(options_section(('saving-paths', "Paths for saving"), {
|
||||||
|
@ -180,7 +181,6 @@ options_templates.update(options_section(('face-restoration', "Face restoration"
|
||||||
"face_restoration_model": OptionInfo(None, "Face restoration model", gr.Radio, lambda: {"choices": [x.name() for x in face_restorers]}),
|
"face_restoration_model": OptionInfo(None, "Face restoration model", gr.Radio, lambda: {"choices": [x.name() for x in face_restorers]}),
|
||||||
"code_former_weight": OptionInfo(0.5, "CodeFormer weight parameter; 0 = maximum effect; 1 = minimum effect", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),
|
"code_former_weight": OptionInfo(0.5, "CodeFormer weight parameter; 0 = maximum effect; 1 = minimum effect", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),
|
||||||
"face_restoration_unload": OptionInfo(False, "Move face restoration model from VRAM into RAM after processing"),
|
"face_restoration_unload": OptionInfo(False, "Move face restoration model from VRAM into RAM after processing"),
|
||||||
"save_selected_only": OptionInfo(False, "When using 'Save' button, only save a single selected image"),
|
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('system', "System"), {
|
options_templates.update(options_section(('system', "System"), {
|
||||||
|
@ -190,12 +190,13 @@ options_templates.update(options_section(('system', "System"), {
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
||||||
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Radio, lambda: {"choices": [x.title for x in modules.sd_models.checkpoints_list.values()]}),
|
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Radio, lambda: {"choices": modules.sd_models.checkpoint_tiles()}),
|
||||||
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
|
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
|
||||||
"save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
|
"save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
|
||||||
"img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."),
|
"img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."),
|
||||||
"enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."),
|
"enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."),
|
||||||
"enable_emphasis": OptionInfo(True, "Use (text) to make model pay more attention to text and [text] to make it pay less attention"),
|
"enable_emphasis": OptionInfo(True, "Eemphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"),
|
||||||
|
"use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."),
|
||||||
"enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
|
"enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
|
||||||
"filter_nsfw": OptionInfo(False, "Filter NSFW content"),
|
"filter_nsfw": OptionInfo(False, "Filter NSFW content"),
|
||||||
"random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}),
|
"random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}),
|
||||||
|
@ -221,8 +222,9 @@ options_templates.update(options_section(('ui', "User interface"), {
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('sampler-params', "Sampler parameters"), {
|
options_templates.update(options_section(('sampler-params', "Sampler parameters"), {
|
||||||
"eta": OptionInfo(0.0, "DDIM and K Ancestral eta", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
"eta_ddim": OptionInfo(0.0, "eta (noise multiplier) for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
||||||
"ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform','quad']}),
|
"eta_ancestral": OptionInfo(1.0, "eta (noise multiplier) for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
||||||
|
"ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}),
|
||||||
's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
||||||
's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
||||||
's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
||||||
|
|
124
modules/ui.py
124
modules/ui.py
|
@ -9,10 +9,12 @@ import random
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
|
import platform
|
||||||
|
import subprocess as sp
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image, PngImagePlugin
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import gradio.utils
|
import gradio.utils
|
||||||
|
@ -22,6 +24,7 @@ from modules.paths import script_path
|
||||||
from modules.shared import opts, cmd_opts
|
from modules.shared import opts, cmd_opts
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
from modules.sd_samplers import samplers, samplers_for_img2img
|
from modules.sd_samplers import samplers, samplers_for_img2img
|
||||||
|
from modules.sd_hijack import model_hijack
|
||||||
import modules.ldsr_model
|
import modules.ldsr_model
|
||||||
import modules.scripts
|
import modules.scripts
|
||||||
import modules.gfpgan_model
|
import modules.gfpgan_model
|
||||||
|
@ -61,7 +64,7 @@ random_symbol = '\U0001f3b2\ufe0f' # 🎲️
|
||||||
reuse_symbol = '\u267b\ufe0f' # ♻️
|
reuse_symbol = '\u267b\ufe0f' # ♻️
|
||||||
art_symbol = '\U0001f3a8' # 🎨
|
art_symbol = '\U0001f3a8' # 🎨
|
||||||
paste_symbol = '\u2199\ufe0f' # ↙
|
paste_symbol = '\u2199\ufe0f' # ↙
|
||||||
|
folder_symbol = '\uD83D\uDCC2'
|
||||||
|
|
||||||
def plaintext_to_html(text):
|
def plaintext_to_html(text):
|
||||||
text = "<p>" + "<br>\n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "</p>"
|
text = "<p>" + "<br>\n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "</p>"
|
||||||
|
@ -102,6 +105,7 @@ def save_files(js_data, images, index):
|
||||||
setattr(self, key, value)
|
setattr(self, key, value)
|
||||||
|
|
||||||
data = json.loads(js_data)
|
data = json.loads(js_data)
|
||||||
|
|
||||||
p = MyObject(data)
|
p = MyObject(data)
|
||||||
path = opts.outdir_save
|
path = opts.outdir_save
|
||||||
save_to_dirs = opts.save_to_dirs
|
save_to_dirs = opts.save_to_dirs
|
||||||
|
@ -112,9 +116,13 @@ def save_files(js_data, images, index):
|
||||||
|
|
||||||
os.makedirs(path, exist_ok=True)
|
os.makedirs(path, exist_ok=True)
|
||||||
|
|
||||||
if index > -1 and opts.save_selected_only and (index > 0 or not opts.return_grid): # ensures we are looking at a specific non-grid picture, and we have save_selected_only
|
|
||||||
|
if index > -1 and opts.save_selected_only and (index >= data["index_of_first_image"]): # ensures we are looking at a specific non-grid picture, and we have save_selected_only
|
||||||
|
|
||||||
images = [images[index]]
|
images = [images[index]]
|
||||||
data["seed"] += (index - 1 if opts.return_grid else index)
|
infotexts = [data["infotexts"][index]]
|
||||||
|
else:
|
||||||
|
infotexts = data["infotexts"]
|
||||||
|
|
||||||
with open(os.path.join(opts.outdir_save, "log.csv"), "a", encoding="utf8", newline='') as file:
|
with open(os.path.join(opts.outdir_save, "log.csv"), "a", encoding="utf8", newline='') as file:
|
||||||
at_start = file.tell() == 0
|
at_start = file.tell() == 0
|
||||||
|
@ -137,8 +145,11 @@ def save_files(js_data, images, index):
|
||||||
if filedata.startswith("data:image/png;base64,"):
|
if filedata.startswith("data:image/png;base64,"):
|
||||||
filedata = filedata[len("data:image/png;base64,"):]
|
filedata = filedata[len("data:image/png;base64,"):]
|
||||||
|
|
||||||
with open(filepath, "wb") as imgfile:
|
pnginfo = PngImagePlugin.PngInfo()
|
||||||
imgfile.write(base64.decodebytes(filedata.encode('utf-8')))
|
pnginfo.add_text('parameters', infotexts[i])
|
||||||
|
|
||||||
|
image = Image.open(io.BytesIO(base64.decodebytes(filedata.encode('utf-8'))))
|
||||||
|
image.save(filepath, quality=opts.jpeg_quality, pnginfo=pnginfo)
|
||||||
|
|
||||||
filenames.append(filename)
|
filenames.append(filename)
|
||||||
|
|
||||||
|
@ -350,6 +361,10 @@ def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info:
|
||||||
outputs=[seed, dummy_component]
|
outputs=[seed, dummy_component]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def update_token_counter(text):
|
||||||
|
tokens, token_count, max_length = model_hijack.tokenize(text)
|
||||||
|
style_class = ' class="red"' if (token_count > max_length) else ""
|
||||||
|
return f"<span {style_class}>{token_count}/{max_length}</span>"
|
||||||
|
|
||||||
def create_toprow(is_img2img):
|
def create_toprow(is_img2img):
|
||||||
id_part = "img2img" if is_img2img else "txt2img"
|
id_part = "img2img" if is_img2img else "txt2img"
|
||||||
|
@ -359,11 +374,14 @@ def create_toprow(is_img2img):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column(scale=80):
|
with gr.Column(scale=80):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
prompt = gr.Textbox(label="Prompt", elem_id="prompt", show_label=False, placeholder="Prompt", lines=2)
|
prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, placeholder="Prompt", lines=2)
|
||||||
|
|
||||||
with gr.Column(scale=1, elem_id="roll_col"):
|
with gr.Column(scale=1, elem_id="roll_col"):
|
||||||
roll = gr.Button(value=art_symbol, elem_id="roll", visible=len(shared.artist_db.artists) > 0)
|
roll = gr.Button(value=art_symbol, elem_id="roll", visible=len(shared.artist_db.artists) > 0)
|
||||||
paste = gr.Button(value=paste_symbol, elem_id="paste")
|
paste = gr.Button(value=paste_symbol, elem_id="paste")
|
||||||
|
token_counter = gr.HTML(value="<span></span>", elem_id=f"{id_part}_token_counter")
|
||||||
|
hidden_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
|
||||||
|
hidden_button.click(fn=update_token_counter, inputs=[prompt], outputs=[token_counter])
|
||||||
|
|
||||||
with gr.Column(scale=10, elem_id="style_pos_col"):
|
with gr.Column(scale=10, elem_id="style_pos_col"):
|
||||||
prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())), visible=len(shared.prompt_styles.styles) > 1)
|
prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())), visible=len(shared.prompt_styles.styles) > 1)
|
||||||
|
@ -470,6 +488,8 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
|
||||||
send_to_img2img = gr.Button('Send to img2img')
|
send_to_img2img = gr.Button('Send to img2img')
|
||||||
send_to_inpaint = gr.Button('Send to inpaint')
|
send_to_inpaint = gr.Button('Send to inpaint')
|
||||||
send_to_extras = gr.Button('Send to extras')
|
send_to_extras = gr.Button('Send to extras')
|
||||||
|
button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else 'open_folder'
|
||||||
|
open_txt2img_folder = gr.Button(folder_symbol, elem_id=button_id)
|
||||||
|
|
||||||
with gr.Group():
|
with gr.Group():
|
||||||
html_info = gr.HTML()
|
html_info = gr.HTML()
|
||||||
|
@ -646,6 +666,8 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
|
||||||
img2img_send_to_img2img = gr.Button('Send to img2img')
|
img2img_send_to_img2img = gr.Button('Send to img2img')
|
||||||
img2img_send_to_inpaint = gr.Button('Send to inpaint')
|
img2img_send_to_inpaint = gr.Button('Send to inpaint')
|
||||||
img2img_send_to_extras = gr.Button('Send to extras')
|
img2img_send_to_extras = gr.Button('Send to extras')
|
||||||
|
button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else 'open_folder'
|
||||||
|
open_img2img_folder = gr.Button(folder_symbol, elem_id=button_id)
|
||||||
|
|
||||||
with gr.Group():
|
with gr.Group():
|
||||||
html_info = gr.HTML()
|
html_info = gr.HTML()
|
||||||
|
@ -818,6 +840,8 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
|
||||||
html_info = gr.HTML()
|
html_info = gr.HTML()
|
||||||
extras_send_to_img2img = gr.Button('Send to img2img')
|
extras_send_to_img2img = gr.Button('Send to img2img')
|
||||||
extras_send_to_inpaint = gr.Button('Send to inpaint')
|
extras_send_to_inpaint = gr.Button('Send to inpaint')
|
||||||
|
button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else ''
|
||||||
|
open_extras_folder = gr.Button('Open output directory', elem_id=button_id)
|
||||||
|
|
||||||
submit.click(
|
submit.click(
|
||||||
fn=run_extras,
|
fn=run_extras,
|
||||||
|
@ -878,32 +902,20 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
|
||||||
with gr.Blocks() as modelmerger_interface:
|
with gr.Blocks() as modelmerger_interface:
|
||||||
with gr.Row().style(equal_height=False):
|
with gr.Row().style(equal_height=False):
|
||||||
with gr.Column(variant='panel'):
|
with gr.Column(variant='panel'):
|
||||||
gr.HTML(value="<p>A merger of the two checkpoints will be generated in your <b>/models</b> directory.</p>")
|
gr.HTML(value="<p>A merger of the two checkpoints will be generated in your <b>checkpoint</b> directory.</p>")
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
ckpt_name_list = sorted([x.model_name for x in modules.sd_models.checkpoints_list.values()])
|
primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary Model Name")
|
||||||
primary_model_name = gr.Dropdown(ckpt_name_list, elem_id="modelmerger_primary_model_name", label="Primary Model Name")
|
secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary Model Name")
|
||||||
secondary_model_name = gr.Dropdown(ckpt_name_list, elem_id="modelmerger_secondary_model_name", label="Secondary Model Name")
|
custom_name = gr.Textbox(label="Custom Name (Optional)")
|
||||||
interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Interpolation Amount', value=0.3)
|
interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Interpolation Amount', value=0.3)
|
||||||
interp_method = gr.Radio(choices=["Weighted Sum", "Sigmoid"], value="Weighted Sum", label="Interpolation Method")
|
interp_method = gr.Radio(choices=["Weighted Sum", "Sigmoid", "Inverse Sigmoid"], value="Weighted Sum", label="Interpolation Method")
|
||||||
submit = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary')
|
save_as_half = gr.Checkbox(value=False, label="Safe as float16")
|
||||||
|
modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary')
|
||||||
|
|
||||||
with gr.Column(variant='panel'):
|
with gr.Column(variant='panel'):
|
||||||
submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False)
|
submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False)
|
||||||
|
|
||||||
submit.click(
|
|
||||||
fn=run_modelmerger,
|
|
||||||
inputs=[
|
|
||||||
primary_model_name,
|
|
||||||
secondary_model_name,
|
|
||||||
interp_method,
|
|
||||||
interp_amount
|
|
||||||
],
|
|
||||||
outputs=[
|
|
||||||
submit_result,
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
def create_setting_component(key):
|
def create_setting_component(key):
|
||||||
def fun():
|
def fun():
|
||||||
return opts.data[key] if key in opts.data else opts.data_labels[key].default
|
return opts.data[key] if key in opts.data else opts.data_labels[key].default
|
||||||
|
@ -927,6 +939,17 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
|
||||||
return comp(label=info.label, value=fun, **(args or {}))
|
return comp(label=info.label, value=fun, **(args or {}))
|
||||||
|
|
||||||
components = []
|
components = []
|
||||||
|
component_dict = {}
|
||||||
|
|
||||||
|
def open_folder(f):
|
||||||
|
if not shared.cmd_opts.hide_ui_dir_config:
|
||||||
|
path = os.path.normpath(f)
|
||||||
|
if platform.system() == "Windows":
|
||||||
|
os.startfile(path)
|
||||||
|
elif platform.system() == "Darwin":
|
||||||
|
sp.Popen(["open", path])
|
||||||
|
else:
|
||||||
|
sp.Popen(["xdg-open", path])
|
||||||
|
|
||||||
def run_settings(*args):
|
def run_settings(*args):
|
||||||
changed = 0
|
changed = 0
|
||||||
|
@ -982,7 +1005,9 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
|
||||||
|
|
||||||
gr.HTML(elem_id="settings_header_text_{}".format(item.section[0]), value='<h1 class="gr-button-lg">{}</h1>'.format(item.section[1]))
|
gr.HTML(elem_id="settings_header_text_{}".format(item.section[0]), value='<h1 class="gr-button-lg">{}</h1>'.format(item.section[1]))
|
||||||
|
|
||||||
components.append(create_setting_component(k))
|
component = create_setting_component(k)
|
||||||
|
component_dict[k] = component
|
||||||
|
components.append(component)
|
||||||
items_displayed += 1
|
items_displayed += 1
|
||||||
|
|
||||||
request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications")
|
request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications")
|
||||||
|
@ -1033,6 +1058,33 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
|
||||||
outputs=[result, text_settings],
|
outputs=[result, text_settings],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def modelmerger(*args):
|
||||||
|
try:
|
||||||
|
results = run_modelmerger(*args)
|
||||||
|
except Exception as e:
|
||||||
|
print("Error loading/saving model file:", file=sys.stderr)
|
||||||
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
modules.sd_models.list_models() #To remove the potentially missing models from the list
|
||||||
|
return ["Error loading/saving model file. It doesn't exist or the name contains illegal characters"] + [gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(3)]
|
||||||
|
return results
|
||||||
|
|
||||||
|
modelmerger_merge.click(
|
||||||
|
fn=modelmerger,
|
||||||
|
inputs=[
|
||||||
|
primary_model_name,
|
||||||
|
secondary_model_name,
|
||||||
|
interp_method,
|
||||||
|
interp_amount,
|
||||||
|
save_as_half,
|
||||||
|
custom_name,
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
submit_result,
|
||||||
|
primary_model_name,
|
||||||
|
secondary_model_name,
|
||||||
|
component_dict['sd_model_checkpoint'],
|
||||||
|
]
|
||||||
|
)
|
||||||
paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration', 'Seed', 'Size-1', 'Size-2']
|
paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration', 'Seed', 'Size-1', 'Size-2']
|
||||||
txt2img_fields = [field for field,name in txt2img_paste_fields if name in paste_field_names]
|
txt2img_fields = [field for field,name in txt2img_paste_fields if name in paste_field_names]
|
||||||
img2img_fields = [field for field,name in img2img_paste_fields if name in paste_field_names]
|
img2img_fields = [field for field,name in img2img_paste_fields if name in paste_field_names]
|
||||||
|
@ -1071,6 +1123,24 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
|
||||||
outputs=[extras_image],
|
outputs=[extras_image],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
open_txt2img_folder.click(
|
||||||
|
fn=lambda: open_folder(opts.outdir_samples or opts.outdir_txt2img_samples),
|
||||||
|
inputs=[],
|
||||||
|
outputs=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
open_img2img_folder.click(
|
||||||
|
fn=lambda: open_folder(opts.outdir_samples or opts.outdir_img2img_samples),
|
||||||
|
inputs=[],
|
||||||
|
outputs=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
open_extras_folder.click(
|
||||||
|
fn=lambda: open_folder(opts.outdir_samples or opts.outdir_extras_samples),
|
||||||
|
inputs=[],
|
||||||
|
outputs=[],
|
||||||
|
)
|
||||||
|
|
||||||
img2img_send_to_extras.click(
|
img2img_send_to_extras.click(
|
||||||
fn=lambda x: image_from_url_text(x),
|
fn=lambda x: image_from_url_text(x),
|
||||||
_js="extract_image_from_gallery_extras",
|
_js="extract_image_from_gallery_extras",
|
||||||
|
|
|
@ -6,7 +6,6 @@ font-roboto
|
||||||
gfpgan
|
gfpgan
|
||||||
gradio
|
gradio
|
||||||
invisible-watermark
|
invisible-watermark
|
||||||
git+https://github.com/crowsonkb/k-diffusion.git
|
|
||||||
numpy
|
numpy
|
||||||
omegaconf
|
omegaconf
|
||||||
piexif
|
piexif
|
||||||
|
@ -16,5 +15,12 @@ realesrgan
|
||||||
scikit-image>=0.19
|
scikit-image>=0.19
|
||||||
git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379
|
git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379
|
||||||
timm==0.4.12
|
timm==0.4.12
|
||||||
transformers
|
transformers==4.19.2
|
||||||
torch
|
torch
|
||||||
|
einops
|
||||||
|
jsonmerge
|
||||||
|
clean-fid
|
||||||
|
git+https://github.com/openai/CLIP@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1
|
||||||
|
resize-right
|
||||||
|
torchdiffeq
|
||||||
|
kornia
|
||||||
|
|
|
@ -15,3 +15,10 @@ font-roboto
|
||||||
timm==0.6.7
|
timm==0.6.7
|
||||||
fairscale==0.4.9
|
fairscale==0.4.9
|
||||||
piexif==1.1.3
|
piexif==1.1.3
|
||||||
|
einops==0.4.1
|
||||||
|
jsonmerge==1.8.0
|
||||||
|
clean-fid==0.1.29
|
||||||
|
git+https://github.com/openai/CLIP@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1
|
||||||
|
resize-right==0.0.2
|
||||||
|
torchdiffeq==0.2.3
|
||||||
|
kornia==0.6.7
|
||||||
|
|
|
@ -91,8 +91,8 @@ axis_options = [
|
||||||
AxisOption("Sigma min", float, apply_field("s_tmin"), format_value_add_label),
|
AxisOption("Sigma min", float, apply_field("s_tmin"), format_value_add_label),
|
||||||
AxisOption("Sigma max", float, apply_field("s_tmax"), format_value_add_label),
|
AxisOption("Sigma max", float, apply_field("s_tmax"), format_value_add_label),
|
||||||
AxisOption("Sigma noise", float, apply_field("s_noise"), format_value_add_label),
|
AxisOption("Sigma noise", float, apply_field("s_noise"), format_value_add_label),
|
||||||
AxisOption("DDIM Eta", float, apply_field("ddim_eta"), format_value_add_label),
|
AxisOption("Eta", float, apply_field("eta"), format_value_add_label),
|
||||||
AxisOptionImg2Img("Denoising", float, apply_field("denoising_strength"), format_value_add_label),# as it is now all AxisOptionImg2Img items must go after AxisOption ones
|
AxisOptionImg2Img("Denoising", float, apply_field("denoising_strength"), format_value_add_label), # as it is now all AxisOptionImg2Img items must go after AxisOption ones
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
16
style.css
16
style.css
|
@ -1,5 +1,11 @@
|
||||||
.output-html p {margin: 0 0.5em;}
|
.output-html p {margin: 0 0.5em;}
|
||||||
|
|
||||||
|
.row > *,
|
||||||
|
.row > .gr-form > * {
|
||||||
|
min-width: min(120px, 100%);
|
||||||
|
flex: 1 1 0%;
|
||||||
|
}
|
||||||
|
|
||||||
.performance {
|
.performance {
|
||||||
font-size: 0.85em;
|
font-size: 0.85em;
|
||||||
color: #444;
|
color: #444;
|
||||||
|
@ -43,13 +49,17 @@
|
||||||
margin-right: auto;
|
margin-right: auto;
|
||||||
}
|
}
|
||||||
|
|
||||||
#random_seed, #random_subseed, #reuse_seed, #reuse_subseed{
|
#random_seed, #random_subseed, #reuse_seed, #reuse_subseed, #open_folder{
|
||||||
min-width: auto;
|
min-width: auto;
|
||||||
flex-grow: 0;
|
flex-grow: 0;
|
||||||
padding-left: 0.25em;
|
padding-left: 0.25em;
|
||||||
padding-right: 0.25em;
|
padding-right: 0.25em;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#hidden_element{
|
||||||
|
display: none;
|
||||||
|
}
|
||||||
|
|
||||||
#seed_row, #subseed_row{
|
#seed_row, #subseed_row{
|
||||||
gap: 0.5rem;
|
gap: 0.5rem;
|
||||||
}
|
}
|
||||||
|
@ -389,3 +399,7 @@ input[type="range"]{
|
||||||
border-radius: 8px;
|
border-radius: 8px;
|
||||||
display: none;
|
display: none;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.red {
|
||||||
|
color: red;
|
||||||
|
}
|
||||||
|
|
5
webui.py
5
webui.py
|
@ -1,6 +1,7 @@
|
||||||
import os
|
import os
|
||||||
import threading
|
import threading
|
||||||
|
|
||||||
|
from modules import devices
|
||||||
from modules.paths import script_path
|
from modules.paths import script_path
|
||||||
|
|
||||||
import signal
|
import signal
|
||||||
|
@ -47,6 +48,8 @@ def wrap_queued_call(func):
|
||||||
|
|
||||||
def wrap_gradio_gpu_call(func):
|
def wrap_gradio_gpu_call(func):
|
||||||
def f(*args, **kwargs):
|
def f(*args, **kwargs):
|
||||||
|
devices.torch_gc()
|
||||||
|
|
||||||
shared.state.sampling_step = 0
|
shared.state.sampling_step = 0
|
||||||
shared.state.job_count = -1
|
shared.state.job_count = -1
|
||||||
shared.state.job_no = 0
|
shared.state.job_no = 0
|
||||||
|
@ -62,6 +65,8 @@ def wrap_gradio_gpu_call(func):
|
||||||
shared.state.job = ""
|
shared.state.job = ""
|
||||||
shared.state.job_count = 0
|
shared.state.job_count = 0
|
||||||
|
|
||||||
|
devices.torch_gc()
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
return modules.ui.wrap_gradio_call(f)
|
return modules.ui.wrap_gradio_call(f)
|
||||||
|
|
Loading…
Reference in a new issue