Merge branch 'AUTOMATIC1111:master' into master

This commit is contained in:
不会画画的中医不是好程序员 2022-10-16 10:04:05 +08:00 committed by GitHub
commit d41ac174e2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
26 changed files with 407 additions and 171 deletions

36
.github/workflows/on_pull_request.yaml vendored Normal file
View file

@ -0,0 +1,36 @@
# See https://github.com/actions/starter-workflows/blob/1067f16ad8a1eac328834e4b0ae24f7d206f810d/ci/pylint.yml for original reference file
name: Run Linting/Formatting on Pull Requests
on:
- push
- pull_request
# See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#onpull_requestpull_request_targetbranchesbranches-ignore for syntax docs
# if you want to filter out branches, delete the `- pull_request` and uncomment these lines :
# pull_request:
# branches:
# - master
# branches-ignore:
# - development
jobs:
lint:
runs-on: ubuntu-latest
steps:
- name: Checkout Code
uses: actions/checkout@v3
- name: Set up Python 3.10
uses: actions/setup-python@v3
with:
python-version: 3.10.6
- name: Install PyLint
run: |
python -m pip install --upgrade pip
pip install pylint
# This lets PyLint check to see if it can resolve imports
- name: Install dependencies
run : |
export COMMANDLINE_ARGS="--skip-torch-cuda-test --exit"
python launch.py
- name: Analysing the code with pylint
run: |
pylint $(git ls-files '*.py')

3
.pylintrc Normal file
View file

@ -0,0 +1,3 @@
# See https://pylint.pycqa.org/en/latest/user_guide/messages/message_control.html
[MESSAGES CONTROL]
disable=C,R,W,E,I

View file

@ -43,7 +43,7 @@ function dropReplaceImage( imgWrap, files ) {
window.document.addEventListener('dragover', e => { window.document.addEventListener('dragover', e => {
const target = e.composedPath()[0]; const target = e.composedPath()[0];
const imgWrap = target.closest('[data-testid="image"]'); const imgWrap = target.closest('[data-testid="image"]');
if ( !imgWrap && target.placeholder != "Prompt") { if ( !imgWrap && target.placeholder.indexOf("Prompt") == -1) {
return; return;
} }
e.stopPropagation(); e.stopPropagation();
@ -53,7 +53,7 @@ window.document.addEventListener('dragover', e => {
window.document.addEventListener('drop', e => { window.document.addEventListener('drop', e => {
const target = e.composedPath()[0]; const target = e.composedPath()[0];
if (target.placeholder === "Prompt") { if (target.placeholder.indexOf("Prompt") == -1) {
return; return;
} }
const imgWrap = target.closest('[data-testid="image"]'); const imgWrap = target.closest('[data-testid="image"]');

View file

@ -2,6 +2,8 @@ addEventListener('keydown', (event) => {
let target = event.originalTarget || event.composedPath()[0]; let target = event.originalTarget || event.composedPath()[0];
if (!target.hasAttribute("placeholder")) return; if (!target.hasAttribute("placeholder")) return;
if (!target.placeholder.toLowerCase().includes("prompt")) return; if (!target.placeholder.toLowerCase().includes("prompt")) return;
if (! (event.metaKey || event.ctrlKey)) return;
let plus = "ArrowUp" let plus = "ArrowUp"
let minus = "ArrowDown" let minus = "ArrowDown"

View file

@ -16,6 +16,8 @@ titles = {
"\u{1f3a8}": "Add a random artist to the prompt.", "\u{1f3a8}": "Add a random artist to the prompt.",
"\u2199\ufe0f": "Read generation parameters from prompt or last generation if prompt is empty into user interface.", "\u2199\ufe0f": "Read generation parameters from prompt or last generation if prompt is empty into user interface.",
"\u{1f4c2}": "Open images output directory", "\u{1f4c2}": "Open images output directory",
"\u{1f4be}": "Save style",
"\u{1f4cb}": "Apply selected styles to current prompt",
"Inpaint a part of image": "Draw a mask over an image, and the script will regenerate the masked area with content according to prompt", "Inpaint a part of image": "Draw a mask over an image, and the script will regenerate the masked area with content according to prompt",
"SD upscale": "Upscale image normally, split result into tiles, improve each tile using img2img, merge whole image back", "SD upscale": "Upscale image normally, split result into tiles, improve each tile using img2img, merge whole image back",

View file

@ -2,7 +2,7 @@ window.onload = (function(){
window.addEventListener('drop', e => { window.addEventListener('drop', e => {
const target = e.composedPath()[0]; const target = e.composedPath()[0];
const idx = selected_gallery_index(); const idx = selected_gallery_index();
if (target.placeholder != "Prompt") return; if (target.placeholder.indexOf("Prompt") == -1) return;
let prompt_target = get_tab_index('tabs') == 1 ? "img2img_prompt_image" : "txt2img_prompt_image"; let prompt_target = get_tab_index('tabs') == 1 ? "img2img_prompt_image" : "txt2img_prompt_image";

View file

@ -1,5 +1,7 @@
// code related to showing and updating progressbar shown as the image is being made // code related to showing and updating progressbar shown as the image is being made
global_progressbars = {} global_progressbars = {}
galleries = {}
galleryObservers = {}
function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_skip, id_interrupt, id_preview, id_gallery){ function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_skip, id_interrupt, id_preview, id_gallery){
var progressbar = gradioApp().getElementById(id_progressbar) var progressbar = gradioApp().getElementById(id_progressbar)
@ -31,13 +33,24 @@ function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_skip
preview.style.width = gallery.clientWidth + "px" preview.style.width = gallery.clientWidth + "px"
preview.style.height = gallery.clientHeight + "px" preview.style.height = gallery.clientHeight + "px"
//only watch gallery if there is a generation process going on
check_gallery(id_gallery);
var progressDiv = gradioApp().querySelectorAll('#' + id_progressbar_span).length > 0; var progressDiv = gradioApp().querySelectorAll('#' + id_progressbar_span).length > 0;
if(!progressDiv){ if(!progressDiv){
if (skip) { if (skip) {
skip.style.display = "none" skip.style.display = "none"
} }
interrupt.style.display = "none" interrupt.style.display = "none"
//disconnect observer once generation finished, so user can close selected image if they want
if (galleryObservers[id_gallery]) {
galleryObservers[id_gallery].disconnect();
galleries[id_gallery] = null;
}
} }
} }
window.setTimeout(function() { requestMoreProgress(id_part, id_progressbar_span, id_skip, id_interrupt) }, 500) window.setTimeout(function() { requestMoreProgress(id_part, id_progressbar_span, id_skip, id_interrupt) }, 500)
@ -46,6 +59,28 @@ function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_skip
} }
} }
function check_gallery(id_gallery){
let gallery = gradioApp().getElementById(id_gallery)
// if gallery has no change, no need to setting up observer again.
if (gallery && galleries[id_gallery] !== gallery){
galleries[id_gallery] = gallery;
if(galleryObservers[id_gallery]){
galleryObservers[id_gallery].disconnect();
}
let prevSelectedIndex = selected_gallery_index();
galleryObservers[id_gallery] = new MutationObserver(function (){
let galleryButtons = gradioApp().querySelectorAll('#'+id_gallery+' .gallery-item')
let galleryBtnSelected = gradioApp().querySelector('#'+id_gallery+' .gallery-item.\\!ring-2')
if (prevSelectedIndex !== -1 && galleryButtons.length>prevSelectedIndex && !galleryBtnSelected) {
//automatically re-open previously selected index (if exists)
galleryButtons[prevSelectedIndex].click();
showGalleryImage();
}
})
galleryObservers[id_gallery].observe( gallery, { childList:true, subtree:false })
}
}
onUiUpdate(function(){ onUiUpdate(function(){
check_progressbar('txt2img', 'txt2img_progressbar', 'txt2img_progress_span', 'txt2img_skip', 'txt2img_interrupt', 'txt2img_preview', 'txt2img_gallery') check_progressbar('txt2img', 'txt2img_progressbar', 'txt2img_progress_span', 'txt2img_skip', 'txt2img_interrupt', 'txt2img_preview', 'txt2img_gallery')
check_progressbar('img2img', 'img2img_progressbar', 'img2img_progress_span', 'img2img_skip', 'img2img_interrupt', 'img2img_preview', 'img2img_gallery') check_progressbar('img2img', 'img2img_progressbar', 'img2img_progress_span', 'img2img_skip', 'img2img_interrupt', 'img2img_preview', 'img2img_gallery')

View file

@ -141,7 +141,7 @@ function submit_img2img(){
function ask_for_style_name(_, prompt_text, negative_prompt_text) { function ask_for_style_name(_, prompt_text, negative_prompt_text) {
name_ = prompt('Style name:') name_ = prompt('Style name:')
return name_ === null ? [null, null, null]: [name_, prompt_text, negative_prompt_text] return [name_, prompt_text, negative_prompt_text]
} }
@ -187,12 +187,10 @@ onUiUpdate(function(){
if (!txt2img_textarea) { if (!txt2img_textarea) {
txt2img_textarea = gradioApp().querySelector("#txt2img_prompt > label > textarea"); txt2img_textarea = gradioApp().querySelector("#txt2img_prompt > label > textarea");
txt2img_textarea?.addEventListener("input", () => update_token_counter("txt2img_token_button")); txt2img_textarea?.addEventListener("input", () => update_token_counter("txt2img_token_button"));
txt2img_textarea?.addEventListener("keyup", (event) => submit_prompt(event, "txt2img_generate"));
} }
if (!img2img_textarea) { if (!img2img_textarea) {
img2img_textarea = gradioApp().querySelector("#img2img_prompt > label > textarea"); img2img_textarea = gradioApp().querySelector("#img2img_prompt > label > textarea");
img2img_textarea?.addEventListener("input", () => update_token_counter("img2img_token_button")); img2img_textarea?.addEventListener("input", () => update_token_counter("img2img_token_button"));
img2img_textarea?.addEventListener("keyup", (event) => submit_prompt(event, "img2img_generate"));
} }
}) })
@ -220,14 +218,6 @@ function update_token_counter(button_id) {
token_timeout = setTimeout(() => gradioApp().getElementById(button_id)?.click(), wait_time); token_timeout = setTimeout(() => gradioApp().getElementById(button_id)?.click(), wait_time);
} }
function submit_prompt(event, generate_button_id) {
if (event.altKey && event.keyCode === 13) {
event.preventDefault();
gradioApp().getElementById(generate_button_id).click();
return;
}
}
function restart_reload(){ function restart_reload(){
document.body.innerHTML='<h1 style="font-family:monospace;margin-top:20%;color:lightgray;text-align:center;">Reloading...</h1>'; document.body.innerHTML='<h1 style="font-family:monospace;margin-top:20%;color:lightgray;text-align:center;">Reloading...</h1>';
setTimeout(function(){location.reload()},2000) setTimeout(function(){location.reload()},2000)

View file

@ -9,6 +9,7 @@ import platform
dir_repos = "repositories" dir_repos = "repositories"
python = sys.executable python = sys.executable
git = os.environ.get('GIT', "git") git = os.environ.get('GIT', "git")
index_url = os.environ.get('INDEX_URL', "")
def extract_arg(args, name): def extract_arg(args, name):
@ -57,7 +58,8 @@ def run_python(code, desc=None, errdesc=None):
def run_pip(args, desc=None): def run_pip(args, desc=None):
return run(f'"{python}" -m pip {args} --prefer-binary', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}") index_url_line = f' --index-url {index_url}' if index_url != '' else ''
return run(f'"{python}" -m pip {args} --prefer-binary{index_url_line}', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}")
def check_run_python(code): def check_run_python(code):
@ -102,6 +104,7 @@ def prepare_enviroment():
args = shlex.split(commandline_args) args = shlex.split(commandline_args)
args, skip_torch_cuda_test = extract_arg(args, '--skip-torch-cuda-test') args, skip_torch_cuda_test = extract_arg(args, '--skip-torch-cuda-test')
args, reinstall_xformers = extract_arg(args, '--reinstall-xformers')
xformers = '--xformers' in args xformers = '--xformers' in args
deepdanbooru = '--deepdanbooru' in args deepdanbooru = '--deepdanbooru' in args
ngrok = '--ngrok' in args ngrok = '--ngrok' in args
@ -126,9 +129,9 @@ def prepare_enviroment():
if not is_installed("clip"): if not is_installed("clip"):
run_pip(f"install {clip_package}", "clip") run_pip(f"install {clip_package}", "clip")
if not is_installed("xformers") and xformers and platform.python_version().startswith("3.10"): if (not is_installed("xformers") or reinstall_xformers) and xformers and platform.python_version().startswith("3.10"):
if platform.system() == "Windows": if platform.system() == "Windows":
run_pip("install https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/c/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl", "xformers") run_pip("install -U -I --no-deps https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl", "xformers")
elif platform.system() == "Linux": elif platform.system() == "Linux":
run_pip("install xformers", "xformers") run_pip("install xformers", "xformers")

View file

@ -102,7 +102,7 @@ def get_deepbooru_tags_model():
tags = dd.project.load_tags_from_project(model_path) tags = dd.project.load_tags_from_project(model_path)
model = dd.project.load_model_from_project( model = dd.project.load_model_from_project(
model_path, compile_model=True model_path, compile_model=False
) )
return model, tags return model, tags

View file

@ -182,7 +182,21 @@ def attention_CrossAttention_forward(self, x, context=None, mask=None):
return self.to_out(out) return self.to_out(out)
def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): def stack_conds(conds):
if len(conds) == 1:
return torch.stack(conds)
# same as in reconstruct_multicond_batch
token_count = max([x.shape[0] for x in conds])
for i in range(len(conds)):
if conds[i].shape[0] != token_count:
last_vector = conds[i][-1:]
last_vector_repeated = last_vector.repeat([token_count - conds[i].shape[0], 1])
conds[i] = torch.vstack([conds[i], last_vector_repeated])
return torch.stack(conds)
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
assert hypernetwork_name, 'hypernetwork not selected' assert hypernetwork_name, 'hypernetwork not selected'
path = shared.hypernetworks.get(hypernetwork_name, None) path = shared.hypernetworks.get(hypernetwork_name, None)
@ -211,7 +225,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"): with torch.autocast("cuda"):
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=512, height=512, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True) ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=512, height=512, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size)
if unload: if unload:
shared.sd_model.cond_stage_model.to(devices.cpu) shared.sd_model.cond_stage_model.to(devices.cpu)
@ -235,7 +249,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate) optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate)
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step) pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
for i, entry in pbar: for i, entries in pbar:
hypernetwork.step = i + ititial_step hypernetwork.step = i + ititial_step
scheduler.apply(optimizer, hypernetwork.step) scheduler.apply(optimizer, hypernetwork.step)
@ -246,26 +260,29 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
break break
with torch.autocast("cuda"): with torch.autocast("cuda"):
cond = entry.cond.to(devices.device) c = stack_conds([entry.cond for entry in entries]).to(devices.device)
x = entry.latent.to(devices.device) # c = torch.vstack([entry.cond for entry in entries]).to(devices.device)
loss = shared.sd_model(x.unsqueeze(0), cond)[0] x = torch.stack([entry.latent for entry in entries]).to(devices.device)
loss = shared.sd_model(x, c)[0]
del x del x
del cond del c
losses[hypernetwork.step % losses.shape[0]] = loss.item() losses[hypernetwork.step % losses.shape[0]] = loss.item()
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
optimizer.step() optimizer.step()
mean_loss = losses.mean()
pbar.set_description(f"loss: {losses.mean():.7f}") if torch.isnan(mean_loss):
raise RuntimeError("Loss diverged.")
pbar.set_description(f"loss: {mean_loss:.7f}")
if hypernetwork.step > 0 and hypernetwork_dir is not None and hypernetwork.step % save_hypernetwork_every == 0: if hypernetwork.step > 0 and hypernetwork_dir is not None and hypernetwork.step % save_hypernetwork_every == 0:
last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name}-{hypernetwork.step}.pt') last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name}-{hypernetwork.step}.pt')
hypernetwork.save(last_saved_file) hypernetwork.save(last_saved_file)
textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), { textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), {
"loss": f"{losses.mean():.7f}", "loss": f"{mean_loss:.7f}",
"learn_rate": scheduler.learn_rate "learn_rate": scheduler.learn_rate
}) })
@ -292,7 +309,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
p.width = preview_width p.width = preview_width
p.height = preview_height p.height = preview_height
else: else:
p.prompt = entry.cond_text p.prompt = entries[0].cond_text
p.steps = 20 p.steps = 20
preview_text = p.prompt preview_text = p.prompt
@ -313,9 +330,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
shared.state.textinfo = f""" shared.state.textinfo = f"""
<p> <p>
Loss: {losses.mean():.7f}<br/> Loss: {mean_loss:.7f}<br/>
Step: {hypernetwork.step}<br/> Step: {hypernetwork.step}<br/>
Last prompt: {html.escape(entry.cond_text)}<br/> Last prompt: {html.escape(entries[0].cond_text)}<br/>
Last saved embedding: {html.escape(last_saved_file)}<br/> Last saved embedding: {html.escape(last_saved_file)}<br/>
Last saved image: {html.escape(last_saved_image)}<br/> Last saved image: {html.escape(last_saved_image)}<br/>
</p> </p>

View file

@ -97,14 +97,16 @@ def delete_image(delete_num, tabname, dir_name, name, page_index, filenames, ima
def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict): def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict):
if tabname == "txt2img": if opts.outdir_samples != "":
dir_name = opts.outdir_samples
elif tabname == "txt2img":
dir_name = opts.outdir_txt2img_samples dir_name = opts.outdir_txt2img_samples
elif tabname == "img2img": elif tabname == "img2img":
dir_name = opts.outdir_img2img_samples dir_name = opts.outdir_img2img_samples
elif tabname == "extras": elif tabname == "extras":
dir_name = opts.outdir_extras_samples dir_name = opts.outdir_extras_samples
d = dir_name.split("/") d = dir_name.split("/")
dir_name = d[0] dir_name = "/" if dir_name.startswith("/") else d[0]
for p in d[1:]: for p in d[1:]:
dir_name = os.path.join(dir_name, p) dir_name = os.path.join(dir_name, p)
with gr.Row(): with gr.Row():

View file

@ -140,7 +140,7 @@ class Processed:
self.sampler_noise_scheduler_override = p.sampler_noise_scheduler_override self.sampler_noise_scheduler_override = p.sampler_noise_scheduler_override
self.prompt = self.prompt if type(self.prompt) != list else self.prompt[0] self.prompt = self.prompt if type(self.prompt) != list else self.prompt[0]
self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0] self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0]
self.seed = int(self.seed if type(self.seed) != list else self.seed[0]) self.seed = int(self.seed if type(self.seed) != list else self.seed[0]) if self.seed is not None else -1
self.subseed = int(self.subseed if type(self.subseed) != list else self.subseed[0]) if self.subseed is not None else -1 self.subseed = int(self.subseed if type(self.subseed) != list else self.subseed[0]) if self.subseed is not None else -1
self.all_prompts = all_prompts or [self.prompt] self.all_prompts = all_prompts or [self.prompt]
@ -528,7 +528,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
firstphase_height_truncated = int(scale * self.height) firstphase_height_truncated = int(scale * self.height)
else: else:
self.extra_generation_params["First pass size"] = f"{self.firstphase_width}x{self.firstphase_height}"
width_ratio = self.width / self.firstphase_width width_ratio = self.width / self.firstphase_width
height_ratio = self.height / self.firstphase_height height_ratio = self.height / self.firstphase_height
@ -540,6 +539,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
firstphase_width_truncated = self.firstphase_height * self.width / self.height firstphase_width_truncated = self.firstphase_height * self.width / self.height
firstphase_height_truncated = self.firstphase_height firstphase_height_truncated = self.firstphase_height
self.extra_generation_params["First pass size"] = f"{self.firstphase_width}x{self.firstphase_height}"
self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f
self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f
@ -557,11 +557,11 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2] samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2]
decoded_samples = decode_first_stage(self.sd_model, samples) if opts.use_scale_latent_for_hires_fix:
samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
if opts.upscaler_for_img2img is None or opts.upscaler_for_img2img == "None":
decoded_samples = torch.nn.functional.interpolate(decoded_samples, size=(self.height, self.width), mode="bilinear")
else: else:
decoded_samples = decode_first_stage(self.sd_model, samples)
lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0) lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
batch_images = [] batch_images = []
@ -578,7 +578,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
decoded_samples = decoded_samples.to(shared.device) decoded_samples = decoded_samples.to(shared.device)
decoded_samples = 2. * decoded_samples - 1. decoded_samples = 2. * decoded_samples - 1.
samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples)) samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples))
shared.state.nextjob() shared.state.nextjob()

View file

@ -24,7 +24,7 @@ def apply_optimizations():
ldm.modules.diffusionmodules.model.nonlinearity = silu ldm.modules.diffusionmodules.model.nonlinearity = silu
if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (8, 6)): if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)):
print("Applying xformers cross attention optimization.") print("Applying xformers cross attention optimization.")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward

View file

@ -1,4 +1,4 @@
import glob import collections
import os.path import os.path
import sys import sys
from collections import namedtuple from collections import namedtuple
@ -15,6 +15,7 @@ model_path = os.path.abspath(os.path.join(models_path, model_dir))
CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name', 'config']) CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name', 'config'])
checkpoints_list = {} checkpoints_list = {}
checkpoints_loaded = collections.OrderedDict()
try: try:
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start. # this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
@ -132,41 +133,45 @@ def load_model_weights(model, checkpoint_info):
checkpoint_file = checkpoint_info.filename checkpoint_file = checkpoint_info.filename
sd_model_hash = checkpoint_info.hash sd_model_hash = checkpoint_info.hash
print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}") if checkpoint_info not in checkpoints_loaded:
print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}")
pl_sd = torch.load(checkpoint_file, map_location=shared.weight_load_location) pl_sd = torch.load(checkpoint_file, map_location=shared.weight_load_location)
if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}")
if "global_step" in pl_sd: sd = get_state_dict_from_checkpoint(pl_sd)
print(f"Global Step: {pl_sd['global_step']}") model.load_state_dict(sd, strict=False)
sd = get_state_dict_from_checkpoint(pl_sd) if shared.cmd_opts.opt_channelslast:
model.to(memory_format=torch.channels_last)
model.load_state_dict(sd, strict=False) if not shared.cmd_opts.no_half:
model.half()
if shared.cmd_opts.opt_channelslast: devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16
model.to(memory_format=torch.channels_last) devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16
if not shared.cmd_opts.no_half: vae_file = os.path.splitext(checkpoint_file)[0] + ".vae.pt"
model.half()
devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16 if not os.path.exists(vae_file) and shared.cmd_opts.vae_path is not None:
devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16 vae_file = shared.cmd_opts.vae_path
vae_file = os.path.splitext(checkpoint_file)[0] + ".vae.pt" if os.path.exists(vae_file):
print(f"Loading VAE weights from: {vae_file}")
vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location)
vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"}
model.first_stage_model.load_state_dict(vae_dict)
if not os.path.exists(vae_file) and shared.cmd_opts.vae_path is not None: model.first_stage_model.to(devices.dtype_vae)
vae_file = shared.cmd_opts.vae_path
if os.path.exists(vae_file): checkpoints_loaded[checkpoint_info] = model.state_dict().copy()
print(f"Loading VAE weights from: {vae_file}") while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
checkpoints_loaded.popitem(last=False) # LRU
vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location) else:
print(f"Loading weights [{sd_model_hash}] from cache")
vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"} checkpoints_loaded.move_to_end(checkpoint_info)
model.load_state_dict(checkpoints_loaded[checkpoint_info])
model.first_stage_model.load_state_dict(vae_dict)
model.first_stage_model.to(devices.dtype_vae)
model.sd_model_hash = sd_model_hash model.sd_model_hash = sd_model_hash
model.sd_model_checkpoint = checkpoint_file model.sd_model_checkpoint = checkpoint_file
@ -205,6 +210,7 @@ def reload_model_weights(sd_model, info=None):
return return
if sd_model.sd_checkpoint_info.config != checkpoint_info.config: if sd_model.sd_checkpoint_info.config != checkpoint_info.config:
checkpoints_loaded.clear()
shared.sd_model = load_model() shared.sd_model = load_model()
return shared.sd_model return shared.sd_model

View file

@ -218,6 +218,7 @@ options_templates.update(options_section(('upscaling', "Upscaling"), {
"SWIN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}), "SWIN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
"ldsr_steps": OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}), "ldsr_steps": OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}),
"upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}), "upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}),
"use_scale_latent_for_hires_fix": OptionInfo(False, "Upscale latent space image when doing hires. fix"),
})) }))
options_templates.update(options_section(('face-restoration', "Face restoration"), { options_templates.update(options_section(('face-restoration', "Face restoration"), {
@ -242,6 +243,7 @@ options_templates.update(options_section(('training', "Training"), {
options_templates.update(options_section(('sd', "Stable Diffusion"), { options_templates.update(options_section(('sd', "Stable Diffusion"), {
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, refresh=sd_models.list_models), "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, refresh=sd_models.list_models),
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
"sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks), "sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
"sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}), "sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}),
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."), "img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
@ -255,7 +257,6 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
"filter_nsfw": OptionInfo(False, "Filter NSFW content"), "filter_nsfw": OptionInfo(False, "Filter NSFW content"),
'CLIP_stop_at_last_layers': OptionInfo(1, "Stop At last layers of CLIP model", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}), 'CLIP_stop_at_last_layers': OptionInfo(1, "Stop At last layers of CLIP model", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}),
"random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}), "random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}),
'quicksettings': OptionInfo("sd_model_checkpoint", "Quicksettings list"),
})) }))
options_templates.update(options_section(('interrogate', "Interrogate Options"), { options_templates.update(options_section(('interrogate', "Interrogate Options"), {
@ -283,6 +284,7 @@ options_templates.update(options_section(('ui', "User interface"), {
"js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"), "js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"),
"js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"), "js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),
"show_progress_in_title": OptionInfo(True, "Show generation progress in window title."), "show_progress_in_title": OptionInfo(True, "Show generation progress in window title."),
'quicksettings': OptionInfo("sd_model_checkpoint", "Quicksettings list"),
})) }))
options_templates.update(options_section(('sampler-params', "Sampler parameters"), { options_templates.update(options_section(('sampler-params', "Sampler parameters"), {

View file

@ -24,11 +24,12 @@ class DatasetEntry:
class PersonalizedBase(Dataset): class PersonalizedBase(Dataset):
def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None, include_cond=False): def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None, include_cond=False, batch_size=1):
re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex)>0 else None re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None
self.placeholder_token = placeholder_token self.placeholder_token = placeholder_token
self.batch_size = batch_size
self.width = width self.width = width
self.height = height self.height = height
self.flip = transforms.RandomHorizontalFlip(p=flip_p) self.flip = transforms.RandomHorizontalFlip(p=flip_p)
@ -78,13 +79,14 @@ class PersonalizedBase(Dataset):
if include_cond: if include_cond:
entry.cond_text = self.create_text(filename_text) entry.cond_text = self.create_text(filename_text)
entry.cond = cond_model([entry.cond_text]).to(devices.cpu) entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0)
self.dataset.append(entry) self.dataset.append(entry)
self.length = len(self.dataset) * repeats assert len(self.dataset) > 1, "No images have been found in the dataset."
self.length = len(self.dataset) * repeats // batch_size
self.initial_indexes = np.arange(self.length) % len(self.dataset) self.initial_indexes = np.arange(len(self.dataset))
self.indexes = None self.indexes = None
self.shuffle() self.shuffle()
@ -101,13 +103,19 @@ class PersonalizedBase(Dataset):
return self.length return self.length
def __getitem__(self, i): def __getitem__(self, i):
if i % len(self.dataset) == 0: res = []
self.shuffle()
index = self.indexes[i % len(self.indexes)] for j in range(self.batch_size):
entry = self.dataset[index] position = i * self.batch_size + j
if position % len(self.indexes) == 0:
self.shuffle()
if entry.cond is None: index = self.indexes[position % len(self.indexes)]
entry.cond_text = self.create_text(entry.filename_text) entry = self.dataset[index]
return entry if entry.cond is None:
entry.cond_text = self.create_text(entry.filename_text)
res.append(entry)
return res

View file

@ -88,9 +88,9 @@ class EmbeddingDatabase:
data = [] data = []
if filename.upper().endswith('.PNG'): if os.path.splitext(filename.upper())[-1] in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
embed_image = Image.open(path) embed_image = Image.open(path)
if 'sd-ti-embedding' in embed_image.text: if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text:
data = embedding_from_b64(embed_image.text['sd-ti-embedding']) data = embedding_from_b64(embed_image.text['sd-ti-embedding'])
name = data.get('name', name) name = data.get('name', name)
else: else:
@ -199,7 +199,7 @@ def write_loss(log_directory, filename, step, epoch_len, values):
}) })
def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
assert embedding_name, 'embedding not selected' assert embedding_name, 'embedding not selected'
shared.state.textinfo = "Initializing textual inversion training..." shared.state.textinfo = "Initializing textual inversion training..."
@ -231,7 +231,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"): with torch.autocast("cuda"):
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file) ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file, batch_size=batch_size)
hijack = sd_hijack.model_hijack hijack = sd_hijack.model_hijack
@ -242,6 +242,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
last_saved_file = "<none>" last_saved_file = "<none>"
last_saved_image = "<none>" last_saved_image = "<none>"
embedding_yet_to_be_embedded = False
ititial_step = embedding.step or 0 ititial_step = embedding.step or 0
if ititial_step > steps: if ititial_step > steps:
@ -251,7 +252,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate) optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate)
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step) pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
for i, entry in pbar: for i, entries in pbar:
embedding.step = i + ititial_step embedding.step = i + ititial_step
scheduler.apply(optimizer, embedding.step) scheduler.apply(optimizer, embedding.step)
@ -262,10 +263,9 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
break break
with torch.autocast("cuda"): with torch.autocast("cuda"):
c = cond_model([entry.cond_text]) c = cond_model([entry.cond_text for entry in entries])
x = torch.stack([entry.latent for entry in entries]).to(devices.device)
x = entry.latent.to(devices.device) loss = shared.sd_model(x, c)[0]
loss = shared.sd_model(x.unsqueeze(0), c)[0]
del x del x
losses[embedding.step % losses.shape[0]] = loss.item() losses[embedding.step % losses.shape[0]] = loss.item()
@ -282,6 +282,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
if embedding.step > 0 and embedding_dir is not None and embedding.step % save_embedding_every == 0: if embedding.step > 0 and embedding_dir is not None and embedding.step % save_embedding_every == 0:
last_saved_file = os.path.join(embedding_dir, f'{embedding_name}-{embedding.step}.pt') last_saved_file = os.path.join(embedding_dir, f'{embedding_name}-{embedding.step}.pt')
embedding.save(last_saved_file) embedding.save(last_saved_file)
embedding_yet_to_be_embedded = True
write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, len(ds), { write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, len(ds), {
"loss": f"{losses.mean():.7f}", "loss": f"{losses.mean():.7f}",
@ -307,7 +308,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
p.width = preview_width p.width = preview_width
p.height = preview_height p.height = preview_height
else: else:
p.prompt = entry.cond_text p.prompt = entries[0].cond_text
p.steps = 20 p.steps = 20
p.width = training_width p.width = training_width
p.height = training_height p.height = training_height
@ -319,7 +320,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
shared.state.current_image = image shared.state.current_image = image
if save_image_with_stored_embedding and os.path.exists(last_saved_file): if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{embedding.step}.png') last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{embedding.step}.png')
@ -328,15 +329,22 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
info.add_text("sd-ti-embedding", embedding_to_b64(data)) info.add_text("sd-ti-embedding", embedding_to_b64(data))
title = "<{}>".format(data.get('name', '???')) title = "<{}>".format(data.get('name', '???'))
try:
vectorSize = list(data['string_to_param'].values())[0].shape[0]
except Exception as e:
vectorSize = '?'
checkpoint = sd_models.select_checkpoint() checkpoint = sd_models.select_checkpoint()
footer_left = checkpoint.model_name footer_left = checkpoint.model_name
footer_mid = '[{}]'.format(checkpoint.hash) footer_mid = '[{}]'.format(checkpoint.hash)
footer_right = '{}'.format(embedding.step) footer_right = '{}v {}s'.format(vectorSize, embedding.step)
captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right) captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right)
captioned_image = insert_image_data_embed(captioned_image, data) captioned_image = insert_image_data_embed(captioned_image, data)
captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info) captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info)
embedding_yet_to_be_embedded = False
image.save(last_saved_image) image.save(last_saved_image)
@ -348,7 +356,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
<p> <p>
Loss: {losses.mean():.7f}<br/> Loss: {losses.mean():.7f}<br/>
Step: {embedding.step}<br/> Step: {embedding.step}<br/>
Last prompt: {html.escape(entry.cond_text)}<br/> Last prompt: {html.escape(entries[0].cond_text)}<br/>
Last saved embedding: {html.escape(last_saved_file)}<br/> Last saved embedding: {html.escape(last_saved_file)}<br/>
Last saved image: {html.escape(last_saved_image)}<br/> Last saved image: {html.escape(last_saved_image)}<br/>
</p> </p>

View file

@ -7,6 +7,7 @@ import mimetypes
import os import os
import random import random
import sys import sys
import tempfile
import time import time
import traceback import traceback
import platform import platform
@ -80,6 +81,8 @@ art_symbol = '\U0001f3a8' # 🎨
paste_symbol = '\u2199\ufe0f' # ↙ paste_symbol = '\u2199\ufe0f' # ↙
folder_symbol = '\U0001f4c2' # 📂 folder_symbol = '\U0001f4c2' # 📂
refresh_symbol = '\U0001f504' # 🔄 refresh_symbol = '\U0001f504' # 🔄
save_style_symbol = '\U0001f4be' # 💾
apply_style_symbol = '\U0001f4cb' # 📋
def plaintext_to_html(text): def plaintext_to_html(text):
@ -88,6 +91,14 @@ def plaintext_to_html(text):
def image_from_url_text(filedata): def image_from_url_text(filedata):
if type(filedata) == dict and filedata["is_file"]:
filename = filedata["name"]
tempdir = os.path.normpath(tempfile.gettempdir())
normfn = os.path.normpath(filename)
assert normfn.startswith(tempdir), 'trying to open image file not in temporary directory'
return Image.open(filename)
if type(filedata) == list: if type(filedata) == list:
if len(filedata) == 0: if len(filedata) == 0:
return None return None
@ -143,10 +154,7 @@ def save_files(js_data, images, do_make_zip, index):
writer.writerow(["prompt", "seed", "width", "height", "sampler", "cfgs", "steps", "filename", "negative_prompt"]) writer.writerow(["prompt", "seed", "width", "height", "sampler", "cfgs", "steps", "filename", "negative_prompt"])
for image_index, filedata in enumerate(images, start_index): for image_index, filedata in enumerate(images, start_index):
if filedata.startswith("data:image/png;base64,"): image = image_from_url_text(filedata)
filedata = filedata[len("data:image/png;base64,"):]
image = Image.open(io.BytesIO(base64.decodebytes(filedata.encode('utf-8'))))
is_grid = image_index < p.index_of_first_image is_grid = image_index < p.index_of_first_image
i = 0 if is_grid else (image_index - p.index_of_first_image) i = 0 if is_grid else (image_index - p.index_of_first_image)
@ -176,6 +184,23 @@ def save_files(js_data, images, do_make_zip, index):
return gr.File.update(value=fullfns, visible=True), '', '', plaintext_to_html(f"Saved: {filenames[0]}") return gr.File.update(value=fullfns, visible=True), '', '', plaintext_to_html(f"Saved: {filenames[0]}")
def save_pil_to_file(pil_image, dir=None):
use_metadata = False
metadata = PngImagePlugin.PngInfo()
for key, value in pil_image.info.items():
if isinstance(key, str) and isinstance(value, str):
metadata.add_text(key, value)
use_metadata = True
file_obj = tempfile.NamedTemporaryFile(delete=False, suffix=".png", dir=dir)
pil_image.save(file_obj, pnginfo=(metadata if use_metadata else None))
return file_obj
# override save to file function so that it also writes PNG info
gr.processing_utils.save_pil_to_file = save_pil_to_file
def wrap_gradio_call(func, extra_outputs=None): def wrap_gradio_call(func, extra_outputs=None):
def f(*args, extra_outputs_array=extra_outputs, **kwargs): def f(*args, extra_outputs_array=extra_outputs, **kwargs):
run_memmon = opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled run_memmon = opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled
@ -304,7 +329,7 @@ def visit(x, func, path=""):
def add_style(name: str, prompt: str, negative_prompt: str): def add_style(name: str, prompt: str, negative_prompt: str):
if name is None: if name is None:
return [gr_show(), gr_show()] return [gr_show() for x in range(4)]
style = modules.styles.PromptStyle(name, prompt, negative_prompt) style = modules.styles.PromptStyle(name, prompt, negative_prompt)
shared.prompt_styles.styles[style.name] = style shared.prompt_styles.styles[style.name] = style
@ -429,29 +454,38 @@ def create_toprow(is_img2img):
id_part = "img2img" if is_img2img else "txt2img" id_part = "img2img" if is_img2img else "txt2img"
with gr.Row(elem_id="toprow"): with gr.Row(elem_id="toprow"):
with gr.Column(scale=4): with gr.Column(scale=6):
with gr.Row(): with gr.Row():
with gr.Column(scale=80): with gr.Column(scale=80):
with gr.Row(): with gr.Row():
prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, placeholder="Prompt", lines=2) prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=2,
with gr.Column(scale=1, elem_id="roll_col"): placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)"
roll = gr.Button(value=art_symbol, elem_id="roll", visible=len(shared.artist_db.artists) > 0) )
paste = gr.Button(value=paste_symbol, elem_id="paste")
token_counter = gr.HTML(value="<span></span>", elem_id=f"{id_part}_token_counter")
token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
with gr.Column(scale=10, elem_id="style_pos_col"):
prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())), visible=len(shared.prompt_styles.styles) > 1)
with gr.Row(): with gr.Row():
with gr.Column(scale=8): with gr.Column(scale=80):
with gr.Row(): with gr.Row():
negative_prompt = gr.Textbox(label="Negative prompt", elem_id="negative_prompt", show_label=False, placeholder="Negative prompt", lines=2) negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=2,
with gr.Column(scale=1, elem_id="roll_col"): placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)"
sh = gr.Button(elem_id="sh", visible=True) )
with gr.Column(scale=1, elem_id="style_neg_col"): with gr.Column(scale=1, elem_id="roll_col"):
prompt_style2 = gr.Dropdown(label="Style 2", elem_id=f"{id_part}_style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())), visible=len(shared.prompt_styles.styles) > 1) roll = gr.Button(value=art_symbol, elem_id="roll", visible=len(shared.artist_db.artists) > 0)
paste = gr.Button(value=paste_symbol, elem_id="paste")
save_style = gr.Button(value=save_style_symbol, elem_id="style_create")
prompt_style_apply = gr.Button(value=apply_style_symbol, elem_id="style_apply")
token_counter = gr.HTML(value="<span></span>", elem_id=f"{id_part}_token_counter")
token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
button_interrogate = None
button_deepbooru = None
if is_img2img:
with gr.Column(scale=1, elem_id="interrogate_col"):
button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate")
if cmd_opts.deepdanbooru:
button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru")
with gr.Column(scale=1): with gr.Column(scale=1):
with gr.Row(): with gr.Row():
@ -471,20 +505,14 @@ def create_toprow(is_img2img):
outputs=[], outputs=[],
) )
with gr.Row(scale=1): with gr.Row():
if is_img2img: with gr.Column(scale=1, elem_id="style_pos_col"):
interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate") prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())))
if cmd_opts.deepdanbooru:
deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru")
else:
deepbooru = None
else:
interrogate = None
deepbooru = None
prompt_style_apply = gr.Button('Apply style', elem_id="style_apply")
save_style = gr.Button('Create style', elem_id="style_create")
return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, interrogate, deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button with gr.Column(scale=1, elem_id="style_neg_col"):
prompt_style2 = gr.Dropdown(label="Style 2", elem_id=f"{id_part}_style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())))
return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button
def setup_progressbar(progressbar, preview, id_part, textinfo=None): def setup_progressbar(progressbar, preview, id_part, textinfo=None):
@ -588,7 +616,7 @@ def create_ui(wrap_gradio_gpu_call):
txt2img_preview = gr.Image(elem_id='txt2img_preview', visible=False) txt2img_preview = gr.Image(elem_id='txt2img_preview', visible=False)
txt2img_gallery = gr.Gallery(label='Output', show_label=False, elem_id='txt2img_gallery').style(grid=4) txt2img_gallery = gr.Gallery(label='Output', show_label=False, elem_id='txt2img_gallery').style(grid=4)
with gr.Group(): with gr.Column():
with gr.Row(): with gr.Row():
save = gr.Button('Save') save = gr.Button('Save')
send_to_img2img = gr.Button('Send to img2img') send_to_img2img = gr.Button('Send to img2img')
@ -744,10 +772,10 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Tabs(elem_id="mode_img2img") as tabs_img2img_mode: with gr.Tabs(elem_id="mode_img2img") as tabs_img2img_mode:
with gr.TabItem('img2img', id='img2img'): with gr.TabItem('img2img', id='img2img'):
init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_img2img_tool) init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_img2img_tool).style(height=480)
with gr.TabItem('Inpaint', id='inpaint'): with gr.TabItem('Inpaint', id='inpaint'):
init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA") init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA").style(height=480)
init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_base") init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_base")
init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_mask") init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_mask")
@ -803,7 +831,7 @@ def create_ui(wrap_gradio_gpu_call):
img2img_preview = gr.Image(elem_id='img2img_preview', visible=False) img2img_preview = gr.Image(elem_id='img2img_preview', visible=False)
img2img_gallery = gr.Gallery(label='Output', show_label=False, elem_id='img2img_gallery').style(grid=4) img2img_gallery = gr.Gallery(label='Output', show_label=False, elem_id='img2img_gallery').style(grid=4)
with gr.Group(): with gr.Column():
with gr.Row(): with gr.Row():
save = gr.Button('Save') save = gr.Button('Save')
img2img_send_to_img2img = gr.Button('Send to img2img') img2img_send_to_img2img = gr.Button('Send to img2img')
@ -1166,6 +1194,7 @@ def create_ui(wrap_gradio_gpu_call):
train_embedding_name = gr.Dropdown(label='Embedding', choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())) train_embedding_name = gr.Dropdown(label='Embedding', choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', choices=[x for x in shared.hypernetworks.keys()]) train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', choices=[x for x in shared.hypernetworks.keys()])
learn_rate = gr.Textbox(label='Learning rate', placeholder="Learning rate", value="0.005") learn_rate = gr.Textbox(label='Learning rate', placeholder="Learning rate", value="0.005")
batch_size = gr.Number(label='Batch size', value=1, precision=0)
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images") dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion") log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt")) template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"))
@ -1244,6 +1273,7 @@ def create_ui(wrap_gradio_gpu_call):
inputs=[ inputs=[
train_embedding_name, train_embedding_name,
learn_rate, learn_rate,
batch_size,
dataset_directory, dataset_directory,
log_directory, log_directory,
training_width, training_width,
@ -1268,6 +1298,7 @@ def create_ui(wrap_gradio_gpu_call):
inputs=[ inputs=[
train_hypernetwork_name, train_hypernetwork_name,
learn_rate, learn_rate,
batch_size,
dataset_directory, dataset_directory,
log_directory, log_directory,
steps, steps,

View file

@ -4,7 +4,7 @@ fairscale==0.4.4
fonts fonts
font-roboto font-roboto
gfpgan gfpgan
gradio==3.4.1 gradio==3.5
invisible-watermark invisible-watermark
numpy numpy
omegaconf omegaconf

View file

@ -2,7 +2,7 @@ transformers==4.19.2
diffusers==0.3.0 diffusers==0.3.0
basicsr==1.4.2 basicsr==1.4.2
gfpgan==1.3.8 gfpgan==1.3.8
gradio==3.4.1 gradio==3.5
numpy==1.23.3 numpy==1.23.3
Pillow==9.2.0 Pillow==9.2.0
realesrgan==0.3.0 realesrgan==0.3.0

View file

@ -50,9 +50,9 @@ document.addEventListener("DOMContentLoaded", function() {
document.addEventListener('keydown', function(e) { document.addEventListener('keydown', function(e) {
var handled = false; var handled = false;
if (e.key !== undefined) { if (e.key !== undefined) {
if((e.key == "Enter" && (e.metaKey || e.ctrlKey))) handled = true; if((e.key == "Enter" && (e.metaKey || e.ctrlKey || e.altKey))) handled = true;
} else if (e.keyCode !== undefined) { } else if (e.keyCode !== undefined) {
if((e.keyCode == 13 && (e.metaKey || e.ctrlKey))) handled = true; if((e.keyCode == 13 && (e.metaKey || e.ctrlKey || e.altKey))) handled = true;
} }
if (handled) { if (handled) {
button = get_uiCurrentTabContent().querySelector('button[id$=_generate]'); button = get_uiCurrentTabContent().querySelector('button[id$=_generate]');

View file

@ -1,7 +1,9 @@
import copy
import math import math
import os import os
import sys import sys
import traceback import traceback
import shlex
import modules.scripts as scripts import modules.scripts as scripts
import gradio as gr import gradio as gr
@ -10,6 +12,75 @@ from modules.processing import Processed, process_images
from PIL import Image from PIL import Image
from modules.shared import opts, cmd_opts, state from modules.shared import opts, cmd_opts, state
def process_string_tag(tag):
return tag
def process_int_tag(tag):
return int(tag)
def process_float_tag(tag):
return float(tag)
def process_boolean_tag(tag):
return True if (tag == "true") else False
prompt_tags = {
"sd_model": None,
"outpath_samples": process_string_tag,
"outpath_grids": process_string_tag,
"prompt_for_display": process_string_tag,
"prompt": process_string_tag,
"negative_prompt": process_string_tag,
"styles": process_string_tag,
"seed": process_int_tag,
"subseed_strength": process_float_tag,
"subseed": process_int_tag,
"seed_resize_from_h": process_int_tag,
"seed_resize_from_w": process_int_tag,
"sampler_index": process_int_tag,
"batch_size": process_int_tag,
"n_iter": process_int_tag,
"steps": process_int_tag,
"cfg_scale": process_float_tag,
"width": process_int_tag,
"height": process_int_tag,
"restore_faces": process_boolean_tag,
"tiling": process_boolean_tag,
"do_not_save_samples": process_boolean_tag,
"do_not_save_grid": process_boolean_tag
}
def cmdargs(line):
args = shlex.split(line)
pos = 0
res = {}
while pos < len(args):
arg = args[pos]
assert arg.startswith("--"), f'must start with "--": {arg}'
tag = arg[2:]
func = prompt_tags.get(tag, None)
assert func, f'unknown commandline option: {arg}'
assert pos+1 < len(args), f'missing argument for command line option {arg}'
val = args[pos+1]
res[tag] = func(val)
pos += 2
return res
class Script(scripts.Script): class Script(scripts.Script):
def title(self): def title(self):
return "Prompts from file or textbox" return "Prompts from file or textbox"
@ -32,26 +103,48 @@ class Script(scripts.Script):
return [ gr.Checkbox.update(visible = True), gr.File.update(visible = not checkbox_txt), gr.TextArea.update(visible = checkbox_txt) ] return [ gr.Checkbox.update(visible = True), gr.File.update(visible = not checkbox_txt), gr.TextArea.update(visible = checkbox_txt) ]
def run(self, p, checkbox_txt, data: bytes, prompt_txt: str): def run(self, p, checkbox_txt, data: bytes, prompt_txt: str):
if (checkbox_txt): if checkbox_txt:
lines = [x.strip() for x in prompt_txt.splitlines()] lines = [x.strip() for x in prompt_txt.splitlines()]
else: else:
lines = [x.strip() for x in data.decode('utf8', errors='ignore').split("\n")] lines = [x.strip() for x in data.decode('utf8', errors='ignore').split("\n")]
lines = [x for x in lines if len(x) > 0] lines = [x for x in lines if len(x) > 0]
img_count = len(lines) * p.n_iter
batch_count = math.ceil(img_count / p.batch_size)
loop_count = math.ceil(batch_count / p.n_iter)
print(f"Will process {img_count} images in {batch_count} batches.")
p.do_not_save_grid = True p.do_not_save_grid = True
state.job_count = batch_count job_count = 0
jobs = []
for line in lines:
if "--" in line:
try:
args = cmdargs(line)
except Exception:
print(f"Error parsing line [line] as commandline:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
args = {"prompt": line}
else:
args = {"prompt": line}
n_iter = args.get("n_iter", 1)
if n_iter != 1:
job_count += n_iter
else:
job_count += 1
jobs.append(args)
print(f"Will process {len(lines)} lines in {job_count} jobs.")
state.job_count = job_count
images = [] images = []
for loop_no in range(loop_count): for n, args in enumerate(jobs):
state.job = f"{loop_no + 1} out of {loop_count}" state.job = f"{state.job_no + 1} out of {state.job_count}"
p.prompt = lines[loop_no*p.batch_size:(loop_no+1)*p.batch_size] * p.n_iter
proc = process_images(p) copy_p = copy.copy(p)
for k, v in args.items():
setattr(copy_p, k, v)
proc = process_images(copy_p)
images += proc.images images += proc.images
return Processed(p, images, p.seed, "") return Processed(p, images, p.seed, "")

View file

@ -12,7 +12,7 @@ import gradio as gr
from modules import images from modules import images
from modules.hypernetworks import hypernetwork from modules.hypernetworks import hypernetwork
from modules.processing import process_images, Processed, get_correct_sampler from modules.processing import process_images, Processed, get_correct_sampler, StableDiffusionProcessingTxt2Img
from modules.shared import opts, cmd_opts, state from modules.shared import opts, cmd_opts, state
import modules.shared as shared import modules.shared as shared
import modules.sd_samplers import modules.sd_samplers
@ -354,6 +354,9 @@ class Script(scripts.Script):
else: else:
total_steps = p.steps * len(xs) * len(ys) total_steps = p.steps * len(xs) * len(ys)
if isinstance(p, StableDiffusionProcessingTxt2Img) and p.enable_hr:
total_steps *= 2
print(f"X/Y plot will create {len(xs) * len(ys) * p.n_iter} images on a {len(xs)}x{len(ys)} grid. (Total steps to process: {total_steps * p.n_iter})") print(f"X/Y plot will create {len(xs) * len(ys) * p.n_iter} images on a {len(xs)}x{len(ys)} grid. (Total steps to process: {total_steps * p.n_iter})")
shared.total_tqdm.updateTotal(total_steps * p.n_iter) shared.total_tqdm.updateTotal(total_steps * p.n_iter)

View file

@ -115,7 +115,7 @@
padding: 0.4em 0; padding: 0.4em 0;
} }
#roll, #paste{ #roll, #paste, #style_create, #style_apply{
min-width: 2em; min-width: 2em;
min-height: 2em; min-height: 2em;
max-width: 2em; max-width: 2em;
@ -126,14 +126,14 @@
margin: 0.1em 0; margin: 0.1em 0;
} }
#style_apply, #style_create, #interrogate{ #interrogate_col{
margin: 0.75em 0.25em 0.25em 0.25em; min-width: 0 !important;
min-width: 5em; max-width: 8em !important;
} }
#interrogate, #deepbooru{
#style_apply, #style_create, #deepbooru{ margin: 0em 0.25em 0.9em 0.25em;
margin: 0.75em 0.25em 0.25em 0.25em; min-width: 8em;
min-width: 5em; max-width: 8em;
} }
#style_pos_col, #style_neg_col{ #style_pos_col, #style_neg_col{
@ -167,10 +167,6 @@ button{
align-self: stretch !important; align-self: stretch !important;
} }
#img2maskimg .h-60{
height: 30rem;
}
.overflow-hidden, .gr-panel{ .overflow-hidden, .gr-panel{
overflow: visible !important; overflow: visible !important;
} }
@ -241,13 +237,6 @@ fieldset span.text-gray-500, .gr-block.gr-box span.text-gray-500, label.block s
margin: 0; margin: 0;
} }
.gr-panel div.flex-col div.justify-between div{
position: absolute;
top: -0.1em;
right: 1em;
padding: 0 0.5em;
}
#settings .gr-panel div.flex-col div.justify-between div{ #settings .gr-panel div.flex-col div.justify-between div{
position: relative; position: relative;
z-index: 200; z-index: 200;
@ -320,6 +309,8 @@ input[type="range"]{
height: 100%; height: 100%;
overflow: auto; overflow: auto;
background-color: rgba(20, 20, 20, 0.95); background-color: rgba(20, 20, 20, 0.95);
user-select: none;
-webkit-user-select: none;
} }
.modalControls { .modalControls {
@ -443,10 +434,6 @@ input[type="range"]{
--tw-bg-opacity: 0 !important; --tw-bg-opacity: 0 !important;
} }
#img2img_image div.h-60{
height: 480px;
}
#context-menu{ #context-menu{
z-index:9999; z-index:9999;
position:absolute; position:absolute;
@ -521,3 +508,11 @@ canvas[key="mask"] {
.row.gr-compact{ .row.gr-compact{
overflow: visible; overflow: visible;
} }
#img2img_image, #img2img_image > .h-60, #img2img_image > .h-60 > div, #img2img_image > .h-60 > div > img,
img2maskimg, #img2maskimg > .h-60, #img2maskimg > .h-60 > div, #img2maskimg > .h-60 > div > img
{
height: 480px !important;
max-height: 480px !important;
min-height: 480px !important;
}

View file

@ -82,8 +82,8 @@ then
clone_dir="${PWD##*/}" clone_dir="${PWD##*/}"
fi fi
# Check prequisites # Check prerequisites
for preq in git python3 for preq in "${GIT}" "${python_cmd}"
do do
if ! hash "${preq}" &>/dev/null if ! hash "${preq}" &>/dev/null
then then