df57064093
add refresh button for aesthetic embeddings add aesthetic params to images' infotext
241 lines
10 KiB
Python
241 lines
10 KiB
Python
import copy
|
|
import itertools
|
|
import os
|
|
from pathlib import Path
|
|
import html
|
|
import gc
|
|
|
|
import gradio as gr
|
|
import torch
|
|
from PIL import Image
|
|
from torch import optim
|
|
|
|
from modules import shared
|
|
from transformers import CLIPModel, CLIPProcessor, CLIPTokenizer
|
|
from tqdm.auto import tqdm, trange
|
|
from modules.shared import opts, device
|
|
|
|
|
|
def get_all_images_in_folder(folder):
|
|
return [os.path.join(folder, f) for f in os.listdir(folder) if
|
|
os.path.isfile(os.path.join(folder, f)) and check_is_valid_image_file(f)]
|
|
|
|
|
|
def check_is_valid_image_file(filename):
|
|
return filename.lower().endswith(('.png', '.jpg', '.jpeg', ".gif", ".tiff", ".webp"))
|
|
|
|
|
|
def batched(dataset, total, n=1):
|
|
for ndx in range(0, total, n):
|
|
yield [dataset.__getitem__(i) for i in range(ndx, min(ndx + n, total))]
|
|
|
|
|
|
def iter_to_batched(iterable, n=1):
|
|
it = iter(iterable)
|
|
while True:
|
|
chunk = tuple(itertools.islice(it, n))
|
|
if not chunk:
|
|
return
|
|
yield chunk
|
|
|
|
|
|
def create_ui():
|
|
import modules.ui
|
|
|
|
with gr.Group():
|
|
with gr.Accordion("Open for Clip Aesthetic!", open=False):
|
|
with gr.Row():
|
|
aesthetic_weight = gr.Slider(minimum=0, maximum=1, step=0.01, label="Aesthetic weight",
|
|
value=0.9)
|
|
aesthetic_steps = gr.Slider(minimum=0, maximum=50, step=1, label="Aesthetic steps", value=5)
|
|
|
|
with gr.Row():
|
|
aesthetic_lr = gr.Textbox(label='Aesthetic learning rate',
|
|
placeholder="Aesthetic learning rate", value="0.0001")
|
|
aesthetic_slerp = gr.Checkbox(label="Slerp interpolation", value=False)
|
|
aesthetic_imgs = gr.Dropdown(sorted(shared.aesthetic_embeddings.keys()),
|
|
label="Aesthetic imgs embedding",
|
|
value="None")
|
|
|
|
modules.ui.create_refresh_button(aesthetic_imgs, shared.update_aesthetic_embeddings, lambda: {"choices": sorted(shared.aesthetic_embeddings.keys())}, "refresh_aesthetic_embeddings")
|
|
|
|
with gr.Row():
|
|
aesthetic_imgs_text = gr.Textbox(label='Aesthetic text for imgs',
|
|
placeholder="This text is used to rotate the feature space of the imgs embs",
|
|
value="")
|
|
aesthetic_slerp_angle = gr.Slider(label='Slerp angle', minimum=0, maximum=1, step=0.01,
|
|
value=0.1)
|
|
aesthetic_text_negative = gr.Checkbox(label="Is negative text", value=False)
|
|
|
|
return aesthetic_weight, aesthetic_steps, aesthetic_lr, aesthetic_slerp, aesthetic_imgs, aesthetic_imgs_text, aesthetic_slerp_angle, aesthetic_text_negative
|
|
|
|
|
|
aesthetic_clip_model = None
|
|
|
|
|
|
def aesthetic_clip():
|
|
global aesthetic_clip_model
|
|
|
|
if aesthetic_clip_model is None or aesthetic_clip_model.name_or_path != shared.sd_model.cond_stage_model.wrapped.transformer.name_or_path:
|
|
aesthetic_clip_model = CLIPModel.from_pretrained(shared.sd_model.cond_stage_model.wrapped.transformer.name_or_path)
|
|
aesthetic_clip_model.cpu()
|
|
|
|
return aesthetic_clip_model
|
|
|
|
|
|
def generate_imgs_embd(name, folder, batch_size):
|
|
model = aesthetic_clip().to(device)
|
|
processor = CLIPProcessor.from_pretrained(model.name_or_path)
|
|
|
|
with torch.no_grad():
|
|
embs = []
|
|
for paths in tqdm(iter_to_batched(get_all_images_in_folder(folder), batch_size),
|
|
desc=f"Generating embeddings for {name}"):
|
|
if shared.state.interrupted:
|
|
break
|
|
inputs = processor(images=[Image.open(path) for path in paths], return_tensors="pt").to(device)
|
|
outputs = model.get_image_features(**inputs).cpu()
|
|
embs.append(torch.clone(outputs))
|
|
inputs.to("cpu")
|
|
del inputs, outputs
|
|
|
|
embs = torch.cat(embs, dim=0).mean(dim=0, keepdim=True)
|
|
|
|
# The generated embedding will be located here
|
|
path = str(Path(shared.cmd_opts.aesthetic_embeddings_dir) / f"{name}.pt")
|
|
torch.save(embs, path)
|
|
|
|
model.cpu()
|
|
del processor
|
|
del embs
|
|
gc.collect()
|
|
torch.cuda.empty_cache()
|
|
res = f"""
|
|
Done generating embedding for {name}!
|
|
Aesthetic embedding saved to {html.escape(path)}
|
|
"""
|
|
shared.update_aesthetic_embeddings()
|
|
return gr.Dropdown.update(choices=sorted(shared.aesthetic_embeddings.keys()), label="Imgs embedding",
|
|
value="None"), \
|
|
gr.Dropdown.update(choices=sorted(shared.aesthetic_embeddings.keys()),
|
|
label="Imgs embedding",
|
|
value="None"), res, ""
|
|
|
|
|
|
def slerp(low, high, val):
|
|
low_norm = low / torch.norm(low, dim=1, keepdim=True)
|
|
high_norm = high / torch.norm(high, dim=1, keepdim=True)
|
|
omega = torch.acos((low_norm * high_norm).sum(1))
|
|
so = torch.sin(omega)
|
|
res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low + (torch.sin(val * omega) / so).unsqueeze(1) * high
|
|
return res
|
|
|
|
|
|
class AestheticCLIP:
|
|
def __init__(self):
|
|
self.skip = False
|
|
self.aesthetic_steps = 0
|
|
self.aesthetic_weight = 0
|
|
self.aesthetic_lr = 0
|
|
self.slerp = False
|
|
self.aesthetic_text_negative = ""
|
|
self.aesthetic_slerp_angle = 0
|
|
self.aesthetic_imgs_text = ""
|
|
|
|
self.image_embs_name = None
|
|
self.image_embs = None
|
|
self.load_image_embs(None)
|
|
|
|
def set_aesthetic_params(self, p, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0, image_embs_name=None,
|
|
aesthetic_slerp=True, aesthetic_imgs_text="",
|
|
aesthetic_slerp_angle=0.15,
|
|
aesthetic_text_negative=False):
|
|
self.aesthetic_imgs_text = aesthetic_imgs_text
|
|
self.aesthetic_slerp_angle = aesthetic_slerp_angle
|
|
self.aesthetic_text_negative = aesthetic_text_negative
|
|
self.slerp = aesthetic_slerp
|
|
self.aesthetic_lr = aesthetic_lr
|
|
self.aesthetic_weight = aesthetic_weight
|
|
self.aesthetic_steps = aesthetic_steps
|
|
self.load_image_embs(image_embs_name)
|
|
|
|
if self.image_embs_name is not None:
|
|
p.extra_generation_params.update({
|
|
"Aesthetic LR": aesthetic_lr,
|
|
"Aesthetic weight": aesthetic_weight,
|
|
"Aesthetic steps": aesthetic_steps,
|
|
"Aesthetic embedding": self.image_embs_name,
|
|
"Aesthetic slerp": aesthetic_slerp,
|
|
"Aesthetic text": aesthetic_imgs_text,
|
|
"Aesthetic text negative": aesthetic_text_negative,
|
|
"Aesthetic slerp angle": aesthetic_slerp_angle,
|
|
})
|
|
|
|
def set_skip(self, skip):
|
|
self.skip = skip
|
|
|
|
def load_image_embs(self, image_embs_name):
|
|
if image_embs_name is None or len(image_embs_name) == 0 or image_embs_name == "None":
|
|
image_embs_name = None
|
|
self.image_embs_name = None
|
|
if image_embs_name is not None and self.image_embs_name != image_embs_name:
|
|
self.image_embs_name = image_embs_name
|
|
self.image_embs = torch.load(shared.aesthetic_embeddings[self.image_embs_name], map_location=device)
|
|
self.image_embs /= self.image_embs.norm(dim=-1, keepdim=True)
|
|
self.image_embs.requires_grad_(False)
|
|
|
|
def __call__(self, z, remade_batch_tokens):
|
|
if not self.skip and self.aesthetic_steps != 0 and self.aesthetic_lr != 0 and self.aesthetic_weight != 0 and self.image_embs_name is not None:
|
|
tokenizer = shared.sd_model.cond_stage_model.tokenizer
|
|
if not opts.use_old_emphasis_implementation:
|
|
remade_batch_tokens = [
|
|
[tokenizer.bos_token_id] + x[:75] + [tokenizer.eos_token_id] for x in
|
|
remade_batch_tokens]
|
|
|
|
tokens = torch.asarray(remade_batch_tokens).to(device)
|
|
|
|
model = copy.deepcopy(aesthetic_clip()).to(device)
|
|
model.requires_grad_(True)
|
|
if self.aesthetic_imgs_text is not None and len(self.aesthetic_imgs_text) > 0:
|
|
text_embs_2 = model.get_text_features(
|
|
**tokenizer([self.aesthetic_imgs_text], padding=True, return_tensors="pt").to(device))
|
|
if self.aesthetic_text_negative:
|
|
text_embs_2 = self.image_embs - text_embs_2
|
|
text_embs_2 /= text_embs_2.norm(dim=-1, keepdim=True)
|
|
img_embs = slerp(self.image_embs, text_embs_2, self.aesthetic_slerp_angle)
|
|
else:
|
|
img_embs = self.image_embs
|
|
|
|
with torch.enable_grad():
|
|
|
|
# We optimize the model to maximize the similarity
|
|
optimizer = optim.Adam(
|
|
model.text_model.parameters(), lr=self.aesthetic_lr
|
|
)
|
|
|
|
for _ in trange(self.aesthetic_steps, desc="Aesthetic optimization"):
|
|
text_embs = model.get_text_features(input_ids=tokens)
|
|
text_embs = text_embs / text_embs.norm(dim=-1, keepdim=True)
|
|
sim = text_embs @ img_embs.T
|
|
loss = -sim
|
|
optimizer.zero_grad()
|
|
loss.mean().backward()
|
|
optimizer.step()
|
|
|
|
zn = model.text_model(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers)
|
|
if opts.CLIP_stop_at_last_layers > 1:
|
|
zn = zn.hidden_states[-opts.CLIP_stop_at_last_layers]
|
|
zn = model.text_model.final_layer_norm(zn)
|
|
else:
|
|
zn = zn.last_hidden_state
|
|
model.cpu()
|
|
del model
|
|
gc.collect()
|
|
torch.cuda.empty_cache()
|
|
zn = torch.concat([zn[77 * i:77 * (i + 1)] for i in range(max(z.shape[1] // 77, 1))], 1)
|
|
if self.slerp:
|
|
z = slerp(z, zn, self.aesthetic_weight)
|
|
else:
|
|
z = z * (1 - self.aesthetic_weight) + zn * self.aesthetic_weight
|
|
|
|
return z
|