fix to tokens lenght, addend embs generator, add new features to edit the embedding before the generation using text

This commit is contained in:
MalumaDev 2022-10-15 15:59:37 +02:00
parent bb57f30c2d
commit 37d7ffb415
6 changed files with 302 additions and 96 deletions

78
modules/aesthetic_clip.py Normal file
View file

@ -0,0 +1,78 @@
import itertools
import os
from pathlib import Path
import html
import gc
import gradio as gr
import torch
from PIL import Image
from modules import shared
from modules.shared import device, aesthetic_embeddings
from transformers import CLIPModel, CLIPProcessor
from tqdm.auto import tqdm
def get_all_images_in_folder(folder):
return [os.path.join(folder, f) for f in os.listdir(folder) if
os.path.isfile(os.path.join(folder, f)) and check_is_valid_image_file(f)]
def check_is_valid_image_file(filename):
return filename.lower().endswith(('.png', '.jpg', '.jpeg'))
def batched(dataset, total, n=1):
for ndx in range(0, total, n):
yield [dataset.__getitem__(i) for i in range(ndx, min(ndx + n, total))]
def iter_to_batched(iterable, n=1):
it = iter(iterable)
while True:
chunk = tuple(itertools.islice(it, n))
if not chunk:
return
yield chunk
def generate_imgs_embd(name, folder, batch_size):
# clipModel = CLIPModel.from_pretrained(
# shared.sd_model.cond_stage_model.clipModel.name_or_path
# )
model = CLIPModel.from_pretrained(shared.sd_model.cond_stage_model.clipModel.name_or_path).to(device)
processor = CLIPProcessor.from_pretrained(shared.sd_model.cond_stage_model.clipModel.name_or_path)
with torch.no_grad():
embs = []
for paths in tqdm(iter_to_batched(get_all_images_in_folder(folder), batch_size),
desc=f"Generating embeddings for {name}"):
if shared.state.interrupted:
break
inputs = processor(images=[Image.open(path) for path in paths], return_tensors="pt").to(device)
outputs = model.get_image_features(**inputs).cpu()
embs.append(torch.clone(outputs))
inputs.to("cpu")
del inputs, outputs
embs = torch.cat(embs, dim=0).mean(dim=0, keepdim=True)
# The generated embedding will be located here
path = str(Path(shared.cmd_opts.aesthetic_embeddings_dir) / f"{name}.pt")
torch.save(embs, path)
model = model.cpu()
del model
del processor
del embs
gc.collect()
torch.cuda.empty_cache()
res = f"""
Done generating embedding for {name}!
Hypernetwork saved to {html.escape(path)}
"""
shared.update_aesthetic_embeddings()
return gr.Dropdown(sorted(aesthetic_embeddings.keys()), label="Imgs embedding",
value=sorted(aesthetic_embeddings.keys())[0] if len(
aesthetic_embeddings) > 0 else None), res, ""

View file

@ -20,7 +20,6 @@ import modules.images as images
import modules.styles import modules.styles
import logging import logging
# some of those options should not be changed at all because they would break the model, so I removed them from options. # some of those options should not be changed at all because they would break the model, so I removed them from options.
opt_C = 4 opt_C = 4
opt_f = 8 opt_f = 8
@ -52,8 +51,13 @@ def get_correct_sampler(p):
elif isinstance(p, modules.processing.StableDiffusionProcessingImg2Img): elif isinstance(p, modules.processing.StableDiffusionProcessingImg2Img):
return sd_samplers.samplers_for_img2img return sd_samplers.samplers_for_img2img
class StableDiffusionProcessing: class StableDiffusionProcessing:
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", styles=None, seed=-1, subseed=-1, subseed_strength=0, seed_resize_from_h=-1, seed_resize_from_w=-1, seed_enable_extras=True, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None, eta=None): def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", styles=None, seed=-1,
subseed=-1, subseed_strength=0, seed_resize_from_h=-1, seed_resize_from_w=-1, seed_enable_extras=True,
sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512,
restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False,
extra_generation_params=None, overlay_images=None, negative_prompt=None, eta=None):
self.sd_model = sd_model self.sd_model = sd_model
self.outpath_samples: str = outpath_samples self.outpath_samples: str = outpath_samples
self.outpath_grids: str = outpath_grids self.outpath_grids: str = outpath_grids
@ -104,7 +108,8 @@ class StableDiffusionProcessing:
class Processed: class Processed:
def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None): def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None,
all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None):
self.images = images_list self.images = images_list
self.prompt = p.prompt self.prompt = p.prompt
self.negative_prompt = p.negative_prompt self.negative_prompt = p.negative_prompt
@ -141,7 +146,8 @@ class Processed:
self.prompt = self.prompt if type(self.prompt) != list else self.prompt[0] self.prompt = self.prompt if type(self.prompt) != list else self.prompt[0]
self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0] self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0]
self.seed = int(self.seed if type(self.seed) != list else self.seed[0]) self.seed = int(self.seed if type(self.seed) != list else self.seed[0])
self.subseed = int(self.subseed if type(self.subseed) != list else self.subseed[0]) if self.subseed is not None else -1 self.subseed = int(
self.subseed if type(self.subseed) != list else self.subseed[0]) if self.subseed is not None else -1
self.all_prompts = all_prompts or [self.prompt] self.all_prompts = all_prompts or [self.prompt]
self.all_seeds = all_seeds or [self.seed] self.all_seeds = all_seeds or [self.seed]
@ -181,39 +187,43 @@ class Processed:
return json.dumps(obj) return json.dumps(obj)
def infotext(self, p: StableDiffusionProcessing, index): def infotext(self, p: StableDiffusionProcessing, index):
return create_infotext(p, self.all_prompts, self.all_seeds, self.all_subseeds, comments=[], position_in_batch=index % self.batch_size, iteration=index // self.batch_size) return create_infotext(p, self.all_prompts, self.all_seeds, self.all_subseeds, comments=[],
position_in_batch=index % self.batch_size, iteration=index // self.batch_size)
# from https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3 # from https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3
def slerp(val, low, high): def slerp(val, low, high):
low_norm = low/torch.norm(low, dim=1, keepdim=True) low_norm = low / torch.norm(low, dim=1, keepdim=True)
high_norm = high/torch.norm(high, dim=1, keepdim=True) high_norm = high / torch.norm(high, dim=1, keepdim=True)
dot = (low_norm*high_norm).sum(1) dot = (low_norm * high_norm).sum(1)
if dot.mean() > 0.9995: if dot.mean() > 0.9995:
return low * val + high * (1 - val) return low * val + high * (1 - val)
omega = torch.acos(dot) omega = torch.acos(dot)
so = torch.sin(omega) so = torch.sin(omega)
res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low + (torch.sin(val * omega) / so).unsqueeze(1) * high
return res return res
def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0, p=None): def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0,
p=None):
xs = [] xs = []
# if we have multiple seeds, this means we are working with batch size>1; this then # if we have multiple seeds, this means we are working with batch size>1; this then
# enables the generation of additional tensors with noise that the sampler will use during its processing. # enables the generation of additional tensors with noise that the sampler will use during its processing.
# Using those pre-generated tensors instead of simple torch.randn allows a batch with seeds [100, 101] to # Using those pre-generated tensors instead of simple torch.randn allows a batch with seeds [100, 101] to
# produce the same images as with two batches [100], [101]. # produce the same images as with two batches [100], [101].
if p is not None and p.sampler is not None and (len(seeds) > 1 and opts.enable_batch_seeds or opts.eta_noise_seed_delta > 0): if p is not None and p.sampler is not None and (
len(seeds) > 1 and opts.enable_batch_seeds or opts.eta_noise_seed_delta > 0):
sampler_noises = [[] for _ in range(p.sampler.number_of_needed_noises(p))] sampler_noises = [[] for _ in range(p.sampler.number_of_needed_noises(p))]
else: else:
sampler_noises = None sampler_noises = None
for i, seed in enumerate(seeds): for i, seed in enumerate(seeds):
noise_shape = shape if seed_resize_from_h <= 0 or seed_resize_from_w <= 0 else (shape[0], seed_resize_from_h//8, seed_resize_from_w//8) noise_shape = shape if seed_resize_from_h <= 0 or seed_resize_from_w <= 0 else (
shape[0], seed_resize_from_h // 8, seed_resize_from_w // 8)
subnoise = None subnoise = None
if subseeds is not None: if subseeds is not None:
@ -241,7 +251,7 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
dx = max(-dx, 0) dx = max(-dx, 0)
dy = max(-dy, 0) dy = max(-dy, 0)
x[:, ty:ty+h, tx:tx+w] = noise[:, dy:dy+h, dx:dx+w] x[:, ty:ty + h, tx:tx + w] = noise[:, dy:dy + h, dx:dx + w]
noise = x noise = x
if sampler_noises is not None: if sampler_noises is not None:
@ -293,14 +303,20 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
"Seed": all_seeds[index], "Seed": all_seeds[index],
"Face restoration": (opts.face_restoration_model if p.restore_faces else None), "Face restoration": (opts.face_restoration_model if p.restore_faces else None),
"Size": f"{p.width}x{p.height}", "Size": f"{p.width}x{p.height}",
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash), "Model hash": getattr(p, 'sd_model_hash',
"Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')), None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
"Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name.replace(',', '').replace(':', '')), "Model": (
None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(
',', '').replace(':', '')),
"Hypernet": (
None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name.replace(',', '').replace(
':', '')),
"Batch size": (None if p.batch_size < 2 else p.batch_size), "Batch size": (None if p.batch_size < 2 else p.batch_size),
"Batch pos": (None if p.batch_size < 2 else position_in_batch), "Batch pos": (None if p.batch_size < 2 else position_in_batch),
"Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]), "Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength), "Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
"Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"), "Seed resize from": (
None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
"Denoising strength": getattr(p, 'denoising_strength', None), "Denoising strength": getattr(p, 'denoising_strength', None),
"Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta), "Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta),
"Clip skip": None if clip_skip <= 1 else clip_skip, "Clip skip": None if clip_skip <= 1 else clip_skip,
@ -309,7 +325,8 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
generation_params.update(p.extra_generation_params) generation_params.update(p.extra_generation_params)
generation_params_text = ", ".join([k if k == v else f'{k}: {v}' for k, v in generation_params.items() if v is not None]) generation_params_text = ", ".join(
[k if k == v else f'{k}: {v}' for k, v in generation_params.items() if v is not None])
negative_prompt_text = "\nNegative prompt: " + p.negative_prompt if p.negative_prompt else "" negative_prompt_text = "\nNegative prompt: " + p.negative_prompt if p.negative_prompt else ""
@ -317,7 +334,9 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
def process_images(p: StableDiffusionProcessing, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0, def process_images(p: StableDiffusionProcessing, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0,
aesthetic_imgs=None,aesthetic_slerp=False) -> Processed: aesthetic_imgs=None, aesthetic_slerp=False, aesthetic_imgs_text="",
aesthetic_slerp_angle=0.15,
aesthetic_text_negative=False) -> Processed:
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch""" """this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
aesthetic_lr = float(aesthetic_lr) aesthetic_lr = float(aesthetic_lr)
@ -385,7 +404,7 @@ def process_images(p: StableDiffusionProcessing, aesthetic_lr=0, aesthetic_weigh
for n in range(p.n_iter): for n in range(p.n_iter):
if state.skipped: if state.skipped:
state.skipped = False state.skipped = False
if state.interrupted: if state.interrupted:
break break
@ -396,16 +415,19 @@ def process_images(p: StableDiffusionProcessing, aesthetic_lr=0, aesthetic_weigh
if (len(prompts) == 0): if (len(prompts) == 0):
break break
#uc = p.sd_model.get_learned_conditioning(len(prompts) * [p.negative_prompt]) # uc = p.sd_model.get_learned_conditioning(len(prompts) * [p.negative_prompt])
#c = p.sd_model.get_learned_conditioning(prompts) # c = p.sd_model.get_learned_conditioning(prompts)
with devices.autocast(): with devices.autocast():
if hasattr(shared.sd_model.cond_stage_model, "set_aesthetic_params"): if hasattr(shared.sd_model.cond_stage_model, "set_aesthetic_params"):
shared.sd_model.cond_stage_model.set_aesthetic_params(0, 0, 0) shared.sd_model.cond_stage_model.set_aesthetic_params()
uc = prompt_parser.get_learned_conditioning(shared.sd_model, len(prompts) * [p.negative_prompt], uc = prompt_parser.get_learned_conditioning(shared.sd_model, len(prompts) * [p.negative_prompt],
p.steps) p.steps)
if hasattr(shared.sd_model.cond_stage_model, "set_aesthetic_params"): if hasattr(shared.sd_model.cond_stage_model, "set_aesthetic_params"):
shared.sd_model.cond_stage_model.set_aesthetic_params(aesthetic_lr, aesthetic_weight, shared.sd_model.cond_stage_model.set_aesthetic_params(aesthetic_lr, aesthetic_weight,
aesthetic_steps, aesthetic_imgs,aesthetic_slerp) aesthetic_steps, aesthetic_imgs,
aesthetic_slerp, aesthetic_imgs_text,
aesthetic_slerp_angle,
aesthetic_text_negative)
c = prompt_parser.get_multicond_learned_conditioning(shared.sd_model, prompts, p.steps) c = prompt_parser.get_multicond_learned_conditioning(shared.sd_model, prompts, p.steps)
if len(model_hijack.comments) > 0: if len(model_hijack.comments) > 0:
@ -413,13 +435,13 @@ def process_images(p: StableDiffusionProcessing, aesthetic_lr=0, aesthetic_weigh
comments[comment] = 1 comments[comment] = 1
if p.n_iter > 1: if p.n_iter > 1:
shared.state.job = f"Batch {n+1} out of {p.n_iter}" shared.state.job = f"Batch {n + 1} out of {p.n_iter}"
with devices.autocast(): with devices.autocast():
samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength) samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds,
subseed_strength=p.subseed_strength)
if state.interrupted or state.skipped: if state.interrupted or state.skipped:
# if we are interrupted, sample returns just noise # if we are interrupted, sample returns just noise
# use the image collected previously in sampler loop # use the image collected previously in sampler loop
samples_ddim = shared.state.current_latent samples_ddim = shared.state.current_latent
@ -445,7 +467,9 @@ def process_images(p: StableDiffusionProcessing, aesthetic_lr=0, aesthetic_weigh
if p.restore_faces: if p.restore_faces:
if opts.save and not p.do_not_save_samples and opts.save_images_before_face_restoration: if opts.save and not p.do_not_save_samples and opts.save_images_before_face_restoration:
images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-face-restoration") images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", seeds[i], prompts[i],
opts.samples_format, info=infotext(n, i), p=p,
suffix="-before-face-restoration")
devices.torch_gc() devices.torch_gc()
@ -456,7 +480,8 @@ def process_images(p: StableDiffusionProcessing, aesthetic_lr=0, aesthetic_weigh
if p.color_corrections is not None and i < len(p.color_corrections): if p.color_corrections is not None and i < len(p.color_corrections):
if opts.save and not p.do_not_save_samples and opts.save_images_before_color_correction: if opts.save and not p.do_not_save_samples and opts.save_images_before_color_correction:
images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-color-correction") images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format,
info=infotext(n, i), p=p, suffix="-before-color-correction")
image = apply_color_correction(p.color_corrections[i], image) image = apply_color_correction(p.color_corrections[i], image)
if p.overlay_images is not None and i < len(p.overlay_images): if p.overlay_images is not None and i < len(p.overlay_images):
@ -474,7 +499,8 @@ def process_images(p: StableDiffusionProcessing, aesthetic_lr=0, aesthetic_weigh
image = image.convert('RGB') image = image.convert('RGB')
if opts.samples_save and not p.do_not_save_samples: if opts.samples_save and not p.do_not_save_samples:
images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p) images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format,
info=infotext(n, i), p=p)
text = infotext(n, i) text = infotext(n, i)
infotexts.append(text) infotexts.append(text)
@ -482,7 +508,7 @@ def process_images(p: StableDiffusionProcessing, aesthetic_lr=0, aesthetic_weigh
image.info["parameters"] = text image.info["parameters"] = text
output_images.append(image) output_images.append(image)
del x_samples_ddim del x_samples_ddim
devices.torch_gc() devices.torch_gc()
@ -504,10 +530,13 @@ def process_images(p: StableDiffusionProcessing, aesthetic_lr=0, aesthetic_weigh
index_of_first_image = 1 index_of_first_image = 1
if opts.grid_save: if opts.grid_save:
images.save_image(grid, p.outpath_grids, "grid", all_seeds[0], all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True) images.save_image(grid, p.outpath_grids, "grid", all_seeds[0], all_prompts[0], opts.grid_format,
info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
devices.torch_gc() devices.torch_gc()
return Processed(p, output_images, all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=all_subseeds[0], all_prompts=all_prompts, all_seeds=all_seeds, all_subseeds=all_subseeds, index_of_first_image=index_of_first_image, infotexts=infotexts) return Processed(p, output_images, all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]),
subseed=all_subseeds[0], all_prompts=all_prompts, all_seeds=all_seeds, all_subseeds=all_subseeds,
index_of_first_image=index_of_first_image, infotexts=infotexts)
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
@ -543,25 +572,34 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model) self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
if not self.enable_hr: if not self.enable_hr:
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds,
subseeds=subseeds, subseed_strength=self.subseed_strength,
seed_resize_from_h=self.seed_resize_from_h,
seed_resize_from_w=self.seed_resize_from_w, p=self)
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning) samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning)
return samples return samples
x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds,
subseeds=subseeds, subseed_strength=self.subseed_strength,
seed_resize_from_h=self.seed_resize_from_h,
seed_resize_from_w=self.seed_resize_from_w, p=self)
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning) samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning)
truncate_x = (self.firstphase_width - self.firstphase_width_truncated) // opt_f truncate_x = (self.firstphase_width - self.firstphase_width_truncated) // opt_f
truncate_y = (self.firstphase_height - self.firstphase_height_truncated) // opt_f truncate_y = (self.firstphase_height - self.firstphase_height_truncated) // opt_f
samples = samples[:, :, truncate_y//2:samples.shape[2]-truncate_y//2, truncate_x//2:samples.shape[3]-truncate_x//2] samples = samples[:, :, truncate_y // 2:samples.shape[2] - truncate_y // 2,
truncate_x // 2:samples.shape[3] - truncate_x // 2]
if self.scale_latent: if self.scale_latent:
samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear") samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f),
mode="bilinear")
else: else:
decoded_samples = decode_first_stage(self.sd_model, samples) decoded_samples = decode_first_stage(self.sd_model, samples)
if opts.upscaler_for_img2img is None or opts.upscaler_for_img2img == "None": if opts.upscaler_for_img2img is None or opts.upscaler_for_img2img == "None":
decoded_samples = torch.nn.functional.interpolate(decoded_samples, size=(self.height, self.width), mode="bilinear") decoded_samples = torch.nn.functional.interpolate(decoded_samples, size=(self.height, self.width),
mode="bilinear")
else: else:
lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0) lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
@ -585,13 +623,16 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model) self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds,
subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h,
seed_resize_from_w=self.seed_resize_from_w, p=self)
# GC now before running the next img2img to prevent running out of memory # GC now before running the next img2img to prevent running out of memory
x = None x = None
devices.torch_gc() devices.torch_gc()
samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps) samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning,
steps=self.steps)
return samples return samples
@ -599,7 +640,9 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
sampler = None sampler = None
def __init__(self, init_images=None, resize_mode=0, denoising_strength=0.75, mask=None, mask_blur=4, inpainting_fill=0, inpaint_full_res=True, inpaint_full_res_padding=0, inpainting_mask_invert=0, **kwargs): def __init__(self, init_images=None, resize_mode=0, denoising_strength=0.75, mask=None, mask_blur=4,
inpainting_fill=0, inpaint_full_res=True, inpaint_full_res_padding=0, inpainting_mask_invert=0,
**kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.init_images = init_images self.init_images = init_images
@ -607,7 +650,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
self.denoising_strength: float = denoising_strength self.denoising_strength: float = denoising_strength
self.init_latent = None self.init_latent = None
self.image_mask = mask self.image_mask = mask
#self.image_unblurred_mask = None # self.image_unblurred_mask = None
self.latent_mask = None self.latent_mask = None
self.mask_for_overlay = None self.mask_for_overlay = None
self.mask_blur = mask_blur self.mask_blur = mask_blur
@ -619,7 +662,8 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
self.nmask = None self.nmask = None
def init(self, all_prompts, all_seeds, all_subseeds): def init(self, all_prompts, all_seeds, all_subseeds):
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers_for_img2img, self.sampler_index, self.sd_model) self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers_for_img2img, self.sampler_index,
self.sd_model)
crop_region = None crop_region = None
if self.image_mask is not None: if self.image_mask is not None:
@ -628,7 +672,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
if self.inpainting_mask_invert: if self.inpainting_mask_invert:
self.image_mask = ImageOps.invert(self.image_mask) self.image_mask = ImageOps.invert(self.image_mask)
#self.image_unblurred_mask = self.image_mask # self.image_unblurred_mask = self.image_mask
if self.mask_blur > 0: if self.mask_blur > 0:
self.image_mask = self.image_mask.filter(ImageFilter.GaussianBlur(self.mask_blur)) self.image_mask = self.image_mask.filter(ImageFilter.GaussianBlur(self.mask_blur))
@ -642,7 +686,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
mask = mask.crop(crop_region) mask = mask.crop(crop_region)
self.image_mask = images.resize_image(2, mask, self.width, self.height) self.image_mask = images.resize_image(2, mask, self.width, self.height)
self.paste_to = (x1, y1, x2-x1, y2-y1) self.paste_to = (x1, y1, x2 - x1, y2 - y1)
else: else:
self.image_mask = images.resize_image(self.resize_mode, self.image_mask, self.width, self.height) self.image_mask = images.resize_image(self.resize_mode, self.image_mask, self.width, self.height)
np_mask = np.array(self.image_mask) np_mask = np.array(self.image_mask)
@ -665,7 +709,8 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
if self.image_mask is not None: if self.image_mask is not None:
image_masked = Image.new('RGBa', (image.width, image.height)) image_masked = Image.new('RGBa', (image.width, image.height))
image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(self.mask_for_overlay.convert('L'))) image_masked.paste(image.convert("RGBA").convert("RGBa"),
mask=ImageOps.invert(self.mask_for_overlay.convert('L')))
self.overlay_images.append(image_masked.convert('RGBA')) self.overlay_images.append(image_masked.convert('RGBA'))
@ -714,12 +759,17 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
# this needs to be fixed to be done in sample() using actual seeds for batches # this needs to be fixed to be done in sample() using actual seeds for batches
if self.inpainting_fill == 2: if self.inpainting_fill == 2:
self.init_latent = self.init_latent * self.mask + create_random_tensors(self.init_latent.shape[1:], all_seeds[0:self.init_latent.shape[0]]) * self.nmask self.init_latent = self.init_latent * self.mask + create_random_tensors(self.init_latent.shape[1:],
all_seeds[
0:self.init_latent.shape[
0]]) * self.nmask
elif self.inpainting_fill == 3: elif self.inpainting_fill == 3:
self.init_latent = self.init_latent * self.mask self.init_latent = self.init_latent * self.mask
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength): def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds,
subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h,
seed_resize_from_w=self.seed_resize_from_w, p=self)
samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning) samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning)

View file

@ -14,7 +14,8 @@ from modules.sd_hijack_optimizations import invokeAI_mps_available
import ldm.modules.attention import ldm.modules.attention
import ldm.modules.diffusionmodules.model import ldm.modules.diffusionmodules.model
from transformers import CLIPVisionModel, CLIPModel from tqdm import trange
from transformers import CLIPVisionModel, CLIPModel, CLIPTokenizer
import torch.optim as optim import torch.optim as optim
import copy import copy
@ -22,21 +23,25 @@ attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
def apply_optimizations(): def apply_optimizations():
undo_optimizations() undo_optimizations()
ldm.modules.diffusionmodules.model.nonlinearity = silu ldm.modules.diffusionmodules.model.nonlinearity = silu
if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (8, 6)): if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (
6, 0) <= torch.cuda.get_device_capability(shared.device) <= (8, 6)):
print("Applying xformers cross attention optimization.") print("Applying xformers cross attention optimization.")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward
elif cmd_opts.opt_split_attention_v1: elif cmd_opts.opt_split_attention_v1:
print("Applying v1 cross attention optimization.") print("Applying v1 cross attention optimization.")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1 ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention_invokeai or not torch.cuda.is_available()): elif not cmd_opts.disable_opt_split_attention and (
cmd_opts.opt_split_attention_invokeai or not torch.cuda.is_available()):
if not invokeAI_mps_available and shared.device.type == 'mps': if not invokeAI_mps_available and shared.device.type == 'mps':
print("The InvokeAI cross attention optimization for MPS requires the psutil package which is not installed.") print(
"The InvokeAI cross attention optimization for MPS requires the psutil package which is not installed.")
print("Applying v1 cross attention optimization.") print("Applying v1 cross attention optimization.")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1 ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
else: else:
@ -112,14 +117,16 @@ class StableDiffusionModelHijack:
_, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text]) _, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text])
return remade_batch_tokens[0], token_count, get_target_prompt_token_count(token_count) return remade_batch_tokens[0], token_count, get_target_prompt_token_count(token_count)
def slerp(low, high, val): def slerp(low, high, val):
low_norm = low/torch.norm(low, dim=1, keepdim=True) low_norm = low / torch.norm(low, dim=1, keepdim=True)
high_norm = high/torch.norm(high, dim=1, keepdim=True) high_norm = high / torch.norm(high, dim=1, keepdim=True)
omega = torch.acos((low_norm*high_norm).sum(1)) omega = torch.acos((low_norm * high_norm).sum(1))
so = torch.sin(omega) so = torch.sin(omega)
res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low + (torch.sin(val * omega) / so).unsqueeze(1) * high
return res return res
class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
def __init__(self, wrapped, hijack): def __init__(self, wrapped, hijack):
super().__init__() super().__init__()
@ -128,6 +135,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
self.wrapped.transformer.name_or_path self.wrapped.transformer.name_or_path
) )
del self.clipModel.vision_model del self.clipModel.vision_model
self.tokenizer = CLIPTokenizer.from_pretrained(self.wrapped.transformer.name_or_path)
self.hijack: StableDiffusionModelHijack = hijack self.hijack: StableDiffusionModelHijack = hijack
self.tokenizer = wrapped.tokenizer self.tokenizer = wrapped.tokenizer
# self.vision = CLIPVisionModel.from_pretrained(self.wrapped.transformer.name_or_path).eval() # self.vision = CLIPVisionModel.from_pretrained(self.wrapped.transformer.name_or_path).eval()
@ -139,7 +147,8 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ',</w>'][0] self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ',</w>'][0]
tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k] tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if
'(' in k or ')' in k or '[' in k or ']' in k]
for text, ident in tokens_with_parens: for text, ident in tokens_with_parens:
mult = 1.0 mult = 1.0
for c in text: for c in text:
@ -155,8 +164,13 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
if mult != 1.0: if mult != 1.0:
self.token_mults[ident] = mult self.token_mults[ident] = mult
def set_aesthetic_params(self, aesthetic_lr, aesthetic_weight, aesthetic_steps, image_embs_name=None, def set_aesthetic_params(self, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0, image_embs_name=None,
aesthetic_slerp=True): aesthetic_slerp=True, aesthetic_imgs_text="",
aesthetic_slerp_angle=0.15,
aesthetic_text_negative=False):
self.aesthetic_imgs_text = aesthetic_imgs_text
self.aesthetic_slerp_angle = aesthetic_slerp_angle
self.aesthetic_text_negative = aesthetic_text_negative
self.slerp = aesthetic_slerp self.slerp = aesthetic_slerp
self.aesthetic_lr = aesthetic_lr self.aesthetic_lr = aesthetic_lr
self.aesthetic_weight = aesthetic_weight self.aesthetic_weight = aesthetic_weight
@ -180,7 +194,8 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
else: else:
parsed = [[line, 1.0]] parsed = [[line, 1.0]]
tokenized = self.wrapped.tokenizer([text for text, _ in parsed], truncation=False, add_special_tokens=False)["input_ids"] tokenized = self.wrapped.tokenizer([text for text, _ in parsed], truncation=False, add_special_tokens=False)[
"input_ids"]
fixes = [] fixes = []
remade_tokens = [] remade_tokens = []
@ -196,18 +211,20 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
if token == self.comma_token: if token == self.comma_token:
last_comma = len(remade_tokens) last_comma = len(remade_tokens)
elif opts.comma_padding_backtrack != 0 and max(len(remade_tokens), 1) % 75 == 0 and last_comma != -1 and len(remade_tokens) - last_comma <= opts.comma_padding_backtrack: elif opts.comma_padding_backtrack != 0 and max(len(remade_tokens),
1) % 75 == 0 and last_comma != -1 and len(
remade_tokens) - last_comma <= opts.comma_padding_backtrack:
last_comma += 1 last_comma += 1
reloc_tokens = remade_tokens[last_comma:] reloc_tokens = remade_tokens[last_comma:]
reloc_mults = multipliers[last_comma:] reloc_mults = multipliers[last_comma:]
remade_tokens = remade_tokens[:last_comma] remade_tokens = remade_tokens[:last_comma]
length = len(remade_tokens) length = len(remade_tokens)
rem = int(math.ceil(length / 75)) * 75 - length rem = int(math.ceil(length / 75)) * 75 - length
remade_tokens += [id_end] * rem + reloc_tokens remade_tokens += [id_end] * rem + reloc_tokens
multipliers = multipliers[:last_comma] + [1.0] * rem + reloc_mults multipliers = multipliers[:last_comma] + [1.0] * rem + reloc_mults
if embedding is None: if embedding is None:
remade_tokens.append(token) remade_tokens.append(token)
multipliers.append(weight) multipliers.append(weight)
@ -248,7 +265,8 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
if line in cache: if line in cache:
remade_tokens, fixes, multipliers = cache[line] remade_tokens, fixes, multipliers = cache[line]
else: else:
remade_tokens, fixes, multipliers, current_token_count = self.tokenize_line(line, used_custom_terms, hijack_comments) remade_tokens, fixes, multipliers, current_token_count = self.tokenize_line(line, used_custom_terms,
hijack_comments)
token_count = max(current_token_count, token_count) token_count = max(current_token_count, token_count)
cache[line] = (remade_tokens, fixes, multipliers) cache[line] = (remade_tokens, fixes, multipliers)
@ -259,7 +277,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
def process_text_old(self, text): def process_text_old(self, text):
id_start = self.wrapped.tokenizer.bos_token_id id_start = self.wrapped.tokenizer.bos_token_id
id_end = self.wrapped.tokenizer.eos_token_id id_end = self.wrapped.tokenizer.eos_token_id
@ -289,7 +306,8 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
while i < len(tokens): while i < len(tokens):
token = tokens[i] token = tokens[i]
embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i) embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens,
i)
mult_change = self.token_mults.get(token) if opts.enable_emphasis else None mult_change = self.token_mults.get(token) if opts.enable_emphasis else None
if mult_change is not None: if mult_change is not None:
@ -312,11 +330,12 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
ovf = remade_tokens[maxlen - 2:] ovf = remade_tokens[maxlen - 2:]
overflowing_words = [vocab.get(int(x), "") for x in ovf] overflowing_words = [vocab.get(int(x), "") for x in ovf]
overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words)) overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n") hijack_comments.append(
f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
token_count = len(remade_tokens) token_count = len(remade_tokens)
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens)) remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end] remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end]
cache[tuple_tokens] = (remade_tokens, fixes, multipliers) cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers)) multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
@ -326,23 +345,26 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
hijack_fixes.append(fixes) hijack_fixes.append(fixes)
batch_multipliers.append(multipliers) batch_multipliers.append(multipliers)
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
def forward(self, text): def forward(self, text):
use_old = opts.use_old_emphasis_implementation use_old = opts.use_old_emphasis_implementation
if use_old: if use_old:
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text) batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(
text)
else: else:
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text) batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(
text)
self.hijack.comments += hijack_comments self.hijack.comments += hijack_comments
if len(used_custom_terms) > 0: if len(used_custom_terms) > 0:
self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms])) self.hijack.comments.append(
"Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
if use_old: if use_old:
self.hijack.fixes = hijack_fixes self.hijack.fixes = hijack_fixes
return self.process_tokens(remade_batch_tokens, batch_multipliers) return self.process_tokens(remade_batch_tokens, batch_multipliers)
z = None z = None
i = 0 i = 0
while max(map(len, remade_batch_tokens)) != 0: while max(map(len, remade_batch_tokens)) != 0:
@ -356,7 +378,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
if fix[0] == i: if fix[0] == i:
fixes.append(fix[1]) fixes.append(fix[1])
self.hijack.fixes.append(fixes) self.hijack.fixes.append(fixes)
tokens = [] tokens = []
multipliers = [] multipliers = []
for j in range(len(remade_batch_tokens)): for j in range(len(remade_batch_tokens)):
@ -378,19 +400,30 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
remade_batch_tokens] remade_batch_tokens]
tokens = torch.asarray(remade_batch_tokens).to(device) tokens = torch.asarray(remade_batch_tokens).to(device)
model = copy.deepcopy(self.clipModel).to(device)
model.requires_grad_(True)
if self.aesthetic_imgs_text is not None and len(self.aesthetic_imgs_text) > 0:
text_embs_2 = model.get_text_features(
**self.tokenizer([self.aesthetic_imgs_text], padding=True, return_tensors="pt").to(device))
if self.aesthetic_text_negative:
text_embs_2 = self.image_embs - text_embs_2
text_embs_2 /= text_embs_2.norm(dim=-1, keepdim=True)
img_embs = slerp(self.image_embs, text_embs_2, self.aesthetic_slerp_angle)
else:
img_embs = self.image_embs
with torch.enable_grad(): with torch.enable_grad():
model = copy.deepcopy(self.clipModel).to(device)
model.requires_grad_(True)
# We optimize the model to maximize the similarity # We optimize the model to maximize the similarity
optimizer = optim.Adam( optimizer = optim.Adam(
model.text_model.parameters(), lr=self.aesthetic_lr model.text_model.parameters(), lr=self.aesthetic_lr
) )
for i in range(self.aesthetic_steps): for i in trange(self.aesthetic_steps, desc="Aesthetic optimization"):
text_embs = model.get_text_features(input_ids=tokens) text_embs = model.get_text_features(input_ids=tokens)
text_embs = text_embs / text_embs.norm(dim=-1, keepdim=True) text_embs = text_embs / text_embs.norm(dim=-1, keepdim=True)
sim = text_embs @ self.image_embs.T sim = text_embs @ img_embs.T
loss = -sim loss = -sim
optimizer.zero_grad() optimizer.zero_grad()
loss.mean().backward() loss.mean().backward()
@ -405,6 +438,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
model.cpu() model.cpu()
del model del model
zn = torch.concat([zn for i in range(z.shape[1] // 77)], 1)
if self.slerp: if self.slerp:
z = slerp(z, zn, self.aesthetic_weight) z = slerp(z, zn, self.aesthetic_weight)
else: else:
@ -413,15 +447,16 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
remade_batch_tokens = rem_tokens remade_batch_tokens = rem_tokens
batch_multipliers = rem_multipliers batch_multipliers = rem_multipliers
i += 1 i += 1
return z return z
def process_tokens(self, remade_batch_tokens, batch_multipliers): def process_tokens(self, remade_batch_tokens, batch_multipliers):
if not opts.use_old_emphasis_implementation: if not opts.use_old_emphasis_implementation:
remade_batch_tokens = [[self.wrapped.tokenizer.bos_token_id] + x[:75] + [self.wrapped.tokenizer.eos_token_id] for x in remade_batch_tokens] remade_batch_tokens = [
[self.wrapped.tokenizer.bos_token_id] + x[:75] + [self.wrapped.tokenizer.eos_token_id] for x in
remade_batch_tokens]
batch_multipliers = [[1.0] + x[:75] + [1.0] for x in batch_multipliers] batch_multipliers = [[1.0] + x[:75] + [1.0] for x in batch_multipliers]
tokens = torch.asarray(remade_batch_tokens).to(device) tokens = torch.asarray(remade_batch_tokens).to(device)
outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers) outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers)
@ -461,8 +496,8 @@ class EmbeddingsWithFixes(torch.nn.Module):
for fixes, tensor in zip(batch_fixes, inputs_embeds): for fixes, tensor in zip(batch_fixes, inputs_embeds):
for offset, embedding in fixes: for offset, embedding in fixes:
emb = embedding.vec emb = embedding.vec
emb_len = min(tensor.shape[0]-offset-1, emb.shape[0]) emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])
tensor = torch.cat([tensor[0:offset+1], emb[0:emb_len], tensor[offset+1+emb_len:]]) tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]])
vecs.append(tensor) vecs.append(tensor)

View file

@ -95,6 +95,10 @@ loaded_hypernetwork = None
aesthetic_embeddings = {f.replace(".pt",""): os.path.join(cmd_opts.aesthetic_embeddings_dir, f) for f in aesthetic_embeddings = {f.replace(".pt",""): os.path.join(cmd_opts.aesthetic_embeddings_dir, f) for f in
os.listdir(cmd_opts.aesthetic_embeddings_dir) if f.endswith(".pt")} os.listdir(cmd_opts.aesthetic_embeddings_dir) if f.endswith(".pt")}
def update_aesthetic_embeddings():
global aesthetic_embeddings
aesthetic_embeddings = {f.replace(".pt",""): os.path.join(cmd_opts.aesthetic_embeddings_dir, f) for f in
os.listdir(cmd_opts.aesthetic_embeddings_dir) if f.endswith(".pt")}
def reload_hypernetworks(): def reload_hypernetworks():
global hypernetworks global hypernetworks

View file

@ -13,7 +13,11 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
aesthetic_lr=0, aesthetic_lr=0,
aesthetic_weight=0, aesthetic_steps=0, aesthetic_weight=0, aesthetic_steps=0,
aesthetic_imgs=None, aesthetic_imgs=None,
aesthetic_slerp=False, *args): aesthetic_slerp=False,
aesthetic_imgs_text="",
aesthetic_slerp_angle=0.15,
aesthetic_text_negative=False,
*args):
p = StableDiffusionProcessingTxt2Img( p = StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model, sd_model=shared.sd_model,
outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples, outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
@ -47,7 +51,9 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
processed = modules.scripts.scripts_txt2img.run(p, *args) processed = modules.scripts.scripts_txt2img.run(p, *args)
if processed is None: if processed is None:
processed = process_images(p, aesthetic_lr, aesthetic_weight, aesthetic_steps, aesthetic_imgs, aesthetic_slerp) processed = process_images(p, aesthetic_lr, aesthetic_weight, aesthetic_steps, aesthetic_imgs, aesthetic_slerp,aesthetic_imgs_text,
aesthetic_slerp_angle,
aesthetic_text_negative)
shared.total_tqdm.clear() shared.total_tqdm.clear()

View file

@ -41,6 +41,7 @@ from modules import prompt_parser
from modules.images import save_image from modules.images import save_image
import modules.textual_inversion.ui import modules.textual_inversion.ui
import modules.hypernetworks.ui import modules.hypernetworks.ui
import modules.aesthetic_clip
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI # this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI
mimetypes.init() mimetypes.init()
@ -449,7 +450,7 @@ def create_toprow(is_img2img):
with gr.Row(): with gr.Row():
negative_prompt = gr.Textbox(label="Negative prompt", elem_id="negative_prompt", show_label=False, placeholder="Negative prompt", lines=2) negative_prompt = gr.Textbox(label="Negative prompt", elem_id="negative_prompt", show_label=False, placeholder="Negative prompt", lines=2)
with gr.Column(scale=1, elem_id="roll_col"): with gr.Column(scale=1, elem_id="roll_col"):
sh = gr.Button(elem_id="sh", visible=True) sh = gr.Button(elem_id="sh", visible=True)
with gr.Column(scale=1, elem_id="style_neg_col"): with gr.Column(scale=1, elem_id="style_neg_col"):
prompt_style2 = gr.Dropdown(label="Style 2", elem_id=f"{id_part}_style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())), visible=len(shared.prompt_styles.styles) > 1) prompt_style2 = gr.Dropdown(label="Style 2", elem_id=f"{id_part}_style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())), visible=len(shared.prompt_styles.styles) > 1)
@ -536,9 +537,13 @@ def create_ui(wrap_gradio_gpu_call):
height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512) height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
with gr.Group(): with gr.Group():
aesthetic_lr = gr.Textbox(label='Learning rate', placeholder="Learning rate", value="0.005") aesthetic_lr = gr.Textbox(label='Learning rate', placeholder="Learning rate", value="0.0001")
aesthetic_weight = gr.Slider(minimum=0, maximum=1, step=0.01, label="Aesthetic weight", value=0.7) aesthetic_weight = gr.Slider(minimum=0, maximum=1, step=0.01, label="Aesthetic weight", value=0.9)
aesthetic_steps = gr.Slider(minimum=0, maximum=50, step=1, label="Aesthetic steps", value=50) aesthetic_steps = gr.Slider(minimum=0, maximum=256, step=1, label="Aesthetic steps", value=5)
with gr.Row():
aesthetic_imgs_text = gr.Textbox(label='Aesthetic text for imgs', placeholder="This text is used to rotate the feature space of the imgs embs", value="")
aesthetic_slerp_angle = gr.Slider(label='Slerp angle',minimum=0, maximum=1, step=0.01, value=0.1)
aesthetic_text_negative = gr.Checkbox(label="Is negative text", value=False)
aesthetic_imgs = gr.Dropdown(sorted(aesthetic_embeddings.keys()), label="Imgs embedding", value=sorted(aesthetic_embeddings.keys())[0] if len(aesthetic_embeddings) > 0 else None) aesthetic_imgs = gr.Dropdown(sorted(aesthetic_embeddings.keys()), label="Imgs embedding", value=sorted(aesthetic_embeddings.keys())[0] if len(aesthetic_embeddings) > 0 else None)
aesthetic_slerp = gr.Checkbox(label="Slerp interpolation", value=False) aesthetic_slerp = gr.Checkbox(label="Slerp interpolation", value=False)
@ -617,7 +622,10 @@ def create_ui(wrap_gradio_gpu_call):
aesthetic_weight, aesthetic_weight,
aesthetic_steps, aesthetic_steps,
aesthetic_imgs, aesthetic_imgs,
aesthetic_slerp aesthetic_slerp,
aesthetic_imgs_text,
aesthetic_slerp_angle,
aesthetic_text_negative
] + custom_inputs, ] + custom_inputs,
outputs=[ outputs=[
txt2img_gallery, txt2img_gallery,
@ -721,7 +729,7 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Row(): with gr.Row():
inpaint_full_res = gr.Checkbox(label='Inpaint at full resolution', value=False) inpaint_full_res = gr.Checkbox(label='Inpaint at full resolution', value=False)
inpaint_full_res_padding = gr.Slider(label='Inpaint at full resolution padding, pixels', minimum=0, maximum=256, step=4, value=32) inpaint_full_res_padding = gr.Slider(label='Inpaint at full resolution padding, pixels', minimum=0, maximum=1024, step=4, value=32)
with gr.TabItem('Batch img2img', id='batch'): with gr.TabItem('Batch img2img', id='batch'):
hidden = '<br>Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else '' hidden = '<br>Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else ''
@ -1071,6 +1079,17 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Column(): with gr.Column():
create_embedding = gr.Button(value="Create embedding", variant='primary') create_embedding = gr.Button(value="Create embedding", variant='primary')
with gr.Tab(label="Create images embedding"):
new_embedding_name_ae = gr.Textbox(label="Name")
process_src_ae = gr.Textbox(label='Source directory')
batch_ae = gr.Slider(minimum=1, maximum=1024, step=1, label="Batch size", value=256)
with gr.Row():
with gr.Column(scale=3):
gr.HTML(value="")
with gr.Column():
create_embedding_ae = gr.Button(value="Create images embedding", variant='primary')
with gr.Tab(label="Create hypernetwork"): with gr.Tab(label="Create hypernetwork"):
new_hypernetwork_name = gr.Textbox(label="Name") new_hypernetwork_name = gr.Textbox(label="Name")
new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"]) new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"])
@ -1139,7 +1158,7 @@ def create_ui(wrap_gradio_gpu_call):
fn=modules.textual_inversion.ui.create_embedding, fn=modules.textual_inversion.ui.create_embedding,
inputs=[ inputs=[
new_embedding_name, new_embedding_name,
initialization_text, process_src,
nvpt, nvpt,
], ],
outputs=[ outputs=[
@ -1149,6 +1168,20 @@ def create_ui(wrap_gradio_gpu_call):
] ]
) )
create_embedding_ae.click(
fn=modules.aesthetic_clip.generate_imgs_embd,
inputs=[
new_embedding_name_ae,
process_src_ae,
batch_ae
],
outputs=[
aesthetic_imgs,
ti_output,
ti_outcome,
]
)
create_hypernetwork.click( create_hypernetwork.click(
fn=modules.hypernetworks.ui.create_hypernetwork, fn=modules.hypernetworks.ui.create_hypernetwork,
inputs=[ inputs=[