fix to tokens lenght, addend embs generator, add new features to edit the embedding before the generation using text

This commit is contained in:
MalumaDev 2022-10-15 15:59:37 +02:00
parent bb57f30c2d
commit 37d7ffb415
6 changed files with 302 additions and 96 deletions

78
modules/aesthetic_clip.py Normal file
View file

@ -0,0 +1,78 @@
import itertools
import os
from pathlib import Path
import html
import gc
import gradio as gr
import torch
from PIL import Image
from modules import shared
from modules.shared import device, aesthetic_embeddings
from transformers import CLIPModel, CLIPProcessor
from tqdm.auto import tqdm
def get_all_images_in_folder(folder):
return [os.path.join(folder, f) for f in os.listdir(folder) if
os.path.isfile(os.path.join(folder, f)) and check_is_valid_image_file(f)]
def check_is_valid_image_file(filename):
return filename.lower().endswith(('.png', '.jpg', '.jpeg'))
def batched(dataset, total, n=1):
for ndx in range(0, total, n):
yield [dataset.__getitem__(i) for i in range(ndx, min(ndx + n, total))]
def iter_to_batched(iterable, n=1):
it = iter(iterable)
while True:
chunk = tuple(itertools.islice(it, n))
if not chunk:
return
yield chunk
def generate_imgs_embd(name, folder, batch_size):
# clipModel = CLIPModel.from_pretrained(
# shared.sd_model.cond_stage_model.clipModel.name_or_path
# )
model = CLIPModel.from_pretrained(shared.sd_model.cond_stage_model.clipModel.name_or_path).to(device)
processor = CLIPProcessor.from_pretrained(shared.sd_model.cond_stage_model.clipModel.name_or_path)
with torch.no_grad():
embs = []
for paths in tqdm(iter_to_batched(get_all_images_in_folder(folder), batch_size),
desc=f"Generating embeddings for {name}"):
if shared.state.interrupted:
break
inputs = processor(images=[Image.open(path) for path in paths], return_tensors="pt").to(device)
outputs = model.get_image_features(**inputs).cpu()
embs.append(torch.clone(outputs))
inputs.to("cpu")
del inputs, outputs
embs = torch.cat(embs, dim=0).mean(dim=0, keepdim=True)
# The generated embedding will be located here
path = str(Path(shared.cmd_opts.aesthetic_embeddings_dir) / f"{name}.pt")
torch.save(embs, path)
model = model.cpu()
del model
del processor
del embs
gc.collect()
torch.cuda.empty_cache()
res = f"""
Done generating embedding for {name}!
Hypernetwork saved to {html.escape(path)}
"""
shared.update_aesthetic_embeddings()
return gr.Dropdown(sorted(aesthetic_embeddings.keys()), label="Imgs embedding",
value=sorted(aesthetic_embeddings.keys())[0] if len(
aesthetic_embeddings) > 0 else None), res, ""

View file

@ -20,7 +20,6 @@ import modules.images as images
import modules.styles
import logging
# some of those options should not be changed at all because they would break the model, so I removed them from options.
opt_C = 4
opt_f = 8
@ -52,8 +51,13 @@ def get_correct_sampler(p):
elif isinstance(p, modules.processing.StableDiffusionProcessingImg2Img):
return sd_samplers.samplers_for_img2img
class StableDiffusionProcessing:
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", styles=None, seed=-1, subseed=-1, subseed_strength=0, seed_resize_from_h=-1, seed_resize_from_w=-1, seed_enable_extras=True, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None, eta=None):
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", styles=None, seed=-1,
subseed=-1, subseed_strength=0, seed_resize_from_h=-1, seed_resize_from_w=-1, seed_enable_extras=True,
sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512,
restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False,
extra_generation_params=None, overlay_images=None, negative_prompt=None, eta=None):
self.sd_model = sd_model
self.outpath_samples: str = outpath_samples
self.outpath_grids: str = outpath_grids
@ -104,7 +108,8 @@ class StableDiffusionProcessing:
class Processed:
def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None):
def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None,
all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None):
self.images = images_list
self.prompt = p.prompt
self.negative_prompt = p.negative_prompt
@ -141,7 +146,8 @@ class Processed:
self.prompt = self.prompt if type(self.prompt) != list else self.prompt[0]
self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0]
self.seed = int(self.seed if type(self.seed) != list else self.seed[0])
self.subseed = int(self.subseed if type(self.subseed) != list else self.subseed[0]) if self.subseed is not None else -1
self.subseed = int(
self.subseed if type(self.subseed) != list else self.subseed[0]) if self.subseed is not None else -1
self.all_prompts = all_prompts or [self.prompt]
self.all_seeds = all_seeds or [self.seed]
@ -181,39 +187,43 @@ class Processed:
return json.dumps(obj)
def infotext(self, p: StableDiffusionProcessing, index):
return create_infotext(p, self.all_prompts, self.all_seeds, self.all_subseeds, comments=[], position_in_batch=index % self.batch_size, iteration=index // self.batch_size)
def infotext(self, p: StableDiffusionProcessing, index):
return create_infotext(p, self.all_prompts, self.all_seeds, self.all_subseeds, comments=[],
position_in_batch=index % self.batch_size, iteration=index // self.batch_size)
# from https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3
def slerp(val, low, high):
low_norm = low/torch.norm(low, dim=1, keepdim=True)
high_norm = high/torch.norm(high, dim=1, keepdim=True)
dot = (low_norm*high_norm).sum(1)
low_norm = low / torch.norm(low, dim=1, keepdim=True)
high_norm = high / torch.norm(high, dim=1, keepdim=True)
dot = (low_norm * high_norm).sum(1)
if dot.mean() > 0.9995:
return low * val + high * (1 - val)
omega = torch.acos(dot)
so = torch.sin(omega)
res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high
res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low + (torch.sin(val * omega) / so).unsqueeze(1) * high
return res
def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0, p=None):
def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0,
p=None):
xs = []
# if we have multiple seeds, this means we are working with batch size>1; this then
# enables the generation of additional tensors with noise that the sampler will use during its processing.
# Using those pre-generated tensors instead of simple torch.randn allows a batch with seeds [100, 101] to
# produce the same images as with two batches [100], [101].
if p is not None and p.sampler is not None and (len(seeds) > 1 and opts.enable_batch_seeds or opts.eta_noise_seed_delta > 0):
if p is not None and p.sampler is not None and (
len(seeds) > 1 and opts.enable_batch_seeds or opts.eta_noise_seed_delta > 0):
sampler_noises = [[] for _ in range(p.sampler.number_of_needed_noises(p))]
else:
sampler_noises = None
for i, seed in enumerate(seeds):
noise_shape = shape if seed_resize_from_h <= 0 or seed_resize_from_w <= 0 else (shape[0], seed_resize_from_h//8, seed_resize_from_w//8)
noise_shape = shape if seed_resize_from_h <= 0 or seed_resize_from_w <= 0 else (
shape[0], seed_resize_from_h // 8, seed_resize_from_w // 8)
subnoise = None
if subseeds is not None:
@ -241,7 +251,7 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
dx = max(-dx, 0)
dy = max(-dy, 0)
x[:, ty:ty+h, tx:tx+w] = noise[:, dy:dy+h, dx:dx+w]
x[:, ty:ty + h, tx:tx + w] = noise[:, dy:dy + h, dx:dx + w]
noise = x
if sampler_noises is not None:
@ -293,14 +303,20 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
"Seed": all_seeds[index],
"Face restoration": (opts.face_restoration_model if p.restore_faces else None),
"Size": f"{p.width}x{p.height}",
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
"Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
"Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name.replace(',', '').replace(':', '')),
"Model hash": getattr(p, 'sd_model_hash',
None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
"Model": (
None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(
',', '').replace(':', '')),
"Hypernet": (
None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name.replace(',', '').replace(
':', '')),
"Batch size": (None if p.batch_size < 2 else p.batch_size),
"Batch pos": (None if p.batch_size < 2 else position_in_batch),
"Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
"Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
"Seed resize from": (
None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
"Denoising strength": getattr(p, 'denoising_strength', None),
"Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta),
"Clip skip": None if clip_skip <= 1 else clip_skip,
@ -309,7 +325,8 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
generation_params.update(p.extra_generation_params)
generation_params_text = ", ".join([k if k == v else f'{k}: {v}' for k, v in generation_params.items() if v is not None])
generation_params_text = ", ".join(
[k if k == v else f'{k}: {v}' for k, v in generation_params.items() if v is not None])
negative_prompt_text = "\nNegative prompt: " + p.negative_prompt if p.negative_prompt else ""
@ -317,7 +334,9 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
def process_images(p: StableDiffusionProcessing, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0,
aesthetic_imgs=None,aesthetic_slerp=False) -> Processed:
aesthetic_imgs=None, aesthetic_slerp=False, aesthetic_imgs_text="",
aesthetic_slerp_angle=0.15,
aesthetic_text_negative=False) -> Processed:
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
aesthetic_lr = float(aesthetic_lr)
@ -385,7 +404,7 @@ def process_images(p: StableDiffusionProcessing, aesthetic_lr=0, aesthetic_weigh
for n in range(p.n_iter):
if state.skipped:
state.skipped = False
if state.interrupted:
break
@ -396,16 +415,19 @@ def process_images(p: StableDiffusionProcessing, aesthetic_lr=0, aesthetic_weigh
if (len(prompts) == 0):
break
#uc = p.sd_model.get_learned_conditioning(len(prompts) * [p.negative_prompt])
#c = p.sd_model.get_learned_conditioning(prompts)
# uc = p.sd_model.get_learned_conditioning(len(prompts) * [p.negative_prompt])
# c = p.sd_model.get_learned_conditioning(prompts)
with devices.autocast():
if hasattr(shared.sd_model.cond_stage_model, "set_aesthetic_params"):
shared.sd_model.cond_stage_model.set_aesthetic_params(0, 0, 0)
shared.sd_model.cond_stage_model.set_aesthetic_params()
uc = prompt_parser.get_learned_conditioning(shared.sd_model, len(prompts) * [p.negative_prompt],
p.steps)
if hasattr(shared.sd_model.cond_stage_model, "set_aesthetic_params"):
shared.sd_model.cond_stage_model.set_aesthetic_params(aesthetic_lr, aesthetic_weight,
aesthetic_steps, aesthetic_imgs,aesthetic_slerp)
aesthetic_steps, aesthetic_imgs,
aesthetic_slerp, aesthetic_imgs_text,
aesthetic_slerp_angle,
aesthetic_text_negative)
c = prompt_parser.get_multicond_learned_conditioning(shared.sd_model, prompts, p.steps)
if len(model_hijack.comments) > 0:
@ -413,13 +435,13 @@ def process_images(p: StableDiffusionProcessing, aesthetic_lr=0, aesthetic_weigh
comments[comment] = 1
if p.n_iter > 1:
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
shared.state.job = f"Batch {n + 1} out of {p.n_iter}"
with devices.autocast():
samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength)
samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds,
subseed_strength=p.subseed_strength)
if state.interrupted or state.skipped:
# if we are interrupted, sample returns just noise
# use the image collected previously in sampler loop
samples_ddim = shared.state.current_latent
@ -445,7 +467,9 @@ def process_images(p: StableDiffusionProcessing, aesthetic_lr=0, aesthetic_weigh
if p.restore_faces:
if opts.save and not p.do_not_save_samples and opts.save_images_before_face_restoration:
images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-face-restoration")
images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", seeds[i], prompts[i],
opts.samples_format, info=infotext(n, i), p=p,
suffix="-before-face-restoration")
devices.torch_gc()
@ -456,7 +480,8 @@ def process_images(p: StableDiffusionProcessing, aesthetic_lr=0, aesthetic_weigh
if p.color_corrections is not None and i < len(p.color_corrections):
if opts.save and not p.do_not_save_samples and opts.save_images_before_color_correction:
images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-color-correction")
images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format,
info=infotext(n, i), p=p, suffix="-before-color-correction")
image = apply_color_correction(p.color_corrections[i], image)
if p.overlay_images is not None and i < len(p.overlay_images):
@ -474,7 +499,8 @@ def process_images(p: StableDiffusionProcessing, aesthetic_lr=0, aesthetic_weigh
image = image.convert('RGB')
if opts.samples_save and not p.do_not_save_samples:
images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p)
images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format,
info=infotext(n, i), p=p)
text = infotext(n, i)
infotexts.append(text)
@ -482,7 +508,7 @@ def process_images(p: StableDiffusionProcessing, aesthetic_lr=0, aesthetic_weigh
image.info["parameters"] = text
output_images.append(image)
del x_samples_ddim
del x_samples_ddim
devices.torch_gc()
@ -504,10 +530,13 @@ def process_images(p: StableDiffusionProcessing, aesthetic_lr=0, aesthetic_weigh
index_of_first_image = 1
if opts.grid_save:
images.save_image(grid, p.outpath_grids, "grid", all_seeds[0], all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
images.save_image(grid, p.outpath_grids, "grid", all_seeds[0], all_prompts[0], opts.grid_format,
info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
devices.torch_gc()
return Processed(p, output_images, all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=all_subseeds[0], all_prompts=all_prompts, all_seeds=all_seeds, all_subseeds=all_subseeds, index_of_first_image=index_of_first_image, infotexts=infotexts)
return Processed(p, output_images, all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]),
subseed=all_subseeds[0], all_prompts=all_prompts, all_seeds=all_seeds, all_subseeds=all_subseeds,
index_of_first_image=index_of_first_image, infotexts=infotexts)
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
@ -543,25 +572,34 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
if not self.enable_hr:
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds,
subseeds=subseeds, subseed_strength=self.subseed_strength,
seed_resize_from_h=self.seed_resize_from_h,
seed_resize_from_w=self.seed_resize_from_w, p=self)
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning)
return samples
x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds,
subseeds=subseeds, subseed_strength=self.subseed_strength,
seed_resize_from_h=self.seed_resize_from_h,
seed_resize_from_w=self.seed_resize_from_w, p=self)
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning)
truncate_x = (self.firstphase_width - self.firstphase_width_truncated) // opt_f
truncate_y = (self.firstphase_height - self.firstphase_height_truncated) // opt_f
samples = samples[:, :, truncate_y//2:samples.shape[2]-truncate_y//2, truncate_x//2:samples.shape[3]-truncate_x//2]
samples = samples[:, :, truncate_y // 2:samples.shape[2] - truncate_y // 2,
truncate_x // 2:samples.shape[3] - truncate_x // 2]
if self.scale_latent:
samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f),
mode="bilinear")
else:
decoded_samples = decode_first_stage(self.sd_model, samples)
if opts.upscaler_for_img2img is None or opts.upscaler_for_img2img == "None":
decoded_samples = torch.nn.functional.interpolate(decoded_samples, size=(self.height, self.width), mode="bilinear")
decoded_samples = torch.nn.functional.interpolate(decoded_samples, size=(self.height, self.width),
mode="bilinear")
else:
lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
@ -585,13 +623,16 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds,
subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h,
seed_resize_from_w=self.seed_resize_from_w, p=self)
# GC now before running the next img2img to prevent running out of memory
x = None
devices.torch_gc()
samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps)
samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning,
steps=self.steps)
return samples
@ -599,7 +640,9 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
sampler = None
def __init__(self, init_images=None, resize_mode=0, denoising_strength=0.75, mask=None, mask_blur=4, inpainting_fill=0, inpaint_full_res=True, inpaint_full_res_padding=0, inpainting_mask_invert=0, **kwargs):
def __init__(self, init_images=None, resize_mode=0, denoising_strength=0.75, mask=None, mask_blur=4,
inpainting_fill=0, inpaint_full_res=True, inpaint_full_res_padding=0, inpainting_mask_invert=0,
**kwargs):
super().__init__(**kwargs)
self.init_images = init_images
@ -607,7 +650,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
self.denoising_strength: float = denoising_strength
self.init_latent = None
self.image_mask = mask
#self.image_unblurred_mask = None
# self.image_unblurred_mask = None
self.latent_mask = None
self.mask_for_overlay = None
self.mask_blur = mask_blur
@ -619,7 +662,8 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
self.nmask = None
def init(self, all_prompts, all_seeds, all_subseeds):
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers_for_img2img, self.sampler_index, self.sd_model)
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers_for_img2img, self.sampler_index,
self.sd_model)
crop_region = None
if self.image_mask is not None:
@ -628,7 +672,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
if self.inpainting_mask_invert:
self.image_mask = ImageOps.invert(self.image_mask)
#self.image_unblurred_mask = self.image_mask
# self.image_unblurred_mask = self.image_mask
if self.mask_blur > 0:
self.image_mask = self.image_mask.filter(ImageFilter.GaussianBlur(self.mask_blur))
@ -642,7 +686,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
mask = mask.crop(crop_region)
self.image_mask = images.resize_image(2, mask, self.width, self.height)
self.paste_to = (x1, y1, x2-x1, y2-y1)
self.paste_to = (x1, y1, x2 - x1, y2 - y1)
else:
self.image_mask = images.resize_image(self.resize_mode, self.image_mask, self.width, self.height)
np_mask = np.array(self.image_mask)
@ -665,7 +709,8 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
if self.image_mask is not None:
image_masked = Image.new('RGBa', (image.width, image.height))
image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(self.mask_for_overlay.convert('L')))
image_masked.paste(image.convert("RGBA").convert("RGBa"),
mask=ImageOps.invert(self.mask_for_overlay.convert('L')))
self.overlay_images.append(image_masked.convert('RGBA'))
@ -714,12 +759,17 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
# this needs to be fixed to be done in sample() using actual seeds for batches
if self.inpainting_fill == 2:
self.init_latent = self.init_latent * self.mask + create_random_tensors(self.init_latent.shape[1:], all_seeds[0:self.init_latent.shape[0]]) * self.nmask
self.init_latent = self.init_latent * self.mask + create_random_tensors(self.init_latent.shape[1:],
all_seeds[
0:self.init_latent.shape[
0]]) * self.nmask
elif self.inpainting_fill == 3:
self.init_latent = self.init_latent * self.mask
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds,
subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h,
seed_resize_from_w=self.seed_resize_from_w, p=self)
samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning)

View file

@ -14,7 +14,8 @@ from modules.sd_hijack_optimizations import invokeAI_mps_available
import ldm.modules.attention
import ldm.modules.diffusionmodules.model
from transformers import CLIPVisionModel, CLIPModel
from tqdm import trange
from transformers import CLIPVisionModel, CLIPModel, CLIPTokenizer
import torch.optim as optim
import copy
@ -22,21 +23,25 @@ attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
def apply_optimizations():
undo_optimizations()
ldm.modules.diffusionmodules.model.nonlinearity = silu
if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (8, 6)):
if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (
6, 0) <= torch.cuda.get_device_capability(shared.device) <= (8, 6)):
print("Applying xformers cross attention optimization.")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward
elif cmd_opts.opt_split_attention_v1:
print("Applying v1 cross attention optimization.")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention_invokeai or not torch.cuda.is_available()):
elif not cmd_opts.disable_opt_split_attention and (
cmd_opts.opt_split_attention_invokeai or not torch.cuda.is_available()):
if not invokeAI_mps_available and shared.device.type == 'mps':
print("The InvokeAI cross attention optimization for MPS requires the psutil package which is not installed.")
print(
"The InvokeAI cross attention optimization for MPS requires the psutil package which is not installed.")
print("Applying v1 cross attention optimization.")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
else:
@ -112,14 +117,16 @@ class StableDiffusionModelHijack:
_, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text])
return remade_batch_tokens[0], token_count, get_target_prompt_token_count(token_count)
def slerp(low, high, val):
low_norm = low/torch.norm(low, dim=1, keepdim=True)
high_norm = high/torch.norm(high, dim=1, keepdim=True)
omega = torch.acos((low_norm*high_norm).sum(1))
low_norm = low / torch.norm(low, dim=1, keepdim=True)
high_norm = high / torch.norm(high, dim=1, keepdim=True)
omega = torch.acos((low_norm * high_norm).sum(1))
so = torch.sin(omega)
res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high
res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low + (torch.sin(val * omega) / so).unsqueeze(1) * high
return res
class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
def __init__(self, wrapped, hijack):
super().__init__()
@ -128,6 +135,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
self.wrapped.transformer.name_or_path
)
del self.clipModel.vision_model
self.tokenizer = CLIPTokenizer.from_pretrained(self.wrapped.transformer.name_or_path)
self.hijack: StableDiffusionModelHijack = hijack
self.tokenizer = wrapped.tokenizer
# self.vision = CLIPVisionModel.from_pretrained(self.wrapped.transformer.name_or_path).eval()
@ -139,7 +147,8 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ',</w>'][0]
tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k]
tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if
'(' in k or ')' in k or '[' in k or ']' in k]
for text, ident in tokens_with_parens:
mult = 1.0
for c in text:
@ -155,8 +164,13 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
if mult != 1.0:
self.token_mults[ident] = mult
def set_aesthetic_params(self, aesthetic_lr, aesthetic_weight, aesthetic_steps, image_embs_name=None,
aesthetic_slerp=True):
def set_aesthetic_params(self, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0, image_embs_name=None,
aesthetic_slerp=True, aesthetic_imgs_text="",
aesthetic_slerp_angle=0.15,
aesthetic_text_negative=False):
self.aesthetic_imgs_text = aesthetic_imgs_text
self.aesthetic_slerp_angle = aesthetic_slerp_angle
self.aesthetic_text_negative = aesthetic_text_negative
self.slerp = aesthetic_slerp
self.aesthetic_lr = aesthetic_lr
self.aesthetic_weight = aesthetic_weight
@ -180,7 +194,8 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
else:
parsed = [[line, 1.0]]
tokenized = self.wrapped.tokenizer([text for text, _ in parsed], truncation=False, add_special_tokens=False)["input_ids"]
tokenized = self.wrapped.tokenizer([text for text, _ in parsed], truncation=False, add_special_tokens=False)[
"input_ids"]
fixes = []
remade_tokens = []
@ -196,18 +211,20 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
if token == self.comma_token:
last_comma = len(remade_tokens)
elif opts.comma_padding_backtrack != 0 and max(len(remade_tokens), 1) % 75 == 0 and last_comma != -1 and len(remade_tokens) - last_comma <= opts.comma_padding_backtrack:
elif opts.comma_padding_backtrack != 0 and max(len(remade_tokens),
1) % 75 == 0 and last_comma != -1 and len(
remade_tokens) - last_comma <= opts.comma_padding_backtrack:
last_comma += 1
reloc_tokens = remade_tokens[last_comma:]
reloc_mults = multipliers[last_comma:]
remade_tokens = remade_tokens[:last_comma]
length = len(remade_tokens)
rem = int(math.ceil(length / 75)) * 75 - length
remade_tokens += [id_end] * rem + reloc_tokens
multipliers = multipliers[:last_comma] + [1.0] * rem + reloc_mults
if embedding is None:
remade_tokens.append(token)
multipliers.append(weight)
@ -248,7 +265,8 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
if line in cache:
remade_tokens, fixes, multipliers = cache[line]
else:
remade_tokens, fixes, multipliers, current_token_count = self.tokenize_line(line, used_custom_terms, hijack_comments)
remade_tokens, fixes, multipliers, current_token_count = self.tokenize_line(line, used_custom_terms,
hijack_comments)
token_count = max(current_token_count, token_count)
cache[line] = (remade_tokens, fixes, multipliers)
@ -259,7 +277,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
def process_text_old(self, text):
id_start = self.wrapped.tokenizer.bos_token_id
id_end = self.wrapped.tokenizer.eos_token_id
@ -289,7 +306,8 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
while i < len(tokens):
token = tokens[i]
embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens,
i)
mult_change = self.token_mults.get(token) if opts.enable_emphasis else None
if mult_change is not None:
@ -312,11 +330,12 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
ovf = remade_tokens[maxlen - 2:]
overflowing_words = [vocab.get(int(x), "") for x in ovf]
overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
hijack_comments.append(
f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
token_count = len(remade_tokens)
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end]
remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end]
cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
@ -326,23 +345,26 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
hijack_fixes.append(fixes)
batch_multipliers.append(multipliers)
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
def forward(self, text):
use_old = opts.use_old_emphasis_implementation
if use_old:
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text)
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(
text)
else:
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text)
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(
text)
self.hijack.comments += hijack_comments
if len(used_custom_terms) > 0:
self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
self.hijack.comments.append(
"Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
if use_old:
self.hijack.fixes = hijack_fixes
return self.process_tokens(remade_batch_tokens, batch_multipliers)
z = None
i = 0
while max(map(len, remade_batch_tokens)) != 0:
@ -356,7 +378,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
if fix[0] == i:
fixes.append(fix[1])
self.hijack.fixes.append(fixes)
tokens = []
multipliers = []
for j in range(len(remade_batch_tokens)):
@ -378,19 +400,30 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
remade_batch_tokens]
tokens = torch.asarray(remade_batch_tokens).to(device)
model = copy.deepcopy(self.clipModel).to(device)
model.requires_grad_(True)
if self.aesthetic_imgs_text is not None and len(self.aesthetic_imgs_text) > 0:
text_embs_2 = model.get_text_features(
**self.tokenizer([self.aesthetic_imgs_text], padding=True, return_tensors="pt").to(device))
if self.aesthetic_text_negative:
text_embs_2 = self.image_embs - text_embs_2
text_embs_2 /= text_embs_2.norm(dim=-1, keepdim=True)
img_embs = slerp(self.image_embs, text_embs_2, self.aesthetic_slerp_angle)
else:
img_embs = self.image_embs
with torch.enable_grad():
model = copy.deepcopy(self.clipModel).to(device)
model.requires_grad_(True)
# We optimize the model to maximize the similarity
optimizer = optim.Adam(
model.text_model.parameters(), lr=self.aesthetic_lr
)
for i in range(self.aesthetic_steps):
for i in trange(self.aesthetic_steps, desc="Aesthetic optimization"):
text_embs = model.get_text_features(input_ids=tokens)
text_embs = text_embs / text_embs.norm(dim=-1, keepdim=True)
sim = text_embs @ self.image_embs.T
sim = text_embs @ img_embs.T
loss = -sim
optimizer.zero_grad()
loss.mean().backward()
@ -405,6 +438,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
model.cpu()
del model
zn = torch.concat([zn for i in range(z.shape[1] // 77)], 1)
if self.slerp:
z = slerp(z, zn, self.aesthetic_weight)
else:
@ -413,15 +447,16 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
remade_batch_tokens = rem_tokens
batch_multipliers = rem_multipliers
i += 1
return z
def process_tokens(self, remade_batch_tokens, batch_multipliers):
if not opts.use_old_emphasis_implementation:
remade_batch_tokens = [[self.wrapped.tokenizer.bos_token_id] + x[:75] + [self.wrapped.tokenizer.eos_token_id] for x in remade_batch_tokens]
remade_batch_tokens = [
[self.wrapped.tokenizer.bos_token_id] + x[:75] + [self.wrapped.tokenizer.eos_token_id] for x in
remade_batch_tokens]
batch_multipliers = [[1.0] + x[:75] + [1.0] for x in batch_multipliers]
tokens = torch.asarray(remade_batch_tokens).to(device)
outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers)
@ -461,8 +496,8 @@ class EmbeddingsWithFixes(torch.nn.Module):
for fixes, tensor in zip(batch_fixes, inputs_embeds):
for offset, embedding in fixes:
emb = embedding.vec
emb_len = min(tensor.shape[0]-offset-1, emb.shape[0])
tensor = torch.cat([tensor[0:offset+1], emb[0:emb_len], tensor[offset+1+emb_len:]])
emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])
tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]])
vecs.append(tensor)

View file

@ -95,6 +95,10 @@ loaded_hypernetwork = None
aesthetic_embeddings = {f.replace(".pt",""): os.path.join(cmd_opts.aesthetic_embeddings_dir, f) for f in
os.listdir(cmd_opts.aesthetic_embeddings_dir) if f.endswith(".pt")}
def update_aesthetic_embeddings():
global aesthetic_embeddings
aesthetic_embeddings = {f.replace(".pt",""): os.path.join(cmd_opts.aesthetic_embeddings_dir, f) for f in
os.listdir(cmd_opts.aesthetic_embeddings_dir) if f.endswith(".pt")}
def reload_hypernetworks():
global hypernetworks

View file

@ -13,7 +13,11 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
aesthetic_lr=0,
aesthetic_weight=0, aesthetic_steps=0,
aesthetic_imgs=None,
aesthetic_slerp=False, *args):
aesthetic_slerp=False,
aesthetic_imgs_text="",
aesthetic_slerp_angle=0.15,
aesthetic_text_negative=False,
*args):
p = StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
@ -47,7 +51,9 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
processed = modules.scripts.scripts_txt2img.run(p, *args)
if processed is None:
processed = process_images(p, aesthetic_lr, aesthetic_weight, aesthetic_steps, aesthetic_imgs, aesthetic_slerp)
processed = process_images(p, aesthetic_lr, aesthetic_weight, aesthetic_steps, aesthetic_imgs, aesthetic_slerp,aesthetic_imgs_text,
aesthetic_slerp_angle,
aesthetic_text_negative)
shared.total_tqdm.clear()

View file

@ -41,6 +41,7 @@ from modules import prompt_parser
from modules.images import save_image
import modules.textual_inversion.ui
import modules.hypernetworks.ui
import modules.aesthetic_clip
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI
mimetypes.init()
@ -449,7 +450,7 @@ def create_toprow(is_img2img):
with gr.Row():
negative_prompt = gr.Textbox(label="Negative prompt", elem_id="negative_prompt", show_label=False, placeholder="Negative prompt", lines=2)
with gr.Column(scale=1, elem_id="roll_col"):
sh = gr.Button(elem_id="sh", visible=True)
sh = gr.Button(elem_id="sh", visible=True)
with gr.Column(scale=1, elem_id="style_neg_col"):
prompt_style2 = gr.Dropdown(label="Style 2", elem_id=f"{id_part}_style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())), visible=len(shared.prompt_styles.styles) > 1)
@ -536,9 +537,13 @@ def create_ui(wrap_gradio_gpu_call):
height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
with gr.Group():
aesthetic_lr = gr.Textbox(label='Learning rate', placeholder="Learning rate", value="0.005")
aesthetic_weight = gr.Slider(minimum=0, maximum=1, step=0.01, label="Aesthetic weight", value=0.7)
aesthetic_steps = gr.Slider(minimum=0, maximum=50, step=1, label="Aesthetic steps", value=50)
aesthetic_lr = gr.Textbox(label='Learning rate', placeholder="Learning rate", value="0.0001")
aesthetic_weight = gr.Slider(minimum=0, maximum=1, step=0.01, label="Aesthetic weight", value=0.9)
aesthetic_steps = gr.Slider(minimum=0, maximum=256, step=1, label="Aesthetic steps", value=5)
with gr.Row():
aesthetic_imgs_text = gr.Textbox(label='Aesthetic text for imgs', placeholder="This text is used to rotate the feature space of the imgs embs", value="")
aesthetic_slerp_angle = gr.Slider(label='Slerp angle',minimum=0, maximum=1, step=0.01, value=0.1)
aesthetic_text_negative = gr.Checkbox(label="Is negative text", value=False)
aesthetic_imgs = gr.Dropdown(sorted(aesthetic_embeddings.keys()), label="Imgs embedding", value=sorted(aesthetic_embeddings.keys())[0] if len(aesthetic_embeddings) > 0 else None)
aesthetic_slerp = gr.Checkbox(label="Slerp interpolation", value=False)
@ -617,7 +622,10 @@ def create_ui(wrap_gradio_gpu_call):
aesthetic_weight,
aesthetic_steps,
aesthetic_imgs,
aesthetic_slerp
aesthetic_slerp,
aesthetic_imgs_text,
aesthetic_slerp_angle,
aesthetic_text_negative
] + custom_inputs,
outputs=[
txt2img_gallery,
@ -721,7 +729,7 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Row():
inpaint_full_res = gr.Checkbox(label='Inpaint at full resolution', value=False)
inpaint_full_res_padding = gr.Slider(label='Inpaint at full resolution padding, pixels', minimum=0, maximum=256, step=4, value=32)
inpaint_full_res_padding = gr.Slider(label='Inpaint at full resolution padding, pixels', minimum=0, maximum=1024, step=4, value=32)
with gr.TabItem('Batch img2img', id='batch'):
hidden = '<br>Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else ''
@ -1071,6 +1079,17 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Column():
create_embedding = gr.Button(value="Create embedding", variant='primary')
with gr.Tab(label="Create images embedding"):
new_embedding_name_ae = gr.Textbox(label="Name")
process_src_ae = gr.Textbox(label='Source directory')
batch_ae = gr.Slider(minimum=1, maximum=1024, step=1, label="Batch size", value=256)
with gr.Row():
with gr.Column(scale=3):
gr.HTML(value="")
with gr.Column():
create_embedding_ae = gr.Button(value="Create images embedding", variant='primary')
with gr.Tab(label="Create hypernetwork"):
new_hypernetwork_name = gr.Textbox(label="Name")
new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"])
@ -1139,7 +1158,7 @@ def create_ui(wrap_gradio_gpu_call):
fn=modules.textual_inversion.ui.create_embedding,
inputs=[
new_embedding_name,
initialization_text,
process_src,
nvpt,
],
outputs=[
@ -1149,6 +1168,20 @@ def create_ui(wrap_gradio_gpu_call):
]
)
create_embedding_ae.click(
fn=modules.aesthetic_clip.generate_imgs_embd,
inputs=[
new_embedding_name_ae,
process_src_ae,
batch_ae
],
outputs=[
aesthetic_imgs,
ti_output,
ti_outcome,
]
)
create_hypernetwork.click(
fn=modules.hypernetworks.ui.create_hypernetwork,
inputs=[