Merge branch 'master' into fix/train-preprocess-keep-ratio
This commit is contained in:
commit
5e9afa5c8a
29 changed files with 1106 additions and 209 deletions
2
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
2
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
|
@ -45,6 +45,8 @@ body:
|
|||
attributes:
|
||||
label: Commit where the problem happens
|
||||
description: Which commit are you running ? (copy the **Commit hash** shown in the cmd/terminal when you launch the UI)
|
||||
validations:
|
||||
required: true
|
||||
- type: dropdown
|
||||
id: platforms
|
||||
attributes:
|
||||
|
|
5
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
5
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
|
@ -0,0 +1,5 @@
|
|||
blank_issues_enabled: false
|
||||
contact_links:
|
||||
- name: WebUI Community Support
|
||||
url: https://github.com/AUTOMATIC1111/stable-diffusion-webui/discussions
|
||||
about: Please ask and answer questions here.
|
30
README.md
30
README.md
|
@ -11,6 +11,7 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web
|
|||
- One click install and run script (but you still must install python and git)
|
||||
- Outpainting
|
||||
- Inpainting
|
||||
- Color Sketch
|
||||
- Prompt Matrix
|
||||
- Stable Diffusion Upscale
|
||||
- Attention, specify parts of text that the model should pay more attention to
|
||||
|
@ -23,6 +24,7 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web
|
|||
- have as many embeddings as you want and use any names you like for them
|
||||
- use multiple embeddings with different numbers of vectors per token
|
||||
- works with half precision floating point numbers
|
||||
- train embeddings on 8GB (also reports of 6GB working)
|
||||
- Extras tab with:
|
||||
- GFPGAN, neural network that fixes faces
|
||||
- CodeFormer, face restoration tool as an alternative to GFPGAN
|
||||
|
@ -37,14 +39,14 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web
|
|||
- Interrupt processing at any time
|
||||
- 4GB video card support (also reports of 2GB working)
|
||||
- Correct seeds for batches
|
||||
- Prompt length validation
|
||||
- get length of prompt in tokens as you type
|
||||
- get a warning after generation if some text was truncated
|
||||
- Live prompt token length validation
|
||||
- Generation parameters
|
||||
- parameters you used to generate images are saved with that image
|
||||
- in PNG chunks for PNG, in EXIF for JPEG
|
||||
- can drag the image to PNG info tab to restore generation parameters and automatically copy them into UI
|
||||
- can be disabled in settings
|
||||
- drag and drop an image/text-parameters to promptbox
|
||||
- Read Generation Parameters Button, loads parameters in promptbox to UI
|
||||
- Settings page
|
||||
- Running arbitrary python code from UI (must run with --allow-code to enable)
|
||||
- Mouseover hints for most UI elements
|
||||
|
@ -59,10 +61,10 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web
|
|||
- CLIP interrogator, a button that tries to guess prompt from an image
|
||||
- Prompt Editing, a way to change prompt mid-generation, say to start making a watermelon and switch to anime girl midway
|
||||
- Batch Processing, process a group of files using img2img
|
||||
- Img2img Alternative
|
||||
- Img2img Alternative, reverse Euler method of cross attention control
|
||||
- Highres Fix, a convenience option to produce high resolution pictures in one click without usual distortions
|
||||
- Reloading checkpoints on the fly
|
||||
- Checkpoint Merger, a tab that allows you to merge two checkpoints into one
|
||||
- Checkpoint Merger, a tab that allows you to merge up to 3 checkpoints into one
|
||||
- [Custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Scripts) with many extensions from community
|
||||
- [Composable-Diffusion](https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/), a way to use multiple prompts at once
|
||||
- separate prompts using uppercase `AND`
|
||||
|
@ -70,14 +72,26 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web
|
|||
- No token limit for prompts (original stable diffusion lets you use up to 75 tokens)
|
||||
- DeepDanbooru integration, creates danbooru style tags for anime prompts (add --deepdanbooru to commandline args)
|
||||
- [xformers](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers), major speed increase for select cards: (add --xformers to commandline args)
|
||||
- History tab: view, direct and delete images conveniently within the UI
|
||||
- Generate forever option
|
||||
- Training tab
|
||||
- hypernetworks and embeddings options
|
||||
- Preprocessing images: cropping, mirroring, autotagging using BLIP or deepdanbooru (for anime)
|
||||
- Clip skip
|
||||
- Use Hypernetworks
|
||||
- Use VAEs
|
||||
- Estimated completion time in progress bar
|
||||
- API
|
||||
- Support for dedicated [inpainting model](https://github.com/runwayml/stable-diffusion#inpainting-with-stable-diffusion) by RunwayML.
|
||||
- Aesthetic Gradients, a way to generate images with a specific aesthetic by using clip images embds (implementation of [https://github.com/vicgalle/stable-diffusion-aesthetic-gradients](https://github.com/vicgalle/stable-diffusion-aesthetic-gradients))
|
||||
|
||||
|
||||
## Installation and Running
|
||||
Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for both [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended) and [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs.
|
||||
|
||||
Alternatively, use Google Colab:
|
||||
Alternatively, use online services (like Google Colab):
|
||||
|
||||
- [Colab, maintained by Akaibu](https://colab.research.google.com/drive/1kw3egmSn-KgWsikYvOMjJkVDsPLjEMzl)
|
||||
- [Colab, original by me, outdated](https://colab.research.google.com/drive/1Iy-xW9t1-OQWhb0hNxueGij8phCyluOh).
|
||||
- [List of Online Services](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Online-Services)
|
||||
|
||||
### Automatic Installation on Windows
|
||||
1. Install [Python 3.10.6](https://www.python.org/downloads/windows/), checking "Add Python to PATH"
|
||||
|
|
|
@ -3,12 +3,12 @@ let currentWidth = null;
|
|||
let currentHeight = null;
|
||||
let arFrameTimeout = setTimeout(function(){},0);
|
||||
|
||||
function dimensionChange(e,dimname){
|
||||
function dimensionChange(e, is_width, is_height){
|
||||
|
||||
if(dimname == 'Width'){
|
||||
if(is_width){
|
||||
currentWidth = e.target.value*1.0
|
||||
}
|
||||
if(dimname == 'Height'){
|
||||
if(is_height){
|
||||
currentHeight = e.target.value*1.0
|
||||
}
|
||||
|
||||
|
@ -18,22 +18,13 @@ function dimensionChange(e,dimname){
|
|||
return;
|
||||
}
|
||||
|
||||
var img2imgMode = gradioApp().querySelector('#mode_img2img.tabs > div > button.rounded-t-lg.border-gray-200')
|
||||
if(img2imgMode){
|
||||
img2imgMode=img2imgMode.innerText
|
||||
}else{
|
||||
return;
|
||||
}
|
||||
|
||||
var redrawImage = gradioApp().querySelector('div[data-testid=image] img');
|
||||
var inpaintImage = gradioApp().querySelector('#img2maskimg div[data-testid=image] img')
|
||||
|
||||
var targetElement = null;
|
||||
|
||||
if(img2imgMode=='img2img' && redrawImage){
|
||||
targetElement = redrawImage;
|
||||
}else if(img2imgMode=='Inpaint' && inpaintImage){
|
||||
targetElement = inpaintImage;
|
||||
var tabIndex = get_tab_index('mode_img2img')
|
||||
if(tabIndex == 0){
|
||||
targetElement = gradioApp().querySelector('div[data-testid=image] img');
|
||||
} else if(tabIndex == 1){
|
||||
targetElement = gradioApp().querySelector('#img2maskimg div[data-testid=image] img');
|
||||
}
|
||||
|
||||
if(targetElement){
|
||||
|
@ -98,22 +89,20 @@ onUiUpdate(function(){
|
|||
var inImg2img = Boolean(gradioApp().querySelector("button.rounded-t-lg.border-gray-200"))
|
||||
if(inImg2img){
|
||||
let inputs = gradioApp().querySelectorAll('input');
|
||||
inputs.forEach(function(e){
|
||||
let parentLabel = e.parentElement.querySelector('label')
|
||||
if(parentLabel && parentLabel.innerText){
|
||||
if(!e.classList.contains('scrollwatch')){
|
||||
if(parentLabel.innerText == 'Width' || parentLabel.innerText == 'Height'){
|
||||
e.addEventListener('input', function(e){dimensionChange(e,parentLabel.innerText)} )
|
||||
e.classList.add('scrollwatch')
|
||||
}
|
||||
if(parentLabel.innerText == 'Width'){
|
||||
currentWidth = e.value*1.0
|
||||
}
|
||||
if(parentLabel.innerText == 'Height'){
|
||||
currentHeight = e.value*1.0
|
||||
}
|
||||
}
|
||||
}
|
||||
inputs.forEach(function(e){
|
||||
var is_width = e.parentElement.id == "img2img_width"
|
||||
var is_height = e.parentElement.id == "img2img_height"
|
||||
|
||||
if((is_width || is_height) && !e.classList.contains('scrollwatch')){
|
||||
e.addEventListener('input', function(e){dimensionChange(e, is_width, is_height)} )
|
||||
e.classList.add('scrollwatch')
|
||||
}
|
||||
if(is_width){
|
||||
currentWidth = e.value*1.0
|
||||
}
|
||||
if(is_height){
|
||||
currentHeight = e.value*1.0
|
||||
}
|
||||
})
|
||||
}
|
||||
});
|
||||
|
|
2
javascript/dragdrop.js
vendored
2
javascript/dragdrop.js
vendored
|
@ -43,7 +43,7 @@ function dropReplaceImage( imgWrap, files ) {
|
|||
window.document.addEventListener('dragover', e => {
|
||||
const target = e.composedPath()[0];
|
||||
const imgWrap = target.closest('[data-testid="image"]');
|
||||
if ( !imgWrap && target.placeholder.indexOf("Prompt") == -1) {
|
||||
if ( !imgWrap && target.placeholder && target.placeholder.indexOf("Prompt") == -1) {
|
||||
return;
|
||||
}
|
||||
e.stopPropagation();
|
||||
|
|
241
modules/aesthetic_clip.py
Normal file
241
modules/aesthetic_clip.py
Normal file
|
@ -0,0 +1,241 @@
|
|||
import copy
|
||||
import itertools
|
||||
import os
|
||||
from pathlib import Path
|
||||
import html
|
||||
import gc
|
||||
|
||||
import gradio as gr
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torch import optim
|
||||
|
||||
from modules import shared
|
||||
from transformers import CLIPModel, CLIPProcessor, CLIPTokenizer
|
||||
from tqdm.auto import tqdm, trange
|
||||
from modules.shared import opts, device
|
||||
|
||||
|
||||
def get_all_images_in_folder(folder):
|
||||
return [os.path.join(folder, f) for f in os.listdir(folder) if
|
||||
os.path.isfile(os.path.join(folder, f)) and check_is_valid_image_file(f)]
|
||||
|
||||
|
||||
def check_is_valid_image_file(filename):
|
||||
return filename.lower().endswith(('.png', '.jpg', '.jpeg', ".gif", ".tiff", ".webp"))
|
||||
|
||||
|
||||
def batched(dataset, total, n=1):
|
||||
for ndx in range(0, total, n):
|
||||
yield [dataset.__getitem__(i) for i in range(ndx, min(ndx + n, total))]
|
||||
|
||||
|
||||
def iter_to_batched(iterable, n=1):
|
||||
it = iter(iterable)
|
||||
while True:
|
||||
chunk = tuple(itertools.islice(it, n))
|
||||
if not chunk:
|
||||
return
|
||||
yield chunk
|
||||
|
||||
|
||||
def create_ui():
|
||||
import modules.ui
|
||||
|
||||
with gr.Group():
|
||||
with gr.Accordion("Open for Clip Aesthetic!", open=False):
|
||||
with gr.Row():
|
||||
aesthetic_weight = gr.Slider(minimum=0, maximum=1, step=0.01, label="Aesthetic weight",
|
||||
value=0.9)
|
||||
aesthetic_steps = gr.Slider(minimum=0, maximum=50, step=1, label="Aesthetic steps", value=5)
|
||||
|
||||
with gr.Row():
|
||||
aesthetic_lr = gr.Textbox(label='Aesthetic learning rate',
|
||||
placeholder="Aesthetic learning rate", value="0.0001")
|
||||
aesthetic_slerp = gr.Checkbox(label="Slerp interpolation", value=False)
|
||||
aesthetic_imgs = gr.Dropdown(sorted(shared.aesthetic_embeddings.keys()),
|
||||
label="Aesthetic imgs embedding",
|
||||
value="None")
|
||||
|
||||
modules.ui.create_refresh_button(aesthetic_imgs, shared.update_aesthetic_embeddings, lambda: {"choices": sorted(shared.aesthetic_embeddings.keys())}, "refresh_aesthetic_embeddings")
|
||||
|
||||
with gr.Row():
|
||||
aesthetic_imgs_text = gr.Textbox(label='Aesthetic text for imgs',
|
||||
placeholder="This text is used to rotate the feature space of the imgs embs",
|
||||
value="")
|
||||
aesthetic_slerp_angle = gr.Slider(label='Slerp angle', minimum=0, maximum=1, step=0.01,
|
||||
value=0.1)
|
||||
aesthetic_text_negative = gr.Checkbox(label="Is negative text", value=False)
|
||||
|
||||
return aesthetic_weight, aesthetic_steps, aesthetic_lr, aesthetic_slerp, aesthetic_imgs, aesthetic_imgs_text, aesthetic_slerp_angle, aesthetic_text_negative
|
||||
|
||||
|
||||
aesthetic_clip_model = None
|
||||
|
||||
|
||||
def aesthetic_clip():
|
||||
global aesthetic_clip_model
|
||||
|
||||
if aesthetic_clip_model is None or aesthetic_clip_model.name_or_path != shared.sd_model.cond_stage_model.wrapped.transformer.name_or_path:
|
||||
aesthetic_clip_model = CLIPModel.from_pretrained(shared.sd_model.cond_stage_model.wrapped.transformer.name_or_path)
|
||||
aesthetic_clip_model.cpu()
|
||||
|
||||
return aesthetic_clip_model
|
||||
|
||||
|
||||
def generate_imgs_embd(name, folder, batch_size):
|
||||
model = aesthetic_clip().to(device)
|
||||
processor = CLIPProcessor.from_pretrained(model.name_or_path)
|
||||
|
||||
with torch.no_grad():
|
||||
embs = []
|
||||
for paths in tqdm(iter_to_batched(get_all_images_in_folder(folder), batch_size),
|
||||
desc=f"Generating embeddings for {name}"):
|
||||
if shared.state.interrupted:
|
||||
break
|
||||
inputs = processor(images=[Image.open(path) for path in paths], return_tensors="pt").to(device)
|
||||
outputs = model.get_image_features(**inputs).cpu()
|
||||
embs.append(torch.clone(outputs))
|
||||
inputs.to("cpu")
|
||||
del inputs, outputs
|
||||
|
||||
embs = torch.cat(embs, dim=0).mean(dim=0, keepdim=True)
|
||||
|
||||
# The generated embedding will be located here
|
||||
path = str(Path(shared.cmd_opts.aesthetic_embeddings_dir) / f"{name}.pt")
|
||||
torch.save(embs, path)
|
||||
|
||||
model.cpu()
|
||||
del processor
|
||||
del embs
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
res = f"""
|
||||
Done generating embedding for {name}!
|
||||
Aesthetic embedding saved to {html.escape(path)}
|
||||
"""
|
||||
shared.update_aesthetic_embeddings()
|
||||
return gr.Dropdown.update(choices=sorted(shared.aesthetic_embeddings.keys()), label="Imgs embedding",
|
||||
value="None"), \
|
||||
gr.Dropdown.update(choices=sorted(shared.aesthetic_embeddings.keys()),
|
||||
label="Imgs embedding",
|
||||
value="None"), res, ""
|
||||
|
||||
|
||||
def slerp(low, high, val):
|
||||
low_norm = low / torch.norm(low, dim=1, keepdim=True)
|
||||
high_norm = high / torch.norm(high, dim=1, keepdim=True)
|
||||
omega = torch.acos((low_norm * high_norm).sum(1))
|
||||
so = torch.sin(omega)
|
||||
res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low + (torch.sin(val * omega) / so).unsqueeze(1) * high
|
||||
return res
|
||||
|
||||
|
||||
class AestheticCLIP:
|
||||
def __init__(self):
|
||||
self.skip = False
|
||||
self.aesthetic_steps = 0
|
||||
self.aesthetic_weight = 0
|
||||
self.aesthetic_lr = 0
|
||||
self.slerp = False
|
||||
self.aesthetic_text_negative = ""
|
||||
self.aesthetic_slerp_angle = 0
|
||||
self.aesthetic_imgs_text = ""
|
||||
|
||||
self.image_embs_name = None
|
||||
self.image_embs = None
|
||||
self.load_image_embs(None)
|
||||
|
||||
def set_aesthetic_params(self, p, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0, image_embs_name=None,
|
||||
aesthetic_slerp=True, aesthetic_imgs_text="",
|
||||
aesthetic_slerp_angle=0.15,
|
||||
aesthetic_text_negative=False):
|
||||
self.aesthetic_imgs_text = aesthetic_imgs_text
|
||||
self.aesthetic_slerp_angle = aesthetic_slerp_angle
|
||||
self.aesthetic_text_negative = aesthetic_text_negative
|
||||
self.slerp = aesthetic_slerp
|
||||
self.aesthetic_lr = aesthetic_lr
|
||||
self.aesthetic_weight = aesthetic_weight
|
||||
self.aesthetic_steps = aesthetic_steps
|
||||
self.load_image_embs(image_embs_name)
|
||||
|
||||
if self.image_embs_name is not None:
|
||||
p.extra_generation_params.update({
|
||||
"Aesthetic LR": aesthetic_lr,
|
||||
"Aesthetic weight": aesthetic_weight,
|
||||
"Aesthetic steps": aesthetic_steps,
|
||||
"Aesthetic embedding": self.image_embs_name,
|
||||
"Aesthetic slerp": aesthetic_slerp,
|
||||
"Aesthetic text": aesthetic_imgs_text,
|
||||
"Aesthetic text negative": aesthetic_text_negative,
|
||||
"Aesthetic slerp angle": aesthetic_slerp_angle,
|
||||
})
|
||||
|
||||
def set_skip(self, skip):
|
||||
self.skip = skip
|
||||
|
||||
def load_image_embs(self, image_embs_name):
|
||||
if image_embs_name is None or len(image_embs_name) == 0 or image_embs_name == "None":
|
||||
image_embs_name = None
|
||||
self.image_embs_name = None
|
||||
if image_embs_name is not None and self.image_embs_name != image_embs_name:
|
||||
self.image_embs_name = image_embs_name
|
||||
self.image_embs = torch.load(shared.aesthetic_embeddings[self.image_embs_name], map_location=device)
|
||||
self.image_embs /= self.image_embs.norm(dim=-1, keepdim=True)
|
||||
self.image_embs.requires_grad_(False)
|
||||
|
||||
def __call__(self, z, remade_batch_tokens):
|
||||
if not self.skip and self.aesthetic_steps != 0 and self.aesthetic_lr != 0 and self.aesthetic_weight != 0 and self.image_embs_name is not None:
|
||||
tokenizer = shared.sd_model.cond_stage_model.tokenizer
|
||||
if not opts.use_old_emphasis_implementation:
|
||||
remade_batch_tokens = [
|
||||
[tokenizer.bos_token_id] + x[:75] + [tokenizer.eos_token_id] for x in
|
||||
remade_batch_tokens]
|
||||
|
||||
tokens = torch.asarray(remade_batch_tokens).to(device)
|
||||
|
||||
model = copy.deepcopy(aesthetic_clip()).to(device)
|
||||
model.requires_grad_(True)
|
||||
if self.aesthetic_imgs_text is not None and len(self.aesthetic_imgs_text) > 0:
|
||||
text_embs_2 = model.get_text_features(
|
||||
**tokenizer([self.aesthetic_imgs_text], padding=True, return_tensors="pt").to(device))
|
||||
if self.aesthetic_text_negative:
|
||||
text_embs_2 = self.image_embs - text_embs_2
|
||||
text_embs_2 /= text_embs_2.norm(dim=-1, keepdim=True)
|
||||
img_embs = slerp(self.image_embs, text_embs_2, self.aesthetic_slerp_angle)
|
||||
else:
|
||||
img_embs = self.image_embs
|
||||
|
||||
with torch.enable_grad():
|
||||
|
||||
# We optimize the model to maximize the similarity
|
||||
optimizer = optim.Adam(
|
||||
model.text_model.parameters(), lr=self.aesthetic_lr
|
||||
)
|
||||
|
||||
for _ in trange(self.aesthetic_steps, desc="Aesthetic optimization"):
|
||||
text_embs = model.get_text_features(input_ids=tokens)
|
||||
text_embs = text_embs / text_embs.norm(dim=-1, keepdim=True)
|
||||
sim = text_embs @ img_embs.T
|
||||
loss = -sim
|
||||
optimizer.zero_grad()
|
||||
loss.mean().backward()
|
||||
optimizer.step()
|
||||
|
||||
zn = model.text_model(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers)
|
||||
if opts.CLIP_stop_at_last_layers > 1:
|
||||
zn = zn.hidden_states[-opts.CLIP_stop_at_last_layers]
|
||||
zn = model.text_model.final_layer_norm(zn)
|
||||
else:
|
||||
zn = zn.last_hidden_state
|
||||
model.cpu()
|
||||
del model
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
zn = torch.concat([zn[77 * i:77 * (i + 1)] for i in range(max(z.shape[1] // 77, 1))], 1)
|
||||
if self.slerp:
|
||||
z = slerp(z, zn, self.aesthetic_weight)
|
||||
else:
|
||||
z = z * (1 - self.aesthetic_weight) + zn * self.aesthetic_weight
|
||||
|
||||
return z
|
|
@ -39,9 +39,12 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
|
|||
|
||||
if input_dir == '':
|
||||
return outputs, "Please select an input directory.", ''
|
||||
image_list = [file for file in [os.path.join(input_dir, x) for x in os.listdir(input_dir)] if os.path.isfile(file)]
|
||||
image_list = [file for file in [os.path.join(input_dir, x) for x in sorted(os.listdir(input_dir))] if os.path.isfile(file)]
|
||||
for img in image_list:
|
||||
image = Image.open(img)
|
||||
try:
|
||||
image = Image.open(img)
|
||||
except Exception:
|
||||
continue
|
||||
imageArr.append(image)
|
||||
imageNameArr.append(img)
|
||||
else:
|
||||
|
@ -118,10 +121,14 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
|
|||
|
||||
while len(cached_images) > 2:
|
||||
del cached_images[next(iter(cached_images.keys()))]
|
||||
|
||||
if opts.use_original_name_batch and image_name != None:
|
||||
basename = os.path.splitext(os.path.basename(image_name))[0]
|
||||
else:
|
||||
basename = ''
|
||||
|
||||
images.save_image(image, path=outpath, basename="", seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True,
|
||||
no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo,
|
||||
forced_filename=image_name if opts.use_original_name_batch else None)
|
||||
images.save_image(image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True,
|
||||
no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None)
|
||||
|
||||
if opts.enable_pnginfo:
|
||||
image.info = existing_pnginfo
|
||||
|
|
|
@ -4,13 +4,22 @@ import gradio as gr
|
|||
from modules.shared import script_path
|
||||
from modules import shared
|
||||
|
||||
re_param_code = r"\s*([\w ]+):\s*([^,]+)(?:,|$)"
|
||||
re_param_code = r'\s*([\w ]+):\s*("(?:\\|\"|[^\"])+"|[^,]*)(?:,|$)'
|
||||
re_param = re.compile(re_param_code)
|
||||
re_params = re.compile(r"^(?:" + re_param_code + "){3,}$")
|
||||
re_imagesize = re.compile(r"^(\d+)x(\d+)$")
|
||||
type_of_gr_update = type(gr.update())
|
||||
|
||||
|
||||
def quote(text):
|
||||
if ',' not in str(text):
|
||||
return text
|
||||
|
||||
text = str(text)
|
||||
text = text.replace('\\', '\\\\')
|
||||
text = text.replace('"', '\\"')
|
||||
return f'"{text}"'
|
||||
|
||||
def parse_generation_parameters(x: str):
|
||||
"""parses generation parameters string, the one you see in text field under the picture in UI:
|
||||
```
|
||||
|
@ -83,7 +92,12 @@ def connect_paste(button, paste_fields, input_comp, js=None):
|
|||
else:
|
||||
try:
|
||||
valtype = type(output.value)
|
||||
val = valtype(v)
|
||||
|
||||
if valtype == bool and v == "False":
|
||||
val = False
|
||||
else:
|
||||
val = valtype(v)
|
||||
|
||||
res.append(gr.update(value=val))
|
||||
except Exception:
|
||||
res.append(gr.update())
|
||||
|
|
|
@ -22,16 +22,26 @@ from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
|||
class HypernetworkModule(torch.nn.Module):
|
||||
multiplier = 1.0
|
||||
|
||||
def __init__(self, dim, state_dict=None, layer_structure=None, add_layer_norm=False):
|
||||
def __init__(self, dim, state_dict=None, layer_structure=None, add_layer_norm=False, activation_func=None):
|
||||
super().__init__()
|
||||
|
||||
assert layer_structure is not None, "layer_structure mut not be None"
|
||||
assert layer_structure is not None, "layer_structure must not be None"
|
||||
assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
|
||||
assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
|
||||
|
||||
linears = []
|
||||
for i in range(len(layer_structure) - 1):
|
||||
linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
|
||||
|
||||
if activation_func == "relu":
|
||||
linears.append(torch.nn.ReLU())
|
||||
elif activation_func == "leakyrelu":
|
||||
linears.append(torch.nn.LeakyReLU())
|
||||
elif activation_func == 'linear' or activation_func is None:
|
||||
pass
|
||||
else:
|
||||
raise RuntimeError(f'hypernetwork uses an unsupported activation function: {activation_func}')
|
||||
|
||||
if add_layer_norm:
|
||||
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
|
||||
|
||||
|
@ -42,8 +52,9 @@ class HypernetworkModule(torch.nn.Module):
|
|||
self.load_state_dict(state_dict)
|
||||
else:
|
||||
for layer in self.linear:
|
||||
layer.weight.data.normal_(mean=0.0, std=0.01)
|
||||
layer.bias.data.zero_()
|
||||
if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm:
|
||||
layer.weight.data.normal_(mean=0.0, std=0.01)
|
||||
layer.bias.data.zero_()
|
||||
|
||||
self.to(devices.device)
|
||||
|
||||
|
@ -69,7 +80,8 @@ class HypernetworkModule(torch.nn.Module):
|
|||
def trainables(self):
|
||||
layer_structure = []
|
||||
for layer in self.linear:
|
||||
layer_structure += [layer.weight, layer.bias]
|
||||
if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm:
|
||||
layer_structure += [layer.weight, layer.bias]
|
||||
return layer_structure
|
||||
|
||||
|
||||
|
@ -81,7 +93,7 @@ class Hypernetwork:
|
|||
filename = None
|
||||
name = None
|
||||
|
||||
def __init__(self, name=None, enable_sizes=None, layer_structure=None, add_layer_norm=False):
|
||||
def __init__(self, name=None, enable_sizes=None, layer_structure=None, add_layer_norm=False, activation_func=None):
|
||||
self.filename = None
|
||||
self.name = name
|
||||
self.layers = {}
|
||||
|
@ -90,11 +102,12 @@ class Hypernetwork:
|
|||
self.sd_checkpoint_name = None
|
||||
self.layer_structure = layer_structure
|
||||
self.add_layer_norm = add_layer_norm
|
||||
self.activation_func = activation_func
|
||||
|
||||
for size in enable_sizes or []:
|
||||
self.layers[size] = (
|
||||
HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm),
|
||||
HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm),
|
||||
HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm, self.activation_func),
|
||||
HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm, self.activation_func),
|
||||
)
|
||||
|
||||
def weights(self):
|
||||
|
@ -117,6 +130,7 @@ class Hypernetwork:
|
|||
state_dict['name'] = self.name
|
||||
state_dict['layer_structure'] = self.layer_structure
|
||||
state_dict['is_layer_norm'] = self.add_layer_norm
|
||||
state_dict['activation_func'] = self.activation_func
|
||||
state_dict['sd_checkpoint'] = self.sd_checkpoint
|
||||
state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
|
||||
|
||||
|
@ -131,12 +145,13 @@ class Hypernetwork:
|
|||
|
||||
self.layer_structure = state_dict.get('layer_structure', [1, 2, 1])
|
||||
self.add_layer_norm = state_dict.get('is_layer_norm', False)
|
||||
self.activation_func = state_dict.get('activation_func', None)
|
||||
|
||||
for size, sd in state_dict.items():
|
||||
if type(size) == int:
|
||||
self.layers[size] = (
|
||||
HypernetworkModule(size, sd[0], self.layer_structure, self.add_layer_norm),
|
||||
HypernetworkModule(size, sd[1], self.layer_structure, self.add_layer_norm),
|
||||
HypernetworkModule(size, sd[0], self.layer_structure, self.add_layer_norm, self.activation_func),
|
||||
HypernetworkModule(size, sd[1], self.layer_structure, self.add_layer_norm, self.activation_func),
|
||||
)
|
||||
|
||||
self.name = state_dict.get('name', self.name)
|
||||
|
@ -241,6 +256,9 @@ def stack_conds(conds):
|
|||
|
||||
|
||||
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
||||
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
|
||||
from modules import images
|
||||
|
||||
assert hypernetwork_name, 'hypernetwork not selected'
|
||||
|
||||
path = shared.hypernetworks.get(hypernetwork_name, None)
|
||||
|
@ -283,6 +301,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
|||
|
||||
last_saved_file = "<none>"
|
||||
last_saved_image = "<none>"
|
||||
forced_filename = "<none>"
|
||||
|
||||
ititial_step = hypernetwork.step or 0
|
||||
if ititial_step > steps:
|
||||
|
@ -321,7 +340,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
|||
pbar.set_description(f"loss: {mean_loss:.7f}")
|
||||
|
||||
if hypernetwork.step > 0 and hypernetwork_dir is not None and hypernetwork.step % save_hypernetwork_every == 0:
|
||||
last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name}-{hypernetwork.step}.pt')
|
||||
# Before saving, change name to match current checkpoint.
|
||||
hypernetwork.name = f'{hypernetwork_name}-{hypernetwork.step}'
|
||||
last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork.name}.pt')
|
||||
hypernetwork.save(last_saved_file)
|
||||
|
||||
textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), {
|
||||
|
@ -330,7 +351,8 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
|||
})
|
||||
|
||||
if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0:
|
||||
last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png')
|
||||
forced_filename = f'{hypernetwork_name}-{hypernetwork.step}'
|
||||
last_saved_image = os.path.join(images_dir, forced_filename)
|
||||
|
||||
optimizer.zero_grad()
|
||||
shared.sd_model.cond_stage_model.to(devices.device)
|
||||
|
@ -366,7 +388,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
|||
|
||||
if image is not None:
|
||||
shared.state.current_image = image
|
||||
image.save(last_saved_image)
|
||||
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename)
|
||||
last_saved_image += f", prompt: {preview_text}"
|
||||
|
||||
shared.state.job_no = hypernetwork.step
|
||||
|
@ -376,7 +398,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
|||
Loss: {mean_loss:.7f}<br/>
|
||||
Step: {hypernetwork.step}<br/>
|
||||
Last prompt: {html.escape(entries[0].cond_text)}<br/>
|
||||
Last saved embedding: {html.escape(last_saved_file)}<br/>
|
||||
Last saved hypernetwork: {html.escape(last_saved_file)}<br/>
|
||||
Last saved image: {html.escape(last_saved_image)}<br/>
|
||||
</p>
|
||||
"""
|
||||
|
@ -385,6 +407,9 @@ Last saved image: {html.escape(last_saved_image)}<br/>
|
|||
|
||||
hypernetwork.sd_checkpoint = checkpoint.hash
|
||||
hypernetwork.sd_checkpoint_name = checkpoint.model_name
|
||||
# Before saving for the last time, change name back to the base name (as opposed to the save_hypernetwork_every step-suffixed naming convention).
|
||||
hypernetwork.name = hypernetwork_name
|
||||
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork.name}.pt')
|
||||
hypernetwork.save(filename)
|
||||
|
||||
return hypernetwork, filename
|
||||
|
|
|
@ -10,9 +10,13 @@ from modules import sd_hijack, shared, devices
|
|||
from modules.hypernetworks import hypernetwork
|
||||
|
||||
|
||||
def create_hypernetwork(name, enable_sizes, layer_structure=None, add_layer_norm=False):
|
||||
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, add_layer_norm=False, activation_func=None):
|
||||
# Remove illegal characters from name.
|
||||
name = "".join( x for x in name if (x.isalnum() or x in "._- "))
|
||||
|
||||
fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
|
||||
assert not os.path.exists(fn), f"file {fn} already exists"
|
||||
if not overwrite_old:
|
||||
assert not os.path.exists(fn), f"file {fn} already exists"
|
||||
|
||||
if type(layer_structure) == str:
|
||||
layer_structure = [float(x.strip()) for x in layer_structure.split(",")]
|
||||
|
@ -22,6 +26,7 @@ def create_hypernetwork(name, enable_sizes, layer_structure=None, add_layer_norm
|
|||
enable_sizes=[int(x) for x in enable_sizes],
|
||||
layer_structure=layer_structure,
|
||||
add_layer_norm=add_layer_norm,
|
||||
activation_func=activation_func,
|
||||
)
|
||||
hypernet.save(fn)
|
||||
|
||||
|
|
|
@ -56,7 +56,7 @@ def process_batch(p, input_dir, output_dir, args):
|
|||
processed_image.save(os.path.join(output_dir, filename))
|
||||
|
||||
|
||||
def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, init_img_with_mask, init_img_inpaint, init_mask_inpaint, mask_mode, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args):
|
||||
def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, init_img_with_mask, init_img_inpaint, init_mask_inpaint, mask_mode, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0, aesthetic_imgs=None, aesthetic_slerp=False, aesthetic_imgs_text="", aesthetic_slerp_angle=0.15, aesthetic_text_negative=False, *args):
|
||||
is_inpaint = mode == 1
|
||||
is_batch = mode == 2
|
||||
|
||||
|
@ -109,6 +109,8 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
|
|||
inpainting_mask_invert=inpainting_mask_invert,
|
||||
)
|
||||
|
||||
shared.aesthetic_clip.set_aesthetic_params(p, float(aesthetic_lr), float(aesthetic_weight), int(aesthetic_steps), aesthetic_imgs, aesthetic_slerp, aesthetic_imgs_text, aesthetic_slerp_angle, aesthetic_text_negative)
|
||||
|
||||
if shared.cmd_opts.enable_console_prompts:
|
||||
print(f"\nimg2img: {prompt}", file=shared.progress_print_out)
|
||||
|
||||
|
|
|
@ -28,9 +28,11 @@ class InterrogateModels:
|
|||
clip_preprocess = None
|
||||
categories = None
|
||||
dtype = None
|
||||
running_on_cpu = None
|
||||
|
||||
def __init__(self, content_dir):
|
||||
self.categories = []
|
||||
self.running_on_cpu = devices.device_interrogate == torch.device("cpu")
|
||||
|
||||
if os.path.exists(content_dir):
|
||||
for filename in os.listdir(content_dir):
|
||||
|
@ -53,7 +55,11 @@ class InterrogateModels:
|
|||
def load_clip_model(self):
|
||||
import clip
|
||||
|
||||
model, preprocess = clip.load(clip_model_name)
|
||||
if self.running_on_cpu:
|
||||
model, preprocess = clip.load(clip_model_name, device="cpu")
|
||||
else:
|
||||
model, preprocess = clip.load(clip_model_name)
|
||||
|
||||
model.eval()
|
||||
model = model.to(devices.device_interrogate)
|
||||
|
||||
|
@ -62,14 +68,14 @@ class InterrogateModels:
|
|||
def load(self):
|
||||
if self.blip_model is None:
|
||||
self.blip_model = self.load_blip_model()
|
||||
if not shared.cmd_opts.no_half:
|
||||
if not shared.cmd_opts.no_half and not self.running_on_cpu:
|
||||
self.blip_model = self.blip_model.half()
|
||||
|
||||
self.blip_model = self.blip_model.to(devices.device_interrogate)
|
||||
|
||||
if self.clip_model is None:
|
||||
self.clip_model, self.clip_preprocess = self.load_clip_model()
|
||||
if not shared.cmd_opts.no_half:
|
||||
if not shared.cmd_opts.no_half and not self.running_on_cpu:
|
||||
self.clip_model = self.clip_model.half()
|
||||
|
||||
self.clip_model = self.clip_model.to(devices.device_interrogate)
|
||||
|
|
|
@ -12,7 +12,7 @@ from skimage import exposure
|
|||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import modules.sd_hijack
|
||||
from modules import devices, prompt_parser, masking, sd_samplers, lowvram
|
||||
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste
|
||||
from modules.sd_hijack import model_hijack
|
||||
from modules.shared import opts, cmd_opts, state
|
||||
import modules.shared as shared
|
||||
|
@ -304,7 +304,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
|
|||
"Size": f"{p.width}x{p.height}",
|
||||
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
|
||||
"Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
|
||||
"Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.filename.split('\\')[-1].split('.')[0]),
|
||||
"Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name),
|
||||
"Batch size": (None if p.batch_size < 2 else p.batch_size),
|
||||
"Batch pos": (None if p.batch_size < 2 else position_in_batch),
|
||||
"Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
|
||||
|
@ -318,7 +318,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
|
|||
|
||||
generation_params.update(p.extra_generation_params)
|
||||
|
||||
generation_params_text = ", ".join([k if k == v else f'{k}: {v}' for k, v in generation_params.items() if v is not None])
|
||||
generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None])
|
||||
|
||||
negative_prompt_text = "\nNegative prompt: " + p.negative_prompt if p.negative_prompt else ""
|
||||
|
||||
|
@ -540,17 +540,37 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||
self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f
|
||||
self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f
|
||||
|
||||
def create_dummy_mask(self, x, width=None, height=None):
|
||||
if self.sampler.conditioning_key in {'hybrid', 'concat'}:
|
||||
height = height or self.height
|
||||
width = width or self.width
|
||||
|
||||
# The "masked-image" in this case will just be all zeros since the entire image is masked.
|
||||
image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
|
||||
image_conditioning = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image_conditioning))
|
||||
|
||||
# Add the fake full 1s mask to the first dimension.
|
||||
image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
|
||||
image_conditioning = image_conditioning.to(x.dtype)
|
||||
|
||||
else:
|
||||
# Dummy zero conditioning if we're not using inpainting model.
|
||||
# Still takes up a bit of memory, but no encoder call.
|
||||
# Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
|
||||
image_conditioning = torch.zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device)
|
||||
|
||||
return image_conditioning
|
||||
|
||||
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
|
||||
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
|
||||
|
||||
if not self.enable_hr:
|
||||
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
|
||||
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning)
|
||||
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.create_dummy_mask(x))
|
||||
return samples
|
||||
|
||||
x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
|
||||
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning)
|
||||
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.create_dummy_mask(x, self.firstphase_width, self.firstphase_height))
|
||||
|
||||
samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2]
|
||||
|
||||
|
@ -587,7 +607,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||
x = None
|
||||
devices.torch_gc()
|
||||
|
||||
samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps)
|
||||
samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps, image_conditioning=self.create_dummy_mask(samples))
|
||||
|
||||
return samples
|
||||
|
||||
|
@ -613,6 +633,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
|||
self.inpainting_mask_invert = inpainting_mask_invert
|
||||
self.mask = None
|
||||
self.nmask = None
|
||||
self.image_conditioning = None
|
||||
|
||||
def init(self, all_prompts, all_seeds, all_subseeds):
|
||||
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers_for_img2img, self.sampler_index, self.sd_model)
|
||||
|
@ -714,10 +735,39 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
|||
elif self.inpainting_fill == 3:
|
||||
self.init_latent = self.init_latent * self.mask
|
||||
|
||||
if self.sampler.conditioning_key in {'hybrid', 'concat'}:
|
||||
if self.image_mask is not None:
|
||||
conditioning_mask = np.array(self.image_mask.convert("L"))
|
||||
conditioning_mask = conditioning_mask.astype(np.float32) / 255.0
|
||||
conditioning_mask = torch.from_numpy(conditioning_mask[None, None])
|
||||
|
||||
# Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0
|
||||
conditioning_mask = torch.round(conditioning_mask)
|
||||
else:
|
||||
conditioning_mask = torch.ones(1, 1, *image.shape[-2:])
|
||||
|
||||
# Create another latent image, this time with a masked version of the original input.
|
||||
conditioning_mask = conditioning_mask.to(image.device)
|
||||
conditioning_image = image * (1.0 - conditioning_mask)
|
||||
conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image))
|
||||
|
||||
# Create the concatenated conditioning tensor to be fed to `c_concat`
|
||||
conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=self.init_latent.shape[-2:])
|
||||
conditioning_mask = conditioning_mask.expand(conditioning_image.shape[0], -1, -1, -1)
|
||||
self.image_conditioning = torch.cat([conditioning_mask, conditioning_image], dim=1)
|
||||
self.image_conditioning = self.image_conditioning.to(shared.device).type(self.sd_model.dtype)
|
||||
else:
|
||||
self.image_conditioning = torch.zeros(
|
||||
self.init_latent.shape[0], 5, 1, 1,
|
||||
dtype=self.init_latent.dtype,
|
||||
device=self.init_latent.device
|
||||
)
|
||||
|
||||
|
||||
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
|
||||
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
|
||||
|
||||
samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning)
|
||||
samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)
|
||||
|
||||
if self.mask is not None:
|
||||
samples = samples * self.nmask + self.init_latent * self.mask
|
||||
|
|
|
@ -19,6 +19,7 @@ attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
|
|||
diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
|
||||
diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
|
||||
|
||||
|
||||
def apply_optimizations():
|
||||
undo_optimizations()
|
||||
|
||||
|
@ -167,11 +168,11 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
|||
|
||||
remade_tokens = remade_tokens[:last_comma]
|
||||
length = len(remade_tokens)
|
||||
|
||||
|
||||
rem = int(math.ceil(length / 75)) * 75 - length
|
||||
remade_tokens += [id_end] * rem + reloc_tokens
|
||||
multipliers = multipliers[:last_comma] + [1.0] * rem + reloc_mults
|
||||
|
||||
|
||||
if embedding is None:
|
||||
remade_tokens.append(token)
|
||||
multipliers.append(weight)
|
||||
|
@ -223,7 +224,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
|||
|
||||
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
|
||||
|
||||
|
||||
def process_text_old(self, text):
|
||||
id_start = self.wrapped.tokenizer.bos_token_id
|
||||
id_end = self.wrapped.tokenizer.eos_token_id
|
||||
|
@ -280,7 +280,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
|||
|
||||
token_count = len(remade_tokens)
|
||||
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
|
||||
remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end]
|
||||
remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end]
|
||||
cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
|
||||
|
||||
multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
|
||||
|
@ -290,7 +290,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
|||
hijack_fixes.append(fixes)
|
||||
batch_multipliers.append(multipliers)
|
||||
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
|
||||
|
||||
|
||||
def forward(self, text):
|
||||
use_old = opts.use_old_emphasis_implementation
|
||||
if use_old:
|
||||
|
@ -302,11 +302,11 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
|||
|
||||
if len(used_custom_terms) > 0:
|
||||
self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
|
||||
|
||||
|
||||
if use_old:
|
||||
self.hijack.fixes = hijack_fixes
|
||||
return self.process_tokens(remade_batch_tokens, batch_multipliers)
|
||||
|
||||
|
||||
z = None
|
||||
i = 0
|
||||
while max(map(len, remade_batch_tokens)) != 0:
|
||||
|
@ -320,7 +320,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
|||
if fix[0] == i:
|
||||
fixes.append(fix[1])
|
||||
self.hijack.fixes.append(fixes)
|
||||
|
||||
|
||||
tokens = []
|
||||
multipliers = []
|
||||
for j in range(len(remade_batch_tokens)):
|
||||
|
@ -332,20 +332,20 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
|||
multipliers.append([1.0] * 75)
|
||||
|
||||
z1 = self.process_tokens(tokens, multipliers)
|
||||
z1 = shared.aesthetic_clip(z1, remade_batch_tokens)
|
||||
z = z1 if z is None else torch.cat((z, z1), axis=-2)
|
||||
|
||||
|
||||
remade_batch_tokens = rem_tokens
|
||||
batch_multipliers = rem_multipliers
|
||||
i += 1
|
||||
|
||||
|
||||
return z
|
||||
|
||||
|
||||
|
||||
def process_tokens(self, remade_batch_tokens, batch_multipliers):
|
||||
if not opts.use_old_emphasis_implementation:
|
||||
remade_batch_tokens = [[self.wrapped.tokenizer.bos_token_id] + x[:75] + [self.wrapped.tokenizer.eos_token_id] for x in remade_batch_tokens]
|
||||
batch_multipliers = [[1.0] + x[:75] + [1.0] for x in batch_multipliers]
|
||||
|
||||
|
||||
tokens = torch.asarray(remade_batch_tokens).to(device)
|
||||
outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers)
|
||||
|
||||
|
@ -385,8 +385,8 @@ class EmbeddingsWithFixes(torch.nn.Module):
|
|||
for fixes, tensor in zip(batch_fixes, inputs_embeds):
|
||||
for offset, embedding in fixes:
|
||||
emb = embedding.vec
|
||||
emb_len = min(tensor.shape[0]-offset-1, emb.shape[0])
|
||||
tensor = torch.cat([tensor[0:offset+1], emb[0:emb_len], tensor[offset+1+emb_len:]])
|
||||
emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])
|
||||
tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]])
|
||||
|
||||
vecs.append(tensor)
|
||||
|
||||
|
|
331
modules/sd_hijack_inpainting.py
Normal file
331
modules/sd_hijack_inpainting.py
Normal file
|
@ -0,0 +1,331 @@
|
|||
import torch
|
||||
|
||||
from einops import repeat
|
||||
from omegaconf import ListConfig
|
||||
|
||||
import ldm.models.diffusion.ddpm
|
||||
import ldm.models.diffusion.ddim
|
||||
import ldm.models.diffusion.plms
|
||||
|
||||
from ldm.models.diffusion.ddpm import LatentDiffusion
|
||||
from ldm.models.diffusion.plms import PLMSSampler
|
||||
from ldm.models.diffusion.ddim import DDIMSampler, noise_like
|
||||
|
||||
# =================================================================================================
|
||||
# Monkey patch DDIMSampler methods from RunwayML repo directly.
|
||||
# Adapted from:
|
||||
# https://github.com/runwayml/stable-diffusion/blob/main/ldm/models/diffusion/ddim.py
|
||||
# =================================================================================================
|
||||
@torch.no_grad()
|
||||
def sample_ddim(self,
|
||||
S,
|
||||
batch_size,
|
||||
shape,
|
||||
conditioning=None,
|
||||
callback=None,
|
||||
normals_sequence=None,
|
||||
img_callback=None,
|
||||
quantize_x0=False,
|
||||
eta=0.,
|
||||
mask=None,
|
||||
x0=None,
|
||||
temperature=1.,
|
||||
noise_dropout=0.,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
verbose=True,
|
||||
x_T=None,
|
||||
log_every_t=100,
|
||||
unconditional_guidance_scale=1.,
|
||||
unconditional_conditioning=None,
|
||||
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||
**kwargs
|
||||
):
|
||||
if conditioning is not None:
|
||||
if isinstance(conditioning, dict):
|
||||
ctmp = conditioning[list(conditioning.keys())[0]]
|
||||
while isinstance(ctmp, list):
|
||||
ctmp = ctmp[0]
|
||||
cbs = ctmp.shape[0]
|
||||
if cbs != batch_size:
|
||||
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
||||
else:
|
||||
if conditioning.shape[0] != batch_size:
|
||||
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
||||
|
||||
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
|
||||
# sampling
|
||||
C, H, W = shape
|
||||
size = (batch_size, C, H, W)
|
||||
print(f'Data shape for DDIM sampling is {size}, eta {eta}')
|
||||
|
||||
samples, intermediates = self.ddim_sampling(conditioning, size,
|
||||
callback=callback,
|
||||
img_callback=img_callback,
|
||||
quantize_denoised=quantize_x0,
|
||||
mask=mask, x0=x0,
|
||||
ddim_use_original_steps=False,
|
||||
noise_dropout=noise_dropout,
|
||||
temperature=temperature,
|
||||
score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs,
|
||||
x_T=x_T,
|
||||
log_every_t=log_every_t,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
)
|
||||
return samples, intermediates
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
||||
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1., unconditional_conditioning=None):
|
||||
b, *_, device = *x.shape, x.device
|
||||
|
||||
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
||||
e_t = self.model.apply_model(x, t, c)
|
||||
else:
|
||||
x_in = torch.cat([x] * 2)
|
||||
t_in = torch.cat([t] * 2)
|
||||
if isinstance(c, dict):
|
||||
assert isinstance(unconditional_conditioning, dict)
|
||||
c_in = dict()
|
||||
for k in c:
|
||||
if isinstance(c[k], list):
|
||||
c_in[k] = [
|
||||
torch.cat([unconditional_conditioning[k][i], c[k][i]])
|
||||
for i in range(len(c[k]))
|
||||
]
|
||||
else:
|
||||
c_in[k] = torch.cat([unconditional_conditioning[k], c[k]])
|
||||
else:
|
||||
c_in = torch.cat([unconditional_conditioning, c])
|
||||
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
||||
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
||||
|
||||
if score_corrector is not None:
|
||||
assert self.model.parameterization == "eps"
|
||||
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
|
||||
|
||||
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
||||
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
|
||||
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
|
||||
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
||||
# select parameters corresponding to the currently considered timestep
|
||||
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
||||
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
||||
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
||||
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
|
||||
|
||||
# current prediction for x_0
|
||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||
if quantize_denoised:
|
||||
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
||||
# direction pointing to x_t
|
||||
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
||||
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
||||
if noise_dropout > 0.:
|
||||
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
||||
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
||||
return x_prev, pred_x0
|
||||
|
||||
|
||||
# =================================================================================================
|
||||
# Monkey patch PLMSSampler methods.
|
||||
# This one was not actually patched correctly in the RunwayML repo, but we can replicate the changes.
|
||||
# Adapted from:
|
||||
# https://github.com/CompVis/stable-diffusion/blob/main/ldm/models/diffusion/plms.py
|
||||
# =================================================================================================
|
||||
@torch.no_grad()
|
||||
def sample_plms(self,
|
||||
S,
|
||||
batch_size,
|
||||
shape,
|
||||
conditioning=None,
|
||||
callback=None,
|
||||
normals_sequence=None,
|
||||
img_callback=None,
|
||||
quantize_x0=False,
|
||||
eta=0.,
|
||||
mask=None,
|
||||
x0=None,
|
||||
temperature=1.,
|
||||
noise_dropout=0.,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
verbose=True,
|
||||
x_T=None,
|
||||
log_every_t=100,
|
||||
unconditional_guidance_scale=1.,
|
||||
unconditional_conditioning=None,
|
||||
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||
**kwargs
|
||||
):
|
||||
if conditioning is not None:
|
||||
if isinstance(conditioning, dict):
|
||||
ctmp = conditioning[list(conditioning.keys())[0]]
|
||||
while isinstance(ctmp, list):
|
||||
ctmp = ctmp[0]
|
||||
cbs = ctmp.shape[0]
|
||||
if cbs != batch_size:
|
||||
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
||||
else:
|
||||
if conditioning.shape[0] != batch_size:
|
||||
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
||||
|
||||
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
|
||||
# sampling
|
||||
C, H, W = shape
|
||||
size = (batch_size, C, H, W)
|
||||
print(f'Data shape for PLMS sampling is {size}')
|
||||
|
||||
samples, intermediates = self.plms_sampling(conditioning, size,
|
||||
callback=callback,
|
||||
img_callback=img_callback,
|
||||
quantize_denoised=quantize_x0,
|
||||
mask=mask, x0=x0,
|
||||
ddim_use_original_steps=False,
|
||||
noise_dropout=noise_dropout,
|
||||
temperature=temperature,
|
||||
score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs,
|
||||
x_T=x_T,
|
||||
log_every_t=log_every_t,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
)
|
||||
return samples, intermediates
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
||||
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None):
|
||||
b, *_, device = *x.shape, x.device
|
||||
|
||||
def get_model_output(x, t):
|
||||
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
||||
e_t = self.model.apply_model(x, t, c)
|
||||
else:
|
||||
x_in = torch.cat([x] * 2)
|
||||
t_in = torch.cat([t] * 2)
|
||||
|
||||
if isinstance(c, dict):
|
||||
assert isinstance(unconditional_conditioning, dict)
|
||||
c_in = dict()
|
||||
for k in c:
|
||||
if isinstance(c[k], list):
|
||||
c_in[k] = [
|
||||
torch.cat([unconditional_conditioning[k][i], c[k][i]])
|
||||
for i in range(len(c[k]))
|
||||
]
|
||||
else:
|
||||
c_in[k] = torch.cat([unconditional_conditioning[k], c[k]])
|
||||
else:
|
||||
c_in = torch.cat([unconditional_conditioning, c])
|
||||
|
||||
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
||||
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
||||
|
||||
if score_corrector is not None:
|
||||
assert self.model.parameterization == "eps"
|
||||
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
|
||||
|
||||
return e_t
|
||||
|
||||
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
||||
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
|
||||
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
|
||||
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
||||
|
||||
def get_x_prev_and_pred_x0(e_t, index):
|
||||
# select parameters corresponding to the currently considered timestep
|
||||
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
||||
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
||||
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
||||
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
|
||||
|
||||
# current prediction for x_0
|
||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||
if quantize_denoised:
|
||||
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
||||
# direction pointing to x_t
|
||||
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
||||
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
||||
if noise_dropout > 0.:
|
||||
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
||||
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
||||
return x_prev, pred_x0
|
||||
|
||||
e_t = get_model_output(x, t)
|
||||
if len(old_eps) == 0:
|
||||
# Pseudo Improved Euler (2nd order)
|
||||
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
|
||||
e_t_next = get_model_output(x_prev, t_next)
|
||||
e_t_prime = (e_t + e_t_next) / 2
|
||||
elif len(old_eps) == 1:
|
||||
# 2nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||
e_t_prime = (3 * e_t - old_eps[-1]) / 2
|
||||
elif len(old_eps) == 2:
|
||||
# 3nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||
e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
|
||||
elif len(old_eps) >= 3:
|
||||
# 4nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||
e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
|
||||
|
||||
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
|
||||
|
||||
return x_prev, pred_x0, e_t
|
||||
|
||||
# =================================================================================================
|
||||
# Monkey patch LatentInpaintDiffusion to load the checkpoint with a proper config.
|
||||
# Adapted from:
|
||||
# https://github.com/runwayml/stable-diffusion/blob/main/ldm/models/diffusion/ddpm.py
|
||||
# =================================================================================================
|
||||
|
||||
@torch.no_grad()
|
||||
def get_unconditional_conditioning(self, batch_size, null_label=None):
|
||||
if null_label is not None:
|
||||
xc = null_label
|
||||
if isinstance(xc, ListConfig):
|
||||
xc = list(xc)
|
||||
if isinstance(xc, dict) or isinstance(xc, list):
|
||||
c = self.get_learned_conditioning(xc)
|
||||
else:
|
||||
if hasattr(xc, "to"):
|
||||
xc = xc.to(self.device)
|
||||
c = self.get_learned_conditioning(xc)
|
||||
else:
|
||||
# todo: get null label from cond_stage_model
|
||||
raise NotImplementedError()
|
||||
c = repeat(c, "1 ... -> b ...", b=batch_size).to(self.device)
|
||||
return c
|
||||
|
||||
|
||||
class LatentInpaintDiffusion(LatentDiffusion):
|
||||
def __init__(
|
||||
self,
|
||||
concat_keys=("mask", "masked_image"),
|
||||
masked_image_key="masked_image",
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.masked_image_key = masked_image_key
|
||||
assert self.masked_image_key in concat_keys
|
||||
self.concat_keys = concat_keys
|
||||
|
||||
|
||||
def should_hijack_inpainting(checkpoint_info):
|
||||
return str(checkpoint_info.filename).endswith("inpainting.ckpt") and not checkpoint_info.config.endswith("inpainting.yaml")
|
||||
|
||||
|
||||
def do_inpainting_hijack():
|
||||
ldm.models.diffusion.ddpm.get_unconditional_conditioning = get_unconditional_conditioning
|
||||
ldm.models.diffusion.ddpm.LatentInpaintDiffusion = LatentInpaintDiffusion
|
||||
|
||||
ldm.models.diffusion.ddim.DDIMSampler.p_sample_ddim = p_sample_ddim
|
||||
ldm.models.diffusion.ddim.DDIMSampler.sample = sample_ddim
|
||||
|
||||
ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms
|
||||
ldm.models.diffusion.plms.PLMSSampler.sample = sample_plms
|
|
@ -9,6 +9,7 @@ from ldm.util import instantiate_from_config
|
|||
|
||||
from modules import shared, modelloader, devices
|
||||
from modules.paths import models_path
|
||||
from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting
|
||||
|
||||
model_dir = "Stable-diffusion"
|
||||
model_path = os.path.abspath(os.path.join(models_path, model_dir))
|
||||
|
@ -20,7 +21,7 @@ checkpoints_loaded = collections.OrderedDict()
|
|||
try:
|
||||
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
|
||||
|
||||
from transformers import logging
|
||||
from transformers import logging, CLIPModel
|
||||
|
||||
logging.set_verbosity_error()
|
||||
except Exception:
|
||||
|
@ -154,6 +155,9 @@ def get_state_dict_from_checkpoint(pl_sd):
|
|||
return pl_sd
|
||||
|
||||
|
||||
vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"}
|
||||
|
||||
|
||||
def load_model_weights(model, checkpoint_info):
|
||||
checkpoint_file = checkpoint_info.filename
|
||||
sd_model_hash = checkpoint_info.hash
|
||||
|
@ -185,7 +189,7 @@ def load_model_weights(model, checkpoint_info):
|
|||
if os.path.exists(vae_file):
|
||||
print(f"Loading VAE weights from: {vae_file}")
|
||||
vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location)
|
||||
vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"}
|
||||
vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys}
|
||||
model.first_stage_model.load_state_dict(vae_dict)
|
||||
|
||||
model.first_stage_model.to(devices.dtype_vae)
|
||||
|
@ -203,14 +207,26 @@ def load_model_weights(model, checkpoint_info):
|
|||
model.sd_checkpoint_info = checkpoint_info
|
||||
|
||||
|
||||
def load_model():
|
||||
def load_model(checkpoint_info=None):
|
||||
from modules import lowvram, sd_hijack
|
||||
checkpoint_info = select_checkpoint()
|
||||
checkpoint_info = checkpoint_info or select_checkpoint()
|
||||
|
||||
if checkpoint_info.config != shared.cmd_opts.config:
|
||||
print(f"Loading config from: {checkpoint_info.config}")
|
||||
|
||||
sd_config = OmegaConf.load(checkpoint_info.config)
|
||||
|
||||
if should_hijack_inpainting(checkpoint_info):
|
||||
# Hardcoded config for now...
|
||||
sd_config.model.target = "ldm.models.diffusion.ddpm.LatentInpaintDiffusion"
|
||||
sd_config.model.params.use_ema = False
|
||||
sd_config.model.params.conditioning_key = "hybrid"
|
||||
sd_config.model.params.unet_config.params.in_channels = 9
|
||||
|
||||
# Create a "fake" config with a different name so that we know to unload it when switching models.
|
||||
checkpoint_info = checkpoint_info._replace(config=checkpoint_info.config.replace(".yaml", "-inpainting.yaml"))
|
||||
|
||||
do_inpainting_hijack()
|
||||
sd_model = instantiate_from_config(sd_config.model)
|
||||
load_model_weights(sd_model, checkpoint_info)
|
||||
|
||||
|
@ -234,9 +250,9 @@ def reload_model_weights(sd_model, info=None):
|
|||
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
|
||||
return
|
||||
|
||||
if sd_model.sd_checkpoint_info.config != checkpoint_info.config:
|
||||
if sd_model.sd_checkpoint_info.config != checkpoint_info.config or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info):
|
||||
checkpoints_loaded.clear()
|
||||
shared.sd_model = load_model()
|
||||
shared.sd_model = load_model(checkpoint_info)
|
||||
return shared.sd_model
|
||||
|
||||
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
||||
|
|
|
@ -117,6 +117,8 @@ class VanillaStableDiffusionSampler:
|
|||
self.config = None
|
||||
self.last_latent = None
|
||||
|
||||
self.conditioning_key = sd_model.model.conditioning_key
|
||||
|
||||
def number_of_needed_noises(self, p):
|
||||
return 0
|
||||
|
||||
|
@ -136,6 +138,12 @@ class VanillaStableDiffusionSampler:
|
|||
if self.stop_at is not None and self.step > self.stop_at:
|
||||
raise InterruptedException
|
||||
|
||||
# Have to unwrap the inpainting conditioning here to perform pre-processing
|
||||
image_conditioning = None
|
||||
if isinstance(cond, dict):
|
||||
image_conditioning = cond["c_concat"][0]
|
||||
cond = cond["c_crossattn"][0]
|
||||
unconditional_conditioning = unconditional_conditioning["c_crossattn"][0]
|
||||
|
||||
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
|
||||
unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step)
|
||||
|
@ -157,6 +165,12 @@ class VanillaStableDiffusionSampler:
|
|||
img_orig = self.sampler.model.q_sample(self.init_latent, ts)
|
||||
x_dec = img_orig * self.mask + self.nmask * x_dec
|
||||
|
||||
# Wrap the image conditioning back up since the DDIM code can accept the dict directly.
|
||||
# Note that they need to be lists because it just concatenates them later.
|
||||
if image_conditioning is not None:
|
||||
cond = {"c_concat": [image_conditioning], "c_crossattn": [cond]}
|
||||
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
|
||||
|
||||
res = self.orig_p_sample_ddim(x_dec, cond, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs)
|
||||
|
||||
if self.mask is not None:
|
||||
|
@ -182,7 +196,7 @@ class VanillaStableDiffusionSampler:
|
|||
self.mask = p.mask if hasattr(p, 'mask') else None
|
||||
self.nmask = p.nmask if hasattr(p, 'nmask') else None
|
||||
|
||||
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None):
|
||||
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||
steps, t_enc = setup_img2img_steps(p, steps)
|
||||
|
||||
self.initialize(p)
|
||||
|
@ -196,20 +210,33 @@ class VanillaStableDiffusionSampler:
|
|||
x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise)
|
||||
|
||||
self.init_latent = x
|
||||
self.last_latent = x
|
||||
self.step = 0
|
||||
|
||||
# Wrap the conditioning models with additional image conditioning for inpainting model
|
||||
if image_conditioning is not None:
|
||||
conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]}
|
||||
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
|
||||
|
||||
|
||||
samples = self.launch_sampling(steps, lambda: self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning))
|
||||
|
||||
return samples
|
||||
|
||||
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None):
|
||||
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||
self.initialize(p)
|
||||
|
||||
self.init_latent = None
|
||||
self.last_latent = x
|
||||
self.step = 0
|
||||
|
||||
steps = steps or p.steps
|
||||
|
||||
# Wrap the conditioning models with additional image conditioning for inpainting model
|
||||
if image_conditioning is not None:
|
||||
conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]}
|
||||
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
|
||||
|
||||
# existing code fails with certain step counts, like 9
|
||||
try:
|
||||
samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
|
||||
|
@ -228,7 +255,7 @@ class CFGDenoiser(torch.nn.Module):
|
|||
self.init_latent = None
|
||||
self.step = 0
|
||||
|
||||
def forward(self, x, sigma, uncond, cond, cond_scale):
|
||||
def forward(self, x, sigma, uncond, cond, cond_scale, image_cond):
|
||||
if state.interrupted or state.skipped:
|
||||
raise InterruptedException
|
||||
|
||||
|
@ -239,28 +266,29 @@ class CFGDenoiser(torch.nn.Module):
|
|||
repeats = [len(conds_list[i]) for i in range(batch_size)]
|
||||
|
||||
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
|
||||
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond])
|
||||
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
|
||||
|
||||
if tensor.shape[1] == uncond.shape[1]:
|
||||
cond_in = torch.cat([tensor, uncond])
|
||||
|
||||
if shared.batch_cond_uncond:
|
||||
x_out = self.inner_model(x_in, sigma_in, cond=cond_in)
|
||||
x_out = self.inner_model(x_in, sigma_in, cond={"c_crossattn": [cond_in], "c_concat": [image_cond_in]})
|
||||
else:
|
||||
x_out = torch.zeros_like(x_in)
|
||||
for batch_offset in range(0, x_out.shape[0], batch_size):
|
||||
a = batch_offset
|
||||
b = a + batch_size
|
||||
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=cond_in[a:b])
|
||||
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": [cond_in[a:b]], "c_concat": [image_cond_in[a:b]]})
|
||||
else:
|
||||
x_out = torch.zeros_like(x_in)
|
||||
batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size
|
||||
for batch_offset in range(0, tensor.shape[0], batch_size):
|
||||
a = batch_offset
|
||||
b = min(a + batch_size, tensor.shape[0])
|
||||
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=tensor[a:b])
|
||||
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": [tensor[a:b]], "c_concat": [image_cond_in[a:b]]})
|
||||
|
||||
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=uncond)
|
||||
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond={"c_crossattn": [uncond], "c_concat": [image_cond_in[-uncond.shape[0]:]]})
|
||||
|
||||
denoised_uncond = x_out[-uncond.shape[0]:]
|
||||
denoised = torch.clone(denoised_uncond)
|
||||
|
@ -306,6 +334,8 @@ class KDiffusionSampler:
|
|||
self.config = None
|
||||
self.last_latent = None
|
||||
|
||||
self.conditioning_key = sd_model.model.conditioning_key
|
||||
|
||||
def callback_state(self, d):
|
||||
step = d['i']
|
||||
latent = d["denoised"]
|
||||
|
@ -361,7 +391,7 @@ class KDiffusionSampler:
|
|||
|
||||
return extra_params_kwargs
|
||||
|
||||
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None):
|
||||
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||
steps, t_enc = setup_img2img_steps(p, steps)
|
||||
|
||||
if p.sampler_noise_scheduler_override:
|
||||
|
@ -388,12 +418,18 @@ class KDiffusionSampler:
|
|||
extra_params_kwargs['sigmas'] = sigma_sched
|
||||
|
||||
self.model_wrap_cfg.init_latent = x
|
||||
self.last_latent = x
|
||||
|
||||
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, xi, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
||||
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, xi, extra_args={
|
||||
'cond': conditioning,
|
||||
'image_cond': image_conditioning,
|
||||
'uncond': unconditional_conditioning,
|
||||
'cond_scale': p.cfg_scale
|
||||
}, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
||||
|
||||
return samples
|
||||
|
||||
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None):
|
||||
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning = None):
|
||||
steps = steps or p.steps
|
||||
|
||||
if p.sampler_noise_scheduler_override:
|
||||
|
@ -414,7 +450,13 @@ class KDiffusionSampler:
|
|||
else:
|
||||
extra_params_kwargs['sigmas'] = sigmas
|
||||
|
||||
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
||||
self.last_latent = x
|
||||
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={
|
||||
'cond': conditioning,
|
||||
'image_cond': image_conditioning,
|
||||
'uncond': unconditional_conditioning,
|
||||
'cond_scale': p.cfg_scale
|
||||
}, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
||||
|
||||
return samples
|
||||
|
||||
|
|
|
@ -3,6 +3,7 @@ import datetime
|
|||
import json
|
||||
import os
|
||||
import sys
|
||||
from collections import OrderedDict
|
||||
|
||||
import gradio as gr
|
||||
import tqdm
|
||||
|
@ -30,6 +31,7 @@ parser.add_argument("--no-half-vae", action='store_true', help="do not switch th
|
|||
parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)")
|
||||
parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI")
|
||||
parser.add_argument("--embeddings-dir", type=str, default=os.path.join(script_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)")
|
||||
parser.add_argument("--aesthetic_embeddings-dir", type=str, default=os.path.join(models_path, 'aesthetic_embeddings'), help="aesthetic_embeddings directory(default: aesthetic_embeddings)")
|
||||
parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory")
|
||||
parser.add_argument("--localizations-dir", type=str, default=os.path.join(script_path, 'localizations'), help="localizations directory")
|
||||
parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui")
|
||||
|
@ -106,6 +108,21 @@ os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True)
|
|||
hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir)
|
||||
loaded_hypernetwork = None
|
||||
|
||||
|
||||
os.makedirs(cmd_opts.aesthetic_embeddings_dir, exist_ok=True)
|
||||
aesthetic_embeddings = {}
|
||||
|
||||
|
||||
def update_aesthetic_embeddings():
|
||||
global aesthetic_embeddings
|
||||
aesthetic_embeddings = {f.replace(".pt", ""): os.path.join(cmd_opts.aesthetic_embeddings_dir, f) for f in
|
||||
os.listdir(cmd_opts.aesthetic_embeddings_dir) if f.endswith(".pt")}
|
||||
aesthetic_embeddings = OrderedDict(**{"None": None}, **aesthetic_embeddings)
|
||||
|
||||
|
||||
update_aesthetic_embeddings()
|
||||
|
||||
|
||||
def reload_hypernetworks():
|
||||
global hypernetworks
|
||||
|
||||
|
@ -249,7 +266,7 @@ options_templates.update(options_section(('system', "System"), {
|
|||
}))
|
||||
|
||||
options_templates.update(options_section(('training', "Training"), {
|
||||
"unload_models_when_training": OptionInfo(False, "Unload VAE and CLIP from VRAM when training"),
|
||||
"unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training hypernetwork. Saves VRAM."),
|
||||
"dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
|
||||
"dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
|
||||
"training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
|
||||
|
@ -387,6 +404,11 @@ sd_upscalers = []
|
|||
|
||||
sd_model = None
|
||||
|
||||
clip_model = None
|
||||
|
||||
from modules.aesthetic_clip import AestheticCLIP
|
||||
aesthetic_clip = AestheticCLIP()
|
||||
|
||||
progress_print_out = sys.stdout
|
||||
|
||||
|
||||
|
|
|
@ -83,7 +83,7 @@ class PersonalizedBase(Dataset):
|
|||
|
||||
self.dataset.append(entry)
|
||||
|
||||
assert len(self.dataset) > 1, "No images have been found in the dataset."
|
||||
assert len(self.dataset) > 0, "No images have been found in the dataset."
|
||||
self.length = len(self.dataset) * repeats // batch_size
|
||||
|
||||
self.initial_indexes = np.arange(len(self.dataset))
|
||||
|
@ -91,7 +91,7 @@ class PersonalizedBase(Dataset):
|
|||
self.shuffle()
|
||||
|
||||
def shuffle(self):
|
||||
self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0])]
|
||||
self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0]).numpy()]
|
||||
|
||||
def create_text(self, filename_text):
|
||||
text = random.choice(self.lines)
|
||||
|
|
|
@ -5,6 +5,7 @@ import zlib
|
|||
from PIL import Image, PngImagePlugin, ImageDraw, ImageFont
|
||||
from fonts.ttf import Roboto
|
||||
import torch
|
||||
from modules.shared import opts
|
||||
|
||||
|
||||
class EmbeddingEncoder(json.JSONEncoder):
|
||||
|
@ -133,7 +134,7 @@ def caption_image_overlay(srcimage, title, footerLeft, footerMid, footerRight, t
|
|||
from math import cos
|
||||
|
||||
image = srcimage.copy()
|
||||
|
||||
fontsize = 32
|
||||
if textfont is None:
|
||||
try:
|
||||
textfont = ImageFont.truetype(opts.font or Roboto, fontsize)
|
||||
|
@ -150,7 +151,7 @@ def caption_image_overlay(srcimage, title, footerLeft, footerMid, footerRight, t
|
|||
image = Image.alpha_composite(image.convert('RGBA'), gradient.resize(image.size))
|
||||
|
||||
draw = ImageDraw.Draw(image)
|
||||
fontsize = 32
|
||||
|
||||
font = ImageFont.truetype(textfont, fontsize)
|
||||
padding = 10
|
||||
|
||||
|
|
|
@ -12,7 +12,7 @@ if cmd_opts.deepdanbooru:
|
|||
import modules.deepbooru as deepbooru
|
||||
|
||||
|
||||
def preprocess(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2):
|
||||
def preprocess(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2):
|
||||
try:
|
||||
if process_caption:
|
||||
shared.interrogator.load()
|
||||
|
@ -22,7 +22,7 @@ def preprocess(process_src, process_dst, process_width, process_height, process_
|
|||
db_opts[deepbooru.OPT_INCLUDE_RANKS] = False
|
||||
deepbooru.create_deepbooru_process(opts.interrogate_deepbooru_score_threshold, db_opts)
|
||||
|
||||
preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru, split_threshold, overlap_ratio)
|
||||
preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru, split_threshold, overlap_ratio)
|
||||
|
||||
finally:
|
||||
|
||||
|
@ -34,7 +34,7 @@ def preprocess(process_src, process_dst, process_width, process_height, process_
|
|||
|
||||
|
||||
|
||||
def preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2):
|
||||
def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2):
|
||||
width = process_width
|
||||
height = process_height
|
||||
src = os.path.abspath(process_src)
|
||||
|
@ -51,7 +51,7 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro
|
|||
shared.state.textinfo = "Preprocessing..."
|
||||
shared.state.job_count = len(files)
|
||||
|
||||
def save_pic_with_caption(image, index):
|
||||
def save_pic_with_caption(image, index, existing_caption=None):
|
||||
caption = ""
|
||||
|
||||
if process_caption:
|
||||
|
@ -69,17 +69,26 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro
|
|||
basename = f"{index:05}-{subindex[0]}-{filename_part}"
|
||||
image.save(os.path.join(dst, f"{basename}.png"))
|
||||
|
||||
if preprocess_txt_action == 'prepend' and existing_caption:
|
||||
caption = existing_caption + ' ' + caption
|
||||
elif preprocess_txt_action == 'append' and existing_caption:
|
||||
caption = caption + ' ' + existing_caption
|
||||
elif preprocess_txt_action == 'copy' and existing_caption:
|
||||
caption = existing_caption
|
||||
|
||||
caption = caption.strip()
|
||||
|
||||
if len(caption) > 0:
|
||||
with open(os.path.join(dst, f"{basename}.txt"), "w", encoding="utf8") as file:
|
||||
file.write(caption)
|
||||
|
||||
subindex[0] += 1
|
||||
|
||||
def save_pic(image, index):
|
||||
save_pic_with_caption(image, index)
|
||||
def save_pic(image, index, existing_caption=None):
|
||||
save_pic_with_caption(image, index, existing_caption=existing_caption)
|
||||
|
||||
if process_flip:
|
||||
save_pic_with_caption(ImageOps.mirror(image), index)
|
||||
save_pic_with_caption(ImageOps.mirror(image), index, existing_caption=existing_caption)
|
||||
|
||||
def split_pic(image, inverse_xy):
|
||||
if inverse_xy:
|
||||
|
@ -112,6 +121,13 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro
|
|||
except Exception:
|
||||
continue
|
||||
|
||||
existing_caption = None
|
||||
|
||||
try:
|
||||
existing_caption = open(os.path.splitext(filename)[0] + '.txt', 'r').read()
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
if shared.state.interrupted:
|
||||
break
|
||||
|
||||
|
@ -124,9 +140,9 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro
|
|||
|
||||
if process_split and ratio < 1.0 and ratio <= split_threshold:
|
||||
for splitted in split_pic(img, inverse_xy):
|
||||
save_pic(splitted, index)
|
||||
save_pic(splitted, index, existing_caption=existing_caption)
|
||||
else:
|
||||
img = images.resize_image(1, img, width, height)
|
||||
save_pic(img, index)
|
||||
save_pic(img, index, existing_caption=existing_caption)
|
||||
|
||||
shared.state.nextjob()
|
||||
|
|
|
@ -153,7 +153,7 @@ class EmbeddingDatabase:
|
|||
return None, None
|
||||
|
||||
|
||||
def create_embedding(name, num_vectors_per_token, init_text='*'):
|
||||
def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'):
|
||||
cond_model = shared.sd_model.cond_stage_model
|
||||
embedding_layer = cond_model.wrapped.transformer.text_model.embeddings
|
||||
|
||||
|
@ -165,7 +165,8 @@ def create_embedding(name, num_vectors_per_token, init_text='*'):
|
|||
vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token]
|
||||
|
||||
fn = os.path.join(shared.cmd_opts.embeddings_dir, f"{name}.pt")
|
||||
assert not os.path.exists(fn), f"file {fn} already exists"
|
||||
if not overwrite_old:
|
||||
assert not os.path.exists(fn), f"file {fn} already exists"
|
||||
|
||||
embedding = Embedding(vec, name)
|
||||
embedding.step = 0
|
||||
|
@ -275,6 +276,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
|
|||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
|
||||
epoch_num = embedding.step // len(ds)
|
||||
epoch_step = embedding.step - (epoch_num * len(ds)) + 1
|
||||
|
||||
|
|
|
@ -7,8 +7,8 @@ import modules.textual_inversion.preprocess
|
|||
from modules import sd_hijack, shared
|
||||
|
||||
|
||||
def create_embedding(name, initialization_text, nvpt):
|
||||
filename = modules.textual_inversion.textual_inversion.create_embedding(name, nvpt, init_text=initialization_text)
|
||||
def create_embedding(name, initialization_text, nvpt, overwrite_old):
|
||||
filename = modules.textual_inversion.textual_inversion.create_embedding(name, nvpt, overwrite_old, init_text=initialization_text)
|
||||
|
||||
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
|
||||
|
||||
|
|
|
@ -1,12 +1,13 @@
|
|||
import modules.scripts
|
||||
from modules.processing import StableDiffusionProcessing, Processed, StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
|
||||
from modules.processing import StableDiffusionProcessing, Processed, StableDiffusionProcessingTxt2Img, \
|
||||
StableDiffusionProcessingImg2Img, process_images
|
||||
from modules.shared import opts, cmd_opts
|
||||
import modules.shared as shared
|
||||
import modules.processing as processing
|
||||
from modules.ui import plaintext_to_html
|
||||
|
||||
|
||||
def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, firstphase_width: int, firstphase_height: int, *args):
|
||||
def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, firstphase_width: int, firstphase_height: int, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0, aesthetic_imgs=None, aesthetic_slerp=False, aesthetic_imgs_text="", aesthetic_slerp_angle=0.15, aesthetic_text_negative=False, *args):
|
||||
p = StableDiffusionProcessingTxt2Img(
|
||||
sd_model=shared.sd_model,
|
||||
outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
|
||||
|
@ -35,6 +36,8 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
|
|||
firstphase_height=firstphase_height if enable_hr else None,
|
||||
)
|
||||
|
||||
shared.aesthetic_clip.set_aesthetic_params(p, float(aesthetic_lr), float(aesthetic_weight), int(aesthetic_steps), aesthetic_imgs, aesthetic_slerp, aesthetic_imgs_text, aesthetic_slerp_angle, aesthetic_text_negative)
|
||||
|
||||
if cmd_opts.enable_console_prompts:
|
||||
print(f"\ntxt2img: {prompt}", file=shared.progress_print_out)
|
||||
|
||||
|
@ -53,4 +56,3 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
|
|||
processed.images = []
|
||||
|
||||
return processed.images, generation_info_js, plaintext_to_html(processed.info)
|
||||
|
||||
|
|
130
modules/ui.py
130
modules/ui.py
|
@ -25,7 +25,9 @@ import gradio.routes
|
|||
|
||||
from modules import sd_hijack, sd_models, localization
|
||||
from modules.paths import script_path
|
||||
from modules.shared import opts, cmd_opts, restricted_opts
|
||||
|
||||
from modules.shared import opts, cmd_opts, restricted_opts, aesthetic_embeddings
|
||||
|
||||
if cmd_opts.deepdanbooru:
|
||||
from modules.deepbooru import get_deepbooru_tags
|
||||
import modules.shared as shared
|
||||
|
@ -41,8 +43,11 @@ from modules import prompt_parser
|
|||
from modules.images import save_image
|
||||
import modules.textual_inversion.ui
|
||||
import modules.hypernetworks.ui
|
||||
|
||||
import modules.aesthetic_clip as aesthetic_clip
|
||||
import modules.images_history as img_his
|
||||
|
||||
|
||||
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI
|
||||
mimetypes.init()
|
||||
mimetypes.add_type('application/javascript', '.js')
|
||||
|
@ -592,27 +597,29 @@ def apply_setting(key, value):
|
|||
return value
|
||||
|
||||
|
||||
def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
|
||||
def refresh():
|
||||
refresh_method()
|
||||
args = refreshed_args() if callable(refreshed_args) else refreshed_args
|
||||
|
||||
for k, v in args.items():
|
||||
setattr(refresh_component, k, v)
|
||||
|
||||
return gr.update(**(args or {}))
|
||||
|
||||
refresh_button = gr.Button(value=refresh_symbol, elem_id=elem_id)
|
||||
refresh_button.click(
|
||||
fn=refresh,
|
||||
inputs=[],
|
||||
outputs=[refresh_component]
|
||||
)
|
||||
return refresh_button
|
||||
|
||||
|
||||
def create_ui(wrap_gradio_gpu_call):
|
||||
import modules.img2img
|
||||
import modules.txt2img
|
||||
|
||||
def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
|
||||
def refresh():
|
||||
refresh_method()
|
||||
args = refreshed_args() if callable(refreshed_args) else refreshed_args
|
||||
|
||||
for k, v in args.items():
|
||||
setattr(refresh_component, k, v)
|
||||
|
||||
return gr.update(**(args or {}))
|
||||
|
||||
refresh_button = gr.Button(value=refresh_symbol, elem_id=elem_id)
|
||||
refresh_button.click(
|
||||
fn = refresh,
|
||||
inputs = [],
|
||||
outputs = [refresh_component]
|
||||
)
|
||||
return refresh_button
|
||||
|
||||
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
|
||||
txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False)
|
||||
|
@ -655,6 +662,8 @@ def create_ui(wrap_gradio_gpu_call):
|
|||
|
||||
seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs()
|
||||
|
||||
aesthetic_weight, aesthetic_steps, aesthetic_lr, aesthetic_slerp, aesthetic_imgs, aesthetic_imgs_text, aesthetic_slerp_angle, aesthetic_text_negative = aesthetic_clip.create_ui()
|
||||
|
||||
with gr.Group():
|
||||
custom_inputs = modules.scripts.scripts_txt2img.setup_ui(is_img2img=False)
|
||||
|
||||
|
@ -709,7 +718,16 @@ def create_ui(wrap_gradio_gpu_call):
|
|||
denoising_strength,
|
||||
firstphase_width,
|
||||
firstphase_height,
|
||||
aesthetic_lr,
|
||||
aesthetic_weight,
|
||||
aesthetic_steps,
|
||||
aesthetic_imgs,
|
||||
aesthetic_slerp,
|
||||
aesthetic_imgs_text,
|
||||
aesthetic_slerp_angle,
|
||||
aesthetic_text_negative
|
||||
] + custom_inputs,
|
||||
|
||||
outputs=[
|
||||
txt2img_gallery,
|
||||
generation_info,
|
||||
|
@ -786,6 +804,14 @@ def create_ui(wrap_gradio_gpu_call):
|
|||
(hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)),
|
||||
(firstphase_width, "First pass size-1"),
|
||||
(firstphase_height, "First pass size-2"),
|
||||
(aesthetic_lr, "Aesthetic LR"),
|
||||
(aesthetic_weight, "Aesthetic weight"),
|
||||
(aesthetic_steps, "Aesthetic steps"),
|
||||
(aesthetic_imgs, "Aesthetic embedding"),
|
||||
(aesthetic_slerp, "Aesthetic slerp"),
|
||||
(aesthetic_imgs_text, "Aesthetic text"),
|
||||
(aesthetic_text_negative, "Aesthetic text negative"),
|
||||
(aesthetic_slerp_angle, "Aesthetic slerp angle"),
|
||||
]
|
||||
|
||||
txt2img_preview_params = [
|
||||
|
@ -853,8 +879,8 @@ def create_ui(wrap_gradio_gpu_call):
|
|||
sampler_index = gr.Radio(label='Sampling method', choices=[x.name for x in samplers_for_img2img], value=samplers_for_img2img[0].name, type="index")
|
||||
|
||||
with gr.Group():
|
||||
width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
|
||||
height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
|
||||
width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512, elem_id="img2img_width")
|
||||
height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512, elem_id="img2img_height")
|
||||
|
||||
with gr.Row():
|
||||
restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1)
|
||||
|
@ -870,6 +896,8 @@ def create_ui(wrap_gradio_gpu_call):
|
|||
|
||||
seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs()
|
||||
|
||||
aesthetic_weight_im, aesthetic_steps_im, aesthetic_lr_im, aesthetic_slerp_im, aesthetic_imgs_im, aesthetic_imgs_text_im, aesthetic_slerp_angle_im, aesthetic_text_negative_im = aesthetic_clip.create_ui()
|
||||
|
||||
with gr.Group():
|
||||
custom_inputs = modules.scripts.scripts_img2img.setup_ui(is_img2img=True)
|
||||
|
||||
|
@ -960,6 +988,14 @@ def create_ui(wrap_gradio_gpu_call):
|
|||
inpainting_mask_invert,
|
||||
img2img_batch_input_dir,
|
||||
img2img_batch_output_dir,
|
||||
aesthetic_lr_im,
|
||||
aesthetic_weight_im,
|
||||
aesthetic_steps_im,
|
||||
aesthetic_imgs_im,
|
||||
aesthetic_slerp_im,
|
||||
aesthetic_imgs_text_im,
|
||||
aesthetic_slerp_angle_im,
|
||||
aesthetic_text_negative_im,
|
||||
] + custom_inputs,
|
||||
outputs=[
|
||||
img2img_gallery,
|
||||
|
@ -1051,6 +1087,14 @@ def create_ui(wrap_gradio_gpu_call):
|
|||
(seed_resize_from_w, "Seed resize from-1"),
|
||||
(seed_resize_from_h, "Seed resize from-2"),
|
||||
(denoising_strength, "Denoising strength"),
|
||||
(aesthetic_lr_im, "Aesthetic LR"),
|
||||
(aesthetic_weight_im, "Aesthetic weight"),
|
||||
(aesthetic_steps_im, "Aesthetic steps"),
|
||||
(aesthetic_imgs_im, "Aesthetic embedding"),
|
||||
(aesthetic_slerp_im, "Aesthetic slerp"),
|
||||
(aesthetic_imgs_text_im, "Aesthetic text"),
|
||||
(aesthetic_text_negative_im, "Aesthetic text negative"),
|
||||
(aesthetic_slerp_angle_im, "Aesthetic slerp angle"),
|
||||
]
|
||||
token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter])
|
||||
|
||||
|
@ -1211,6 +1255,7 @@ def create_ui(wrap_gradio_gpu_call):
|
|||
new_embedding_name = gr.Textbox(label="Name")
|
||||
initialization_text = gr.Textbox(label="Initialization text", value="*")
|
||||
nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1)
|
||||
overwrite_old_embedding = gr.Checkbox(value=False, label="Overwrite Old Embedding")
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column(scale=3):
|
||||
|
@ -1219,11 +1264,25 @@ def create_ui(wrap_gradio_gpu_call):
|
|||
with gr.Column():
|
||||
create_embedding = gr.Button(value="Create embedding", variant='primary')
|
||||
|
||||
with gr.Tab(label="Create aesthetic images embedding"):
|
||||
|
||||
new_embedding_name_ae = gr.Textbox(label="Name")
|
||||
process_src_ae = gr.Textbox(label='Source directory')
|
||||
batch_ae = gr.Slider(minimum=1, maximum=1024, step=1, label="Batch size", value=256)
|
||||
with gr.Row():
|
||||
with gr.Column(scale=3):
|
||||
gr.HTML(value="")
|
||||
|
||||
with gr.Column():
|
||||
create_embedding_ae = gr.Button(value="Create images embedding", variant='primary')
|
||||
|
||||
with gr.Tab(label="Create hypernetwork"):
|
||||
new_hypernetwork_name = gr.Textbox(label="Name")
|
||||
new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"])
|
||||
new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'")
|
||||
new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization")
|
||||
overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork")
|
||||
new_hypernetwork_activation_func = gr.Dropdown(value="relu", label="Select activation function of hypernetwork", choices=["linear", "relu", "leakyrelu"])
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column(scale=3):
|
||||
|
@ -1237,6 +1296,7 @@ def create_ui(wrap_gradio_gpu_call):
|
|||
process_dst = gr.Textbox(label='Destination directory')
|
||||
process_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
|
||||
process_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
|
||||
preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', value="ignore", choices=["ignore", "copy", "prepend", "append"])
|
||||
|
||||
with gr.Row():
|
||||
process_flip = gr.Checkbox(label='Create flipped copies')
|
||||
|
@ -1262,14 +1322,17 @@ def create_ui(wrap_gradio_gpu_call):
|
|||
)
|
||||
|
||||
with gr.Tab(label="Train"):
|
||||
gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding; must specify a directory with a set of 1:1 ratio images</p>")
|
||||
gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images <a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\" style=\"font-weight:bold;\">[wiki]</a></p>")
|
||||
with gr.Row():
|
||||
train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
|
||||
create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name")
|
||||
with gr.Row():
|
||||
train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=[x for x in shared.hypernetworks.keys()])
|
||||
create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted([x for x in shared.hypernetworks.keys()])}, "refresh_train_hypernetwork_name")
|
||||
learn_rate = gr.Textbox(label='Learning rate', placeholder="Learning rate", value="0.005")
|
||||
with gr.Row():
|
||||
embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005")
|
||||
hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001")
|
||||
|
||||
batch_size = gr.Number(label='Batch size', value=1, precision=0)
|
||||
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
|
||||
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
|
||||
|
@ -1303,6 +1366,7 @@ def create_ui(wrap_gradio_gpu_call):
|
|||
new_embedding_name,
|
||||
initialization_text,
|
||||
nvpt,
|
||||
overwrite_old_embedding,
|
||||
],
|
||||
outputs=[
|
||||
train_embedding_name,
|
||||
|
@ -1311,13 +1375,30 @@ def create_ui(wrap_gradio_gpu_call):
|
|||
]
|
||||
)
|
||||
|
||||
create_embedding_ae.click(
|
||||
fn=aesthetic_clip.generate_imgs_embd,
|
||||
inputs=[
|
||||
new_embedding_name_ae,
|
||||
process_src_ae,
|
||||
batch_ae
|
||||
],
|
||||
outputs=[
|
||||
aesthetic_imgs,
|
||||
aesthetic_imgs_im,
|
||||
ti_output,
|
||||
ti_outcome,
|
||||
]
|
||||
)
|
||||
|
||||
create_hypernetwork.click(
|
||||
fn=modules.hypernetworks.ui.create_hypernetwork,
|
||||
inputs=[
|
||||
new_hypernetwork_name,
|
||||
new_hypernetwork_sizes,
|
||||
overwrite_old_hypernetwork,
|
||||
new_hypernetwork_layer_structure,
|
||||
new_hypernetwork_add_layer_norm,
|
||||
new_hypernetwork_activation_func,
|
||||
],
|
||||
outputs=[
|
||||
train_hypernetwork_name,
|
||||
|
@ -1334,6 +1415,7 @@ def create_ui(wrap_gradio_gpu_call):
|
|||
process_dst,
|
||||
process_width,
|
||||
process_height,
|
||||
preprocess_txt_action,
|
||||
process_flip,
|
||||
process_split,
|
||||
process_caption,
|
||||
|
@ -1352,7 +1434,7 @@ def create_ui(wrap_gradio_gpu_call):
|
|||
_js="start_training_textual_inversion",
|
||||
inputs=[
|
||||
train_embedding_name,
|
||||
learn_rate,
|
||||
embedding_learn_rate,
|
||||
batch_size,
|
||||
dataset_directory,
|
||||
log_directory,
|
||||
|
@ -1377,7 +1459,7 @@ def create_ui(wrap_gradio_gpu_call):
|
|||
_js="start_training_textual_inversion",
|
||||
inputs=[
|
||||
train_hypernetwork_name,
|
||||
learn_rate,
|
||||
hypernetwork_learn_rate,
|
||||
batch_size,
|
||||
dataset_directory,
|
||||
log_directory,
|
||||
|
|
|
@ -172,54 +172,54 @@ class Script(scripts.Script):
|
|||
if down > 0:
|
||||
down = target_h - init_img.height - up
|
||||
|
||||
init_image = p.init_images[0]
|
||||
|
||||
state.job_count = (1 if left > 0 else 0) + (1 if right > 0 else 0) + (1 if up > 0 else 0) + (1 if down > 0 else 0)
|
||||
|
||||
def expand(init, expand_pixels, is_left=False, is_right=False, is_top=False, is_bottom=False):
|
||||
def expand(init, count, expand_pixels, is_left=False, is_right=False, is_top=False, is_bottom=False):
|
||||
is_horiz = is_left or is_right
|
||||
is_vert = is_top or is_bottom
|
||||
pixels_horiz = expand_pixels if is_horiz else 0
|
||||
pixels_vert = expand_pixels if is_vert else 0
|
||||
|
||||
res_w = init.width + pixels_horiz
|
||||
res_h = init.height + pixels_vert
|
||||
process_res_w = math.ceil(res_w / 64) * 64
|
||||
process_res_h = math.ceil(res_h / 64) * 64
|
||||
images_to_process = []
|
||||
output_images = []
|
||||
for n in range(count):
|
||||
res_w = init[n].width + pixels_horiz
|
||||
res_h = init[n].height + pixels_vert
|
||||
process_res_w = math.ceil(res_w / 64) * 64
|
||||
process_res_h = math.ceil(res_h / 64) * 64
|
||||
|
||||
img = Image.new("RGB", (process_res_w, process_res_h))
|
||||
img.paste(init, (pixels_horiz if is_left else 0, pixels_vert if is_top else 0))
|
||||
mask = Image.new("RGB", (process_res_w, process_res_h), "white")
|
||||
draw = ImageDraw.Draw(mask)
|
||||
draw.rectangle((
|
||||
expand_pixels + mask_blur if is_left else 0,
|
||||
expand_pixels + mask_blur if is_top else 0,
|
||||
mask.width - expand_pixels - mask_blur if is_right else res_w,
|
||||
mask.height - expand_pixels - mask_blur if is_bottom else res_h,
|
||||
), fill="black")
|
||||
img = Image.new("RGB", (process_res_w, process_res_h))
|
||||
img.paste(init[n], (pixels_horiz if is_left else 0, pixels_vert if is_top else 0))
|
||||
mask = Image.new("RGB", (process_res_w, process_res_h), "white")
|
||||
draw = ImageDraw.Draw(mask)
|
||||
draw.rectangle((
|
||||
expand_pixels + mask_blur if is_left else 0,
|
||||
expand_pixels + mask_blur if is_top else 0,
|
||||
mask.width - expand_pixels - mask_blur if is_right else res_w,
|
||||
mask.height - expand_pixels - mask_blur if is_bottom else res_h,
|
||||
), fill="black")
|
||||
|
||||
np_image = (np.asarray(img) / 255.0).astype(np.float64)
|
||||
np_mask = (np.asarray(mask) / 255.0).astype(np.float64)
|
||||
noised = get_matched_noise(np_image, np_mask, noise_q, color_variation)
|
||||
out = Image.fromarray(np.clip(noised * 255., 0., 255.).astype(np.uint8), mode="RGB")
|
||||
np_image = (np.asarray(img) / 255.0).astype(np.float64)
|
||||
np_mask = (np.asarray(mask) / 255.0).astype(np.float64)
|
||||
noised = get_matched_noise(np_image, np_mask, noise_q, color_variation)
|
||||
output_images.append(Image.fromarray(np.clip(noised * 255., 0., 255.).astype(np.uint8), mode="RGB"))
|
||||
|
||||
target_width = min(process_width, init.width + pixels_horiz) if is_horiz else img.width
|
||||
target_height = min(process_height, init.height + pixels_vert) if is_vert else img.height
|
||||
target_width = min(process_width, init[n].width + pixels_horiz) if is_horiz else img.width
|
||||
target_height = min(process_height, init[n].height + pixels_vert) if is_vert else img.height
|
||||
p.width = target_width if is_horiz else img.width
|
||||
p.height = target_height if is_vert else img.height
|
||||
|
||||
crop_region = (
|
||||
0 if is_left else out.width - target_width,
|
||||
0 if is_top else out.height - target_height,
|
||||
target_width if is_left else out.width,
|
||||
target_height if is_top else out.height,
|
||||
)
|
||||
crop_region = (
|
||||
0 if is_left else output_images[n].width - target_width,
|
||||
0 if is_top else output_images[n].height - target_height,
|
||||
target_width if is_left else output_images[n].width,
|
||||
target_height if is_top else output_images[n].height,
|
||||
)
|
||||
mask = mask.crop(crop_region)
|
||||
p.image_mask = mask
|
||||
|
||||
image_to_process = out.crop(crop_region)
|
||||
mask = mask.crop(crop_region)
|
||||
image_to_process = output_images[n].crop(crop_region)
|
||||
images_to_process.append(image_to_process)
|
||||
|
||||
p.width = target_width if is_horiz else img.width
|
||||
p.height = target_height if is_vert else img.height
|
||||
p.init_images = [image_to_process]
|
||||
p.image_mask = mask
|
||||
p.init_images = images_to_process
|
||||
|
||||
latent_mask = Image.new("RGB", (p.width, p.height), "white")
|
||||
draw = ImageDraw.Draw(latent_mask)
|
||||
|
@ -232,31 +232,52 @@ class Script(scripts.Script):
|
|||
p.latent_mask = latent_mask
|
||||
|
||||
proc = process_images(p)
|
||||
proc_img = proc.images[0]
|
||||
|
||||
if initial_seed_and_info[0] is None:
|
||||
initial_seed_and_info[0] = proc.seed
|
||||
initial_seed_and_info[1] = proc.info
|
||||
|
||||
out.paste(proc_img, (0 if is_left else out.width - proc_img.width, 0 if is_top else out.height - proc_img.height))
|
||||
out = out.crop((0, 0, res_w, res_h))
|
||||
return out
|
||||
for n in range(count):
|
||||
output_images[n].paste(proc.images[n], (0 if is_left else output_images[n].width - proc.images[n].width, 0 if is_top else output_images[n].height - proc.images[n].height))
|
||||
output_images[n] = output_images[n].crop((0, 0, res_w, res_h))
|
||||
|
||||
img = init_image
|
||||
return output_images
|
||||
|
||||
if left > 0:
|
||||
img = expand(img, left, is_left=True)
|
||||
if right > 0:
|
||||
img = expand(img, right, is_right=True)
|
||||
if up > 0:
|
||||
img = expand(img, up, is_top=True)
|
||||
if down > 0:
|
||||
img = expand(img, down, is_bottom=True)
|
||||
batch_count = p.n_iter
|
||||
batch_size = p.batch_size
|
||||
p.n_iter = 1
|
||||
state.job_count = batch_count * ((1 if left > 0 else 0) + (1 if right > 0 else 0) + (1 if up > 0 else 0) + (1 if down > 0 else 0))
|
||||
all_processed_images = []
|
||||
|
||||
res = Processed(p, [img], initial_seed_and_info[0], initial_seed_and_info[1])
|
||||
for i in range(batch_count):
|
||||
imgs = [init_img] * batch_size
|
||||
state.job = f"Batch {i + 1} out of {batch_count}"
|
||||
|
||||
if left > 0:
|
||||
imgs = expand(imgs, batch_size, left, is_left=True)
|
||||
if right > 0:
|
||||
imgs = expand(imgs, batch_size, right, is_right=True)
|
||||
if up > 0:
|
||||
imgs = expand(imgs, batch_size, up, is_top=True)
|
||||
if down > 0:
|
||||
imgs = expand(imgs, batch_size, down, is_bottom=True)
|
||||
|
||||
all_processed_images += imgs
|
||||
|
||||
all_images = all_processed_images
|
||||
|
||||
combined_grid_image = images.image_grid(all_processed_images)
|
||||
unwanted_grid_because_of_img_count = len(all_processed_images) < 2 and opts.grid_only_if_multiple
|
||||
if opts.return_grid and not unwanted_grid_because_of_img_count:
|
||||
all_images = [combined_grid_image] + all_processed_images
|
||||
|
||||
res = Processed(p, all_images, initial_seed_and_info[0], initial_seed_and_info[1])
|
||||
|
||||
if opts.samples_save:
|
||||
images.save_image(img, p.outpath_samples, "", res.seed, p.prompt, opts.grid_format, info=res.info, p=p)
|
||||
for img in all_processed_images:
|
||||
images.save_image(img, p.outpath_samples, "", res.seed, p.prompt, opts.grid_format, info=res.info, p=p)
|
||||
|
||||
if opts.grid_save and not unwanted_grid_because_of_img_count:
|
||||
images.save_image(combined_grid_image, p.outpath_grids, "grid", res.seed, p.prompt, opts.grid_format, info=res.info, short_filename=not opts.grid_extended_filename, grid=True, p=p)
|
||||
|
||||
return res
|
||||
|
||||
|
|
|
@ -89,6 +89,7 @@ def apply_checkpoint(p, x, xs):
|
|||
if info is None:
|
||||
raise RuntimeError(f"Unknown checkpoint: {x}")
|
||||
modules.sd_models.reload_model_weights(shared.sd_model, info)
|
||||
p.sd_model = shared.sd_model
|
||||
|
||||
|
||||
def confirm_checkpoints(p, xs):
|
||||
|
|
|
@ -477,7 +477,7 @@ input[type="range"]{
|
|||
padding: 0;
|
||||
}
|
||||
|
||||
#refresh_sd_model_checkpoint, #refresh_sd_hypernetwork, #refresh_train_hypernetwork_name, #refresh_train_embedding_name, #refresh_localization{
|
||||
#refresh_sd_model_checkpoint, #refresh_sd_hypernetwork, #refresh_train_hypernetwork_name, #refresh_train_embedding_name, #refresh_localization, #refresh_aesthetic_embeddings{
|
||||
max-width: 2.5em;
|
||||
min-width: 2.5em;
|
||||
height: 2.4em;
|
||||
|
|
5
webui.py
5
webui.py
|
@ -118,7 +118,8 @@ def api_only():
|
|||
api.launch(server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", port=cmd_opts.port if cmd_opts.port else 7861)
|
||||
|
||||
|
||||
def webui(launch_api=False):
|
||||
def webui():
|
||||
launch_api = cmd_opts.api
|
||||
initialize()
|
||||
|
||||
while 1:
|
||||
|
@ -158,4 +159,4 @@ if __name__ == "__main__":
|
|||
if cmd_opts.nowebui:
|
||||
api_only()
|
||||
else:
|
||||
webui(cmd_opts.api)
|
||||
webui()
|
||||
|
|
Loading…
Reference in a new issue