test
This commit is contained in:
commit
c87c3b9c11
28 changed files with 565 additions and 171 deletions
1
CODEOWNERS
Normal file
1
CODEOWNERS
Normal file
|
@ -0,0 +1 @@
|
||||||
|
* @AUTOMATIC1111
|
11
README.md
11
README.md
|
@ -28,10 +28,12 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web
|
||||||
- CodeFormer, face restoration tool as an alternative to GFPGAN
|
- CodeFormer, face restoration tool as an alternative to GFPGAN
|
||||||
- RealESRGAN, neural network upscaler
|
- RealESRGAN, neural network upscaler
|
||||||
- ESRGAN, neural network upscaler with a lot of third party models
|
- ESRGAN, neural network upscaler with a lot of third party models
|
||||||
- SwinIR, neural network upscaler
|
- SwinIR and Swin2SR([see here](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/2092)), neural network upscalers
|
||||||
- LDSR, Latent diffusion super resolution upscaling
|
- LDSR, Latent diffusion super resolution upscaling
|
||||||
- Resizing aspect ratio options
|
- Resizing aspect ratio options
|
||||||
- Sampling method selection
|
- Sampling method selection
|
||||||
|
- Adjust sampler eta values (noise multiplier)
|
||||||
|
- More advanced noise setting options
|
||||||
- Interrupt processing at any time
|
- Interrupt processing at any time
|
||||||
- 4GB video card support (also reports of 2GB working)
|
- 4GB video card support (also reports of 2GB working)
|
||||||
- Correct seeds for batches
|
- Correct seeds for batches
|
||||||
|
@ -67,6 +69,7 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web
|
||||||
- also supports weights for prompts: `a cat :1.2 AND a dog AND a penguin :2.2`
|
- also supports weights for prompts: `a cat :1.2 AND a dog AND a penguin :2.2`
|
||||||
- No token limit for prompts (original stable diffusion lets you use up to 75 tokens)
|
- No token limit for prompts (original stable diffusion lets you use up to 75 tokens)
|
||||||
- DeepDanbooru integration, creates danbooru style tags for anime prompts (add --deepdanbooru to commandline args)
|
- DeepDanbooru integration, creates danbooru style tags for anime prompts (add --deepdanbooru to commandline args)
|
||||||
|
- [xformers](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers), major speed increase for select cards: (add --xformers to commandline args)
|
||||||
|
|
||||||
## Installation and Running
|
## Installation and Running
|
||||||
Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for both [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended) and [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs.
|
Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for both [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended) and [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs.
|
||||||
|
@ -116,13 +119,17 @@ The documentation was moved from this README over to the project's [wiki](https:
|
||||||
- CodeFormer - https://github.com/sczhou/CodeFormer
|
- CodeFormer - https://github.com/sczhou/CodeFormer
|
||||||
- ESRGAN - https://github.com/xinntao/ESRGAN
|
- ESRGAN - https://github.com/xinntao/ESRGAN
|
||||||
- SwinIR - https://github.com/JingyunLiang/SwinIR
|
- SwinIR - https://github.com/JingyunLiang/SwinIR
|
||||||
|
- Swin2SR - https://github.com/mv-lab/swin2sr
|
||||||
- LDSR - https://github.com/Hafiidz/latent-diffusion
|
- LDSR - https://github.com/Hafiidz/latent-diffusion
|
||||||
- Ideas for optimizations - https://github.com/basujindal/stable-diffusion
|
- Ideas for optimizations - https://github.com/basujindal/stable-diffusion
|
||||||
- Doggettx - Cross Attention layer optimization - https://github.com/Doggettx/stable-diffusion, original idea for prompt editing.
|
- Doggettx - Cross Attention layer optimization - https://github.com/Doggettx/stable-diffusion, original idea for prompt editing.
|
||||||
|
- InvokeAI, lstein - Cross Attention layer optimization - https://github.com/invoke-ai/InvokeAI (originally http://github.com/lstein/stable-diffusion)
|
||||||
- Rinon Gal - Textual Inversion - https://github.com/rinongal/textual_inversion (we're not using his code, but we are using his ideas).
|
- Rinon Gal - Textual Inversion - https://github.com/rinongal/textual_inversion (we're not using his code, but we are using his ideas).
|
||||||
- Idea for SD upscale - https://github.com/jquesnelle/txt2imghd
|
- Idea for SD upscale - https://github.com/jquesnelle/txt2imghd
|
||||||
- Noise generation for outpainting mk2 - https://github.com/parlance-zz/g-diffuser-bot
|
- Noise generation for outpainting mk2 - https://github.com/parlance-zz/g-diffuser-bot
|
||||||
- CLIP interrogator idea and borrowing some code - https://github.com/pharmapsychotic/clip-interrogator
|
- CLIP interrogator idea and borrowing some code - https://github.com/pharmapsychotic/clip-interrogator
|
||||||
|
- Idea for Composable Diffusion - https://github.com/energy-based-model/Compositional-Visual-Generation-with-Composable-Diffusion-Models-PyTorch
|
||||||
|
- xformers - https://github.com/facebookresearch/xformers
|
||||||
|
- DeepDanbooru - interrogator for anime diffusers https://github.com/KichangKim/DeepDanbooru
|
||||||
- Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user.
|
- Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user.
|
||||||
- DeepDanbooru - interrogator for anime diffusors https://github.com/KichangKim/DeepDanbooru
|
|
||||||
- (You)
|
- (You)
|
||||||
|
|
|
@ -1045,7 +1045,6 @@ Bakemono Zukushi,0.67051035,anime
|
||||||
Lucy Madox Brown,0.67032814,fineart
|
Lucy Madox Brown,0.67032814,fineart
|
||||||
Paul Wonner,0.6700563,scribbles
|
Paul Wonner,0.6700563,scribbles
|
||||||
Guido Borelli Da Caluso,0.66966087,digipa-high-impact
|
Guido Borelli Da Caluso,0.66966087,digipa-high-impact
|
||||||
Guido Borelli da Caluso,0.66966087,digipa-high-impact
|
|
||||||
Emil Alzamora,0.5844039,nudity
|
Emil Alzamora,0.5844039,nudity
|
||||||
Heinrich Brocksieper,0.64469147,fineart
|
Heinrich Brocksieper,0.64469147,fineart
|
||||||
Dan Smith,0.669563,digipa-high-impact
|
Dan Smith,0.669563,digipa-high-impact
|
||||||
|
|
|
|
@ -3,9 +3,9 @@ channels:
|
||||||
- pytorch
|
- pytorch
|
||||||
- defaults
|
- defaults
|
||||||
dependencies:
|
dependencies:
|
||||||
- python=3.8.5
|
- python=3.10
|
||||||
- pip=20.3
|
- pip=22.2.2
|
||||||
- cudatoolkit=11.3
|
- cudatoolkit=11.3
|
||||||
- pytorch=1.11.0
|
- pytorch=1.12.1
|
||||||
- torchvision=0.12.0
|
- torchvision=0.13.1
|
||||||
- numpy=1.19.2
|
- numpy=1.23.1
|
|
@ -25,6 +25,7 @@ addEventListener('keydown', (event) => {
|
||||||
} else {
|
} else {
|
||||||
end = target.value.slice(selectionEnd + 1).indexOf(")") + 1;
|
end = target.value.slice(selectionEnd + 1).indexOf(")") + 1;
|
||||||
weight = parseFloat(target.value.slice(selectionEnd + 1, selectionEnd + 1 + end));
|
weight = parseFloat(target.value.slice(selectionEnd + 1, selectionEnd + 1 + end));
|
||||||
|
if (isNaN(weight)) return;
|
||||||
if (event.key == minus) weight -= 0.1;
|
if (event.key == minus) weight -= 0.1;
|
||||||
if (event.key == plus) weight += 0.1;
|
if (event.key == plus) weight += 0.1;
|
||||||
|
|
||||||
|
@ -38,4 +39,7 @@ addEventListener('keydown', (event) => {
|
||||||
target.selectionStart = selectionStart;
|
target.selectionStart = selectionStart;
|
||||||
target.selectionEnd = selectionEnd;
|
target.selectionEnd = selectionEnd;
|
||||||
}
|
}
|
||||||
|
// Since we've modified a Gradio Textbox component manually, we need to simulate an `input` DOM event to ensure its
|
||||||
|
// internal Svelte data binding remains in sync.
|
||||||
|
target.dispatchEvent(new Event("input", { bubbles: true }));
|
||||||
});
|
});
|
||||||
|
|
|
@ -80,7 +80,7 @@ titles = {
|
||||||
"Scale latent": "Uscale the image in latent space. Alternative is to produce the full image from latent representation, upscale that, and then move it back to latent space.",
|
"Scale latent": "Uscale the image in latent space. Alternative is to produce the full image from latent representation, upscale that, and then move it back to latent space.",
|
||||||
|
|
||||||
"Eta noise seed delta": "If this values is non-zero, it will be added to seed and used to initialize RNG for noises when using samplers with Eta. You can use this to produce even more variation of images, or you can use this to match images of other software if you know what you are doing.",
|
"Eta noise seed delta": "If this values is non-zero, it will be added to seed and used to initialize RNG for noises when using samplers with Eta. You can use this to produce even more variation of images, or you can use this to match images of other software if you know what you are doing.",
|
||||||
"Do not add watermark to images": "If this option is enabled, watermark will not be added to created images. Warning: if you do not add watermark, you may be bevaing in an unethical manner.",
|
"Do not add watermark to images": "If this option is enabled, watermark will not be added to created images. Warning: if you do not add watermark, you may be behaving in an unethical manner.",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -101,7 +101,8 @@ function create_tab_index_args(tabId, args){
|
||||||
}
|
}
|
||||||
|
|
||||||
function get_extras_tab_index(){
|
function get_extras_tab_index(){
|
||||||
return create_tab_index_args('mode_extras', arguments)
|
const [,,...args] = [...arguments]
|
||||||
|
return [get_tab_index('mode_extras'), get_tab_index('extras_resize_mode'), ...args]
|
||||||
}
|
}
|
||||||
|
|
||||||
function create_submit_args(args){
|
function create_submit_args(args){
|
||||||
|
|
|
@ -1,21 +1,75 @@
|
||||||
import os.path
|
import os.path
|
||||||
from concurrent.futures import ProcessPoolExecutor
|
from concurrent.futures import ProcessPoolExecutor
|
||||||
from multiprocessing import get_context
|
import multiprocessing
|
||||||
|
import time
|
||||||
|
|
||||||
|
def get_deepbooru_tags(pil_image):
|
||||||
|
"""
|
||||||
|
This method is for running only one image at a time for simple use. Used to the img2img interrogate.
|
||||||
|
"""
|
||||||
|
from modules import shared # prevents circular reference
|
||||||
|
create_deepbooru_process(shared.opts.interrogate_deepbooru_score_threshold, shared.opts.deepbooru_sort_alpha)
|
||||||
|
shared.deepbooru_process_return["value"] = -1
|
||||||
|
shared.deepbooru_process_queue.put(pil_image)
|
||||||
|
while shared.deepbooru_process_return["value"] == -1:
|
||||||
|
time.sleep(0.2)
|
||||||
|
tags = shared.deepbooru_process_return["value"]
|
||||||
|
release_process()
|
||||||
|
return tags
|
||||||
|
|
||||||
|
|
||||||
def _load_tf_and_return_tags(pil_image, threshold):
|
def deepbooru_process(queue, deepbooru_process_return, threshold, alpha_sort):
|
||||||
|
model, tags = get_deepbooru_tags_model()
|
||||||
|
while True: # while process is running, keep monitoring queue for new image
|
||||||
|
pil_image = queue.get()
|
||||||
|
if pil_image == "QUIT":
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
deepbooru_process_return["value"] = get_deepbooru_tags_from_model(model, tags, pil_image, threshold, alpha_sort)
|
||||||
|
|
||||||
|
|
||||||
|
def create_deepbooru_process(threshold, alpha_sort):
|
||||||
|
"""
|
||||||
|
Creates deepbooru process. A queue is created to send images into the process. This enables multiple images
|
||||||
|
to be processed in a row without reloading the model or creating a new process. To return the data, a shared
|
||||||
|
dictionary is created to hold the tags created. To wait for tags to be returned, a value of -1 is assigned
|
||||||
|
to the dictionary and the method adding the image to the queue should wait for this value to be updated with
|
||||||
|
the tags.
|
||||||
|
"""
|
||||||
|
from modules import shared # prevents circular reference
|
||||||
|
shared.deepbooru_process_manager = multiprocessing.Manager()
|
||||||
|
shared.deepbooru_process_queue = shared.deepbooru_process_manager.Queue()
|
||||||
|
shared.deepbooru_process_return = shared.deepbooru_process_manager.dict()
|
||||||
|
shared.deepbooru_process_return["value"] = -1
|
||||||
|
shared.deepbooru_process = multiprocessing.Process(target=deepbooru_process, args=(shared.deepbooru_process_queue, shared.deepbooru_process_return, threshold, alpha_sort))
|
||||||
|
shared.deepbooru_process.start()
|
||||||
|
|
||||||
|
|
||||||
|
def release_process():
|
||||||
|
"""
|
||||||
|
Stops the deepbooru process to return used memory
|
||||||
|
"""
|
||||||
|
from modules import shared # prevents circular reference
|
||||||
|
shared.deepbooru_process_queue.put("QUIT")
|
||||||
|
shared.deepbooru_process.join()
|
||||||
|
shared.deepbooru_process_queue = None
|
||||||
|
shared.deepbooru_process = None
|
||||||
|
shared.deepbooru_process_return = None
|
||||||
|
shared.deepbooru_process_manager = None
|
||||||
|
|
||||||
|
def get_deepbooru_tags_model():
|
||||||
import deepdanbooru as dd
|
import deepdanbooru as dd
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
this_folder = os.path.dirname(__file__)
|
this_folder = os.path.dirname(__file__)
|
||||||
model_path = os.path.abspath(os.path.join(this_folder, '..', 'models', 'deepbooru'))
|
model_path = os.path.abspath(os.path.join(this_folder, '..', 'models', 'deepbooru'))
|
||||||
if not os.path.exists(os.path.join(model_path, 'project.json')):
|
if not os.path.exists(os.path.join(model_path, 'project.json')):
|
||||||
# there is no point importing these every time
|
# there is no point importing these every time
|
||||||
import zipfile
|
import zipfile
|
||||||
from basicsr.utils.download_util import load_file_from_url
|
from basicsr.utils.download_util import load_file_from_url
|
||||||
load_file_from_url(r"https://github.com/KichangKim/DeepDanbooru/releases/download/v3-20211112-sgd-e28/deepdanbooru-v3-20211112-sgd-e28.zip",
|
load_file_from_url(
|
||||||
model_path)
|
r"https://github.com/KichangKim/DeepDanbooru/releases/download/v3-20211112-sgd-e28/deepdanbooru-v3-20211112-sgd-e28.zip",
|
||||||
|
model_path)
|
||||||
with zipfile.ZipFile(os.path.join(model_path, "deepdanbooru-v3-20211112-sgd-e28.zip"), "r") as zip_ref:
|
with zipfile.ZipFile(os.path.join(model_path, "deepdanbooru-v3-20211112-sgd-e28.zip"), "r") as zip_ref:
|
||||||
zip_ref.extractall(model_path)
|
zip_ref.extractall(model_path)
|
||||||
os.remove(os.path.join(model_path, "deepdanbooru-v3-20211112-sgd-e28.zip"))
|
os.remove(os.path.join(model_path, "deepdanbooru-v3-20211112-sgd-e28.zip"))
|
||||||
|
@ -24,7 +78,13 @@ def _load_tf_and_return_tags(pil_image, threshold):
|
||||||
model = dd.project.load_model_from_project(
|
model = dd.project.load_model_from_project(
|
||||||
model_path, compile_model=True
|
model_path, compile_model=True
|
||||||
)
|
)
|
||||||
|
return model, tags
|
||||||
|
|
||||||
|
|
||||||
|
def get_deepbooru_tags_from_model(model, tags, pil_image, threshold, alpha_sort):
|
||||||
|
import deepdanbooru as dd
|
||||||
|
import tensorflow as tf
|
||||||
|
import numpy as np
|
||||||
width = model.input_shape[2]
|
width = model.input_shape[2]
|
||||||
height = model.input_shape[1]
|
height = model.input_shape[1]
|
||||||
image = np.array(pil_image)
|
image = np.array(pil_image)
|
||||||
|
@ -46,28 +106,27 @@ def _load_tf_and_return_tags(pil_image, threshold):
|
||||||
|
|
||||||
for i, tag in enumerate(tags):
|
for i, tag in enumerate(tags):
|
||||||
result_dict[tag] = y[i]
|
result_dict[tag] = y[i]
|
||||||
result_tags_out = []
|
|
||||||
|
unsorted_tags_in_theshold = []
|
||||||
result_tags_print = []
|
result_tags_print = []
|
||||||
for tag in tags:
|
for tag in tags:
|
||||||
if result_dict[tag] >= threshold:
|
if result_dict[tag] >= threshold:
|
||||||
if tag.startswith("rating:"):
|
if tag.startswith("rating:"):
|
||||||
continue
|
continue
|
||||||
result_tags_out.append(tag)
|
unsorted_tags_in_theshold.append((result_dict[tag], tag))
|
||||||
result_tags_print.append(f'{result_dict[tag]} {tag}')
|
result_tags_print.append(f'{result_dict[tag]} {tag}')
|
||||||
|
|
||||||
|
# sort tags
|
||||||
|
result_tags_out = []
|
||||||
|
sort_ndx = 0
|
||||||
|
if alpha_sort:
|
||||||
|
sort_ndx = 1
|
||||||
|
|
||||||
|
# sort by reverse by likelihood and normal for alpha
|
||||||
|
unsorted_tags_in_theshold.sort(key=lambda y: y[sort_ndx], reverse=(not alpha_sort))
|
||||||
|
for weight, tag in unsorted_tags_in_theshold:
|
||||||
|
result_tags_out.append(tag)
|
||||||
|
|
||||||
print('\n'.join(sorted(result_tags_print, reverse=True)))
|
print('\n'.join(sorted(result_tags_print, reverse=True)))
|
||||||
|
|
||||||
return ', '.join(result_tags_out).replace('_', ' ').replace(':', ' ')
|
return ', '.join(result_tags_out).replace('_', ' ').replace(':', ' ')
|
||||||
|
|
||||||
|
|
||||||
def subprocess_init_no_cuda():
|
|
||||||
import os
|
|
||||||
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
|
|
||||||
|
|
||||||
|
|
||||||
def get_deepbooru_tags(pil_image, threshold=0.5):
|
|
||||||
context = get_context('spawn')
|
|
||||||
with ProcessPoolExecutor(initializer=subprocess_init_no_cuda, mp_context=context) as executor:
|
|
||||||
f = executor.submit(_load_tf_and_return_tags, pil_image, threshold, )
|
|
||||||
ret = f.result() # will rethrow any exceptions
|
|
||||||
return ret
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
import math
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -19,7 +20,7 @@ import gradio as gr
|
||||||
cached_images = {}
|
cached_images = {}
|
||||||
|
|
||||||
|
|
||||||
def run_extras(extras_mode, image, image_folder, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility):
|
def run_extras(extras_mode, resize_mode, image, image_folder, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility):
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
|
|
||||||
imageArr = []
|
imageArr = []
|
||||||
|
@ -67,8 +68,13 @@ def run_extras(extras_mode, image, image_folder, gfpgan_visibility, codeformer_v
|
||||||
info += f"CodeFormer w: {round(codeformer_weight, 2)}, CodeFormer visibility:{round(codeformer_visibility, 2)}\n"
|
info += f"CodeFormer w: {round(codeformer_weight, 2)}, CodeFormer visibility:{round(codeformer_visibility, 2)}\n"
|
||||||
image = res
|
image = res
|
||||||
|
|
||||||
|
if resize_mode == 1:
|
||||||
|
upscaling_resize = max(upscaling_resize_w/image.width, upscaling_resize_h/image.height)
|
||||||
|
crop_info = " (crop)" if upscaling_crop else ""
|
||||||
|
info += f"Resize to: {upscaling_resize_w:g}x{upscaling_resize_h:g}{crop_info}\n"
|
||||||
|
|
||||||
if upscaling_resize != 1.0:
|
if upscaling_resize != 1.0:
|
||||||
def upscale(image, scaler_index, resize):
|
def upscale(image, scaler_index, resize, mode, resize_w, resize_h, crop):
|
||||||
small = image.crop((image.width // 2, image.height // 2, image.width // 2 + 10, image.height // 2 + 10))
|
small = image.crop((image.width // 2, image.height // 2, image.width // 2 + 10, image.height // 2 + 10))
|
||||||
pixels = tuple(np.array(small).flatten().tolist())
|
pixels = tuple(np.array(small).flatten().tolist())
|
||||||
key = (resize, scaler_index, image.width, image.height, gfpgan_visibility, codeformer_visibility, codeformer_weight) + pixels
|
key = (resize, scaler_index, image.width, image.height, gfpgan_visibility, codeformer_visibility, codeformer_weight) + pixels
|
||||||
|
@ -77,15 +83,19 @@ def run_extras(extras_mode, image, image_folder, gfpgan_visibility, codeformer_v
|
||||||
if c is None:
|
if c is None:
|
||||||
upscaler = shared.sd_upscalers[scaler_index]
|
upscaler = shared.sd_upscalers[scaler_index]
|
||||||
c = upscaler.scaler.upscale(image, resize, upscaler.data_path)
|
c = upscaler.scaler.upscale(image, resize, upscaler.data_path)
|
||||||
|
if mode == 1 and crop:
|
||||||
|
cropped = Image.new("RGB", (resize_w, resize_h))
|
||||||
|
cropped.paste(c, box=(resize_w // 2 - c.width // 2, resize_h // 2 - c.height // 2))
|
||||||
|
c = cropped
|
||||||
cached_images[key] = c
|
cached_images[key] = c
|
||||||
|
|
||||||
return c
|
return c
|
||||||
|
|
||||||
info += f"Upscale: {round(upscaling_resize, 3)}, model:{shared.sd_upscalers[extras_upscaler_1].name}\n"
|
info += f"Upscale: {round(upscaling_resize, 3)}, model:{shared.sd_upscalers[extras_upscaler_1].name}\n"
|
||||||
res = upscale(image, extras_upscaler_1, upscaling_resize)
|
res = upscale(image, extras_upscaler_1, upscaling_resize, resize_mode, upscaling_resize_w, upscaling_resize_h, upscaling_crop)
|
||||||
|
|
||||||
if extras_upscaler_2 != 0 and extras_upscaler_2_visibility > 0:
|
if extras_upscaler_2 != 0 and extras_upscaler_2_visibility > 0:
|
||||||
res2 = upscale(image, extras_upscaler_2, upscaling_resize)
|
res2 = upscale(image, extras_upscaler_2, upscaling_resize, resize_mode, upscaling_resize_w, upscaling_resize_h, upscaling_crop)
|
||||||
info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {round(extras_upscaler_2_visibility, 3)}, model:{shared.sd_upscalers[extras_upscaler_2].name}\n"
|
info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {round(extras_upscaler_2_visibility, 3)}, model:{shared.sd_upscalers[extras_upscaler_2].name}\n"
|
||||||
res = Image.blend(res, res2, extras_upscaler_2_visibility)
|
res = Image.blend(res, res2, extras_upscaler_2_visibility)
|
||||||
|
|
||||||
|
@ -190,7 +200,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, interp_method, int
|
||||||
theta_0[key] = theta_func(theta_0[key], theta_1[key], (float(1.0) - interp_amount)) # Need to reverse the interp_amount to match the desired mix ration in the merged checkpoint
|
theta_0[key] = theta_func(theta_0[key], theta_1[key], (float(1.0) - interp_amount)) # Need to reverse the interp_amount to match the desired mix ration in the merged checkpoint
|
||||||
if save_as_half:
|
if save_as_half:
|
||||||
theta_0[key] = theta_0[key].half()
|
theta_0[key] = theta_0[key].half()
|
||||||
|
|
||||||
for key in theta_1.keys():
|
for key in theta_1.keys():
|
||||||
if 'model' in key and key not in theta_0:
|
if 'model' in key and key not in theta_0:
|
||||||
theta_0[key] = theta_1[key]
|
theta_0[key] = theta_1[key]
|
||||||
|
|
|
@ -14,6 +14,7 @@ import torch
|
||||||
from torch import einsum
|
from torch import einsum
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
import modules.textual_inversion.dataset
|
import modules.textual_inversion.dataset
|
||||||
|
from modules.textual_inversion.learn_schedule import LearnSchedule
|
||||||
|
|
||||||
|
|
||||||
class HypernetworkModule(torch.nn.Module):
|
class HypernetworkModule(torch.nn.Module):
|
||||||
|
@ -42,7 +43,7 @@ class Hypernetwork:
|
||||||
filename = None
|
filename = None
|
||||||
name = None
|
name = None
|
||||||
|
|
||||||
def __init__(self, name=None):
|
def __init__(self, name=None, enable_sizes=None):
|
||||||
self.filename = None
|
self.filename = None
|
||||||
self.name = name
|
self.name = name
|
||||||
self.layers = {}
|
self.layers = {}
|
||||||
|
@ -50,7 +51,7 @@ class Hypernetwork:
|
||||||
self.sd_checkpoint = None
|
self.sd_checkpoint = None
|
||||||
self.sd_checkpoint_name = None
|
self.sd_checkpoint_name = None
|
||||||
|
|
||||||
for size in [320, 640, 768, 1280]:
|
for size in enable_sizes or []:
|
||||||
self.layers[size] = (HypernetworkModule(size), HypernetworkModule(size))
|
self.layers[size] = (HypernetworkModule(size), HypernetworkModule(size))
|
||||||
|
|
||||||
def weights(self):
|
def weights(self):
|
||||||
|
@ -119,6 +120,17 @@ def load_hypernetwork(filename):
|
||||||
shared.loaded_hypernetwork = None
|
shared.loaded_hypernetwork = None
|
||||||
|
|
||||||
|
|
||||||
|
def find_closest_hypernetwork_name(search: str):
|
||||||
|
if not search:
|
||||||
|
return None
|
||||||
|
search = search.lower()
|
||||||
|
applicable = [name for name in shared.hypernetworks if search in name.lower()]
|
||||||
|
if not applicable:
|
||||||
|
return None
|
||||||
|
applicable = sorted(applicable, key=lambda name: len(name))
|
||||||
|
return applicable[0]
|
||||||
|
|
||||||
|
|
||||||
def apply_hypernetwork(hypernetwork, context, layer=None):
|
def apply_hypernetwork(hypernetwork, context, layer=None):
|
||||||
hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
|
hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
|
||||||
|
|
||||||
|
@ -163,7 +175,7 @@ def attention_CrossAttention_forward(self, x, context=None, mask=None):
|
||||||
|
|
||||||
|
|
||||||
def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, steps, create_image_every, save_hypernetwork_every, template_file, preview_image_prompt):
|
def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, steps, create_image_every, save_hypernetwork_every, template_file, preview_image_prompt):
|
||||||
assert hypernetwork_name, 'embedding not selected'
|
assert hypernetwork_name, 'hypernetwork not selected'
|
||||||
|
|
||||||
path = shared.hypernetworks.get(hypernetwork_name, None)
|
path = shared.hypernetworks.get(hypernetwork_name, None)
|
||||||
shared.loaded_hypernetwork = Hypernetwork()
|
shared.loaded_hypernetwork = Hypernetwork()
|
||||||
|
@ -175,6 +187,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
|
||||||
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
|
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
|
||||||
|
|
||||||
log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), hypernetwork_name)
|
log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), hypernetwork_name)
|
||||||
|
unload = shared.opts.unload_models_when_training
|
||||||
|
|
||||||
if save_hypernetwork_every > 0:
|
if save_hypernetwork_every > 0:
|
||||||
hypernetwork_dir = os.path.join(log_directory, "hypernetworks")
|
hypernetwork_dir = os.path.join(log_directory, "hypernetworks")
|
||||||
|
@ -188,19 +201,19 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
|
||||||
else:
|
else:
|
||||||
images_dir = None
|
images_dir = None
|
||||||
|
|
||||||
cond_model = shared.sd_model.cond_stage_model
|
|
||||||
|
|
||||||
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
||||||
with torch.autocast("cuda"):
|
with torch.autocast("cuda"):
|
||||||
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=512, height=512, repeats=1, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file)
|
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=512, height=512, repeats=1, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True)
|
||||||
|
|
||||||
|
if unload:
|
||||||
|
shared.sd_model.cond_stage_model.to(devices.cpu)
|
||||||
|
shared.sd_model.first_stage_model.to(devices.cpu)
|
||||||
|
|
||||||
hypernetwork = shared.loaded_hypernetwork
|
hypernetwork = shared.loaded_hypernetwork
|
||||||
weights = hypernetwork.weights()
|
weights = hypernetwork.weights()
|
||||||
for weight in weights:
|
for weight in weights:
|
||||||
weight.requires_grad = True
|
weight.requires_grad = True
|
||||||
|
|
||||||
optimizer = torch.optim.AdamW(weights, lr=learn_rate)
|
|
||||||
|
|
||||||
losses = torch.zeros((32,))
|
losses = torch.zeros((32,))
|
||||||
|
|
||||||
last_saved_file = "<none>"
|
last_saved_file = "<none>"
|
||||||
|
@ -210,22 +223,34 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
|
||||||
if ititial_step > steps:
|
if ititial_step > steps:
|
||||||
return hypernetwork, filename
|
return hypernetwork, filename
|
||||||
|
|
||||||
|
schedules = iter(LearnSchedule(learn_rate, steps, ititial_step))
|
||||||
|
(learn_rate, end_step) = next(schedules)
|
||||||
|
print(f'Training at rate of {learn_rate} until step {end_step}')
|
||||||
|
|
||||||
|
optimizer = torch.optim.AdamW(weights, lr=learn_rate)
|
||||||
|
|
||||||
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
|
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
|
||||||
for i, (x, text) in pbar:
|
for i, (x, text, cond) in pbar:
|
||||||
hypernetwork.step = i + ititial_step
|
hypernetwork.step = i + ititial_step
|
||||||
|
|
||||||
if hypernetwork.step > steps:
|
if hypernetwork.step > end_step:
|
||||||
break
|
try:
|
||||||
|
(learn_rate, end_step) = next(schedules)
|
||||||
|
except Exception:
|
||||||
|
break
|
||||||
|
tqdm.tqdm.write(f'Training at rate of {learn_rate} until step {end_step}')
|
||||||
|
for pg in optimizer.param_groups:
|
||||||
|
pg['lr'] = learn_rate
|
||||||
|
|
||||||
if shared.state.interrupted:
|
if shared.state.interrupted:
|
||||||
break
|
break
|
||||||
|
|
||||||
with torch.autocast("cuda"):
|
with torch.autocast("cuda"):
|
||||||
c = cond_model([text])
|
cond = cond.to(devices.device)
|
||||||
|
|
||||||
x = x.to(devices.device)
|
x = x.to(devices.device)
|
||||||
loss = shared.sd_model(x.unsqueeze(0), c)[0]
|
loss = shared.sd_model(x.unsqueeze(0), cond)[0]
|
||||||
del x
|
del x
|
||||||
|
del cond
|
||||||
|
|
||||||
losses[hypernetwork.step % losses.shape[0]] = loss.item()
|
losses[hypernetwork.step % losses.shape[0]] = loss.item()
|
||||||
|
|
||||||
|
@ -244,6 +269,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
|
||||||
|
|
||||||
preview_text = text if preview_image_prompt == "" else preview_image_prompt
|
preview_text = text if preview_image_prompt == "" else preview_image_prompt
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
shared.sd_model.cond_stage_model.to(devices.device)
|
||||||
|
shared.sd_model.first_stage_model.to(devices.device)
|
||||||
|
|
||||||
p = processing.StableDiffusionProcessingTxt2Img(
|
p = processing.StableDiffusionProcessingTxt2Img(
|
||||||
sd_model=shared.sd_model,
|
sd_model=shared.sd_model,
|
||||||
prompt=preview_text,
|
prompt=preview_text,
|
||||||
|
@ -255,6 +284,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
|
||||||
processed = processing.process_images(p)
|
processed = processing.process_images(p)
|
||||||
image = processed.images[0]
|
image = processed.images[0]
|
||||||
|
|
||||||
|
if unload:
|
||||||
|
shared.sd_model.cond_stage_model.to(devices.cpu)
|
||||||
|
shared.sd_model.first_stage_model.to(devices.cpu)
|
||||||
|
|
||||||
shared.state.current_image = image
|
shared.state.current_image = image
|
||||||
image.save(last_saved_image)
|
image.save(last_saved_image)
|
||||||
|
|
||||||
|
|
|
@ -5,15 +5,15 @@ import gradio as gr
|
||||||
|
|
||||||
import modules.textual_inversion.textual_inversion
|
import modules.textual_inversion.textual_inversion
|
||||||
import modules.textual_inversion.preprocess
|
import modules.textual_inversion.preprocess
|
||||||
from modules import sd_hijack, shared
|
from modules import sd_hijack, shared, devices
|
||||||
from modules.hypernetworks import hypernetwork
|
from modules.hypernetworks import hypernetwork
|
||||||
|
|
||||||
|
|
||||||
def create_hypernetwork(name):
|
def create_hypernetwork(name, enable_sizes):
|
||||||
fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
|
fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
|
||||||
assert not os.path.exists(fn), f"file {fn} already exists"
|
assert not os.path.exists(fn), f"file {fn} already exists"
|
||||||
|
|
||||||
hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(name=name)
|
hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(name=name, enable_sizes=[int(x) for x in enable_sizes])
|
||||||
hypernet.save(fn)
|
hypernet.save(fn)
|
||||||
|
|
||||||
shared.reload_hypernetworks()
|
shared.reload_hypernetworks()
|
||||||
|
@ -25,6 +25,8 @@ def train_hypernetwork(*args):
|
||||||
|
|
||||||
initial_hypernetwork = shared.loaded_hypernetwork
|
initial_hypernetwork = shared.loaded_hypernetwork
|
||||||
|
|
||||||
|
assert not shared.cmd_opts.lowvram, 'Training models with lowvram is not possible'
|
||||||
|
|
||||||
try:
|
try:
|
||||||
sd_hijack.undo_optimizations()
|
sd_hijack.undo_optimizations()
|
||||||
|
|
||||||
|
@ -39,5 +41,7 @@ Hypernetwork saved to {html.escape(filename)}
|
||||||
raise
|
raise
|
||||||
finally:
|
finally:
|
||||||
shared.loaded_hypernetwork = initial_hypernetwork
|
shared.loaded_hypernetwork = initial_hypernetwork
|
||||||
|
shared.sd_model.cond_stage_model.to(devices.device)
|
||||||
|
shared.sd_model.first_stage_model.to(devices.device)
|
||||||
sd_hijack.apply_optimizations()
|
sd_hijack.apply_optimizations()
|
||||||
|
|
||||||
|
|
|
@ -136,7 +136,7 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict):
|
||||||
|
|
||||||
def create_history_tabs(gr, opts, run_pnginfo, switch_dict):
|
def create_history_tabs(gr, opts, run_pnginfo, switch_dict):
|
||||||
with gr.Blocks(analytics_enabled=False) as images_history:
|
with gr.Blocks(analytics_enabled=False) as images_history:
|
||||||
with gr.Tabs() as tabs:
|
with gr.Tabs() as tabs:
|
||||||
with gr.Tab("txt2img history"):
|
with gr.Tab("txt2img history"):
|
||||||
with gr.Blocks(analytics_enabled=False) as images_history_txt2img:
|
with gr.Blocks(analytics_enabled=False) as images_history_txt2img:
|
||||||
show_images_history(gr, opts, "txt2img", run_pnginfo, switch_dict)
|
show_images_history(gr, opts, "txt2img", run_pnginfo, switch_dict)
|
||||||
|
|
|
@ -10,6 +10,7 @@ import torch
|
||||||
import numpy
|
import numpy
|
||||||
import _codecs
|
import _codecs
|
||||||
import zipfile
|
import zipfile
|
||||||
|
import re
|
||||||
|
|
||||||
|
|
||||||
# PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage
|
# PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage
|
||||||
|
@ -54,11 +55,27 @@ class RestrictedUnpickler(pickle.Unpickler):
|
||||||
raise pickle.UnpicklingError(f"global '{module}/{name}' is forbidden")
|
raise pickle.UnpicklingError(f"global '{module}/{name}' is forbidden")
|
||||||
|
|
||||||
|
|
||||||
|
allowed_zip_names = ["archive/data.pkl", "archive/version"]
|
||||||
|
allowed_zip_names_re = re.compile(r"^archive/data/\d+$")
|
||||||
|
|
||||||
|
|
||||||
|
def check_zip_filenames(filename, names):
|
||||||
|
for name in names:
|
||||||
|
if name in allowed_zip_names:
|
||||||
|
continue
|
||||||
|
if allowed_zip_names_re.match(name):
|
||||||
|
continue
|
||||||
|
|
||||||
|
raise Exception(f"bad file inside {filename}: {name}")
|
||||||
|
|
||||||
|
|
||||||
def check_pt(filename):
|
def check_pt(filename):
|
||||||
try:
|
try:
|
||||||
|
|
||||||
# new pytorch format is a zip file
|
# new pytorch format is a zip file
|
||||||
with zipfile.ZipFile(filename) as z:
|
with zipfile.ZipFile(filename) as z:
|
||||||
|
check_zip_filenames(filename, z.namelist())
|
||||||
|
|
||||||
with z.open('archive/data.pkl') as file:
|
with z.open('archive/data.pkl') as file:
|
||||||
unpickler = RestrictedUnpickler(file)
|
unpickler = RestrictedUnpickler(file)
|
||||||
unpickler.load()
|
unpickler.load()
|
||||||
|
|
|
@ -10,6 +10,7 @@ from torch.nn.functional import silu
|
||||||
import modules.textual_inversion.textual_inversion
|
import modules.textual_inversion.textual_inversion
|
||||||
from modules import prompt_parser, devices, sd_hijack_optimizations, shared
|
from modules import prompt_parser, devices, sd_hijack_optimizations, shared
|
||||||
from modules.shared import opts, device, cmd_opts
|
from modules.shared import opts, device, cmd_opts
|
||||||
|
from modules.sd_hijack_optimizations import invokeAI_mps_available
|
||||||
|
|
||||||
import ldm.modules.attention
|
import ldm.modules.attention
|
||||||
import ldm.modules.diffusionmodules.model
|
import ldm.modules.diffusionmodules.model
|
||||||
|
@ -30,8 +31,16 @@ def apply_optimizations():
|
||||||
elif cmd_opts.opt_split_attention_v1:
|
elif cmd_opts.opt_split_attention_v1:
|
||||||
print("Applying v1 cross attention optimization.")
|
print("Applying v1 cross attention optimization.")
|
||||||
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
|
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
|
||||||
|
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention_invokeai or not torch.cuda.is_available()):
|
||||||
|
if not invokeAI_mps_available and shared.device.type == 'mps':
|
||||||
|
print("The InvokeAI cross attention optimization for MPS requires the psutil package which is not installed.")
|
||||||
|
print("Applying v1 cross attention optimization.")
|
||||||
|
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
|
||||||
|
else:
|
||||||
|
print("Applying cross attention optimization (InvokeAI).")
|
||||||
|
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_invokeAI
|
||||||
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
|
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
|
||||||
print("Applying cross attention optimization.")
|
print("Applying cross attention optimization (Doggettx).")
|
||||||
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward
|
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward
|
||||||
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward
|
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward
|
||||||
|
|
||||||
|
@ -312,7 +321,17 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
||||||
fixes.append(fix[1])
|
fixes.append(fix[1])
|
||||||
self.hijack.fixes.append(fixes)
|
self.hijack.fixes.append(fixes)
|
||||||
|
|
||||||
z1 = self.process_tokens([x[:75] for x in remade_batch_tokens], [x[:75] for x in batch_multipliers])
|
tokens = []
|
||||||
|
multipliers = []
|
||||||
|
for j in range(len(remade_batch_tokens)):
|
||||||
|
if len(remade_batch_tokens[j]) > 0:
|
||||||
|
tokens.append(remade_batch_tokens[j][:75])
|
||||||
|
multipliers.append(batch_multipliers[j][:75])
|
||||||
|
else:
|
||||||
|
tokens.append([self.wrapped.tokenizer.eos_token_id] * 75)
|
||||||
|
multipliers.append([1.0] * 75)
|
||||||
|
|
||||||
|
z1 = self.process_tokens(tokens, multipliers)
|
||||||
z = z1 if z is None else torch.cat((z, z1), axis=-2)
|
z = z1 if z is None else torch.cat((z, z1), axis=-2)
|
||||||
|
|
||||||
remade_batch_tokens = rem_tokens
|
remade_batch_tokens = rem_tokens
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import math
|
import math
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
|
import importlib
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import einsum
|
from torch import einsum
|
||||||
|
@ -116,6 +117,102 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
|
||||||
|
|
||||||
return self.to_out(r2)
|
return self.to_out(r2)
|
||||||
|
|
||||||
|
|
||||||
|
def check_for_psutil():
|
||||||
|
try:
|
||||||
|
spec = importlib.util.find_spec('psutil')
|
||||||
|
return spec is not None
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
invokeAI_mps_available = check_for_psutil()
|
||||||
|
|
||||||
|
# -- Taken from https://github.com/invoke-ai/InvokeAI --
|
||||||
|
if invokeAI_mps_available:
|
||||||
|
import psutil
|
||||||
|
mem_total_gb = psutil.virtual_memory().total // (1 << 30)
|
||||||
|
|
||||||
|
def einsum_op_compvis(q, k, v):
|
||||||
|
s = einsum('b i d, b j d -> b i j', q, k)
|
||||||
|
s = s.softmax(dim=-1, dtype=s.dtype)
|
||||||
|
return einsum('b i j, b j d -> b i d', s, v)
|
||||||
|
|
||||||
|
def einsum_op_slice_0(q, k, v, slice_size):
|
||||||
|
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||||
|
for i in range(0, q.shape[0], slice_size):
|
||||||
|
end = i + slice_size
|
||||||
|
r[i:end] = einsum_op_compvis(q[i:end], k[i:end], v[i:end])
|
||||||
|
return r
|
||||||
|
|
||||||
|
def einsum_op_slice_1(q, k, v, slice_size):
|
||||||
|
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||||
|
for i in range(0, q.shape[1], slice_size):
|
||||||
|
end = i + slice_size
|
||||||
|
r[:, i:end] = einsum_op_compvis(q[:, i:end], k, v)
|
||||||
|
return r
|
||||||
|
|
||||||
|
def einsum_op_mps_v1(q, k, v):
|
||||||
|
if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096
|
||||||
|
return einsum_op_compvis(q, k, v)
|
||||||
|
else:
|
||||||
|
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
|
||||||
|
return einsum_op_slice_1(q, k, v, slice_size)
|
||||||
|
|
||||||
|
def einsum_op_mps_v2(q, k, v):
|
||||||
|
if mem_total_gb > 8 and q.shape[1] <= 4096:
|
||||||
|
return einsum_op_compvis(q, k, v)
|
||||||
|
else:
|
||||||
|
return einsum_op_slice_0(q, k, v, 1)
|
||||||
|
|
||||||
|
def einsum_op_tensor_mem(q, k, v, max_tensor_mb):
|
||||||
|
size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20)
|
||||||
|
if size_mb <= max_tensor_mb:
|
||||||
|
return einsum_op_compvis(q, k, v)
|
||||||
|
div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length()
|
||||||
|
if div <= q.shape[0]:
|
||||||
|
return einsum_op_slice_0(q, k, v, q.shape[0] // div)
|
||||||
|
return einsum_op_slice_1(q, k, v, max(q.shape[1] // div, 1))
|
||||||
|
|
||||||
|
def einsum_op_cuda(q, k, v):
|
||||||
|
stats = torch.cuda.memory_stats(q.device)
|
||||||
|
mem_active = stats['active_bytes.all.current']
|
||||||
|
mem_reserved = stats['reserved_bytes.all.current']
|
||||||
|
mem_free_cuda, _ = torch.cuda.mem_get_info(q.device)
|
||||||
|
mem_free_torch = mem_reserved - mem_active
|
||||||
|
mem_free_total = mem_free_cuda + mem_free_torch
|
||||||
|
# Divide factor of safety as there's copying and fragmentation
|
||||||
|
return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
|
||||||
|
|
||||||
|
def einsum_op(q, k, v):
|
||||||
|
if q.device.type == 'cuda':
|
||||||
|
return einsum_op_cuda(q, k, v)
|
||||||
|
|
||||||
|
if q.device.type == 'mps':
|
||||||
|
if mem_total_gb >= 32:
|
||||||
|
return einsum_op_mps_v1(q, k, v)
|
||||||
|
return einsum_op_mps_v2(q, k, v)
|
||||||
|
|
||||||
|
# Smaller slices are faster due to L2/L3/SLC caches.
|
||||||
|
# Tested on i7 with 8MB L3 cache.
|
||||||
|
return einsum_op_tensor_mem(q, k, v, 32)
|
||||||
|
|
||||||
|
def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None):
|
||||||
|
h = self.heads
|
||||||
|
|
||||||
|
q = self.to_q(x)
|
||||||
|
context = default(context, x)
|
||||||
|
|
||||||
|
context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
|
||||||
|
k = self.to_k(context_k) * self.scale
|
||||||
|
v = self.to_v(context_v)
|
||||||
|
del context, context_k, context_v, x
|
||||||
|
|
||||||
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
||||||
|
r = einsum_op(q, k, v)
|
||||||
|
return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h))
|
||||||
|
|
||||||
|
# -- End of code from https://github.com/invoke-ai/InvokeAI --
|
||||||
|
|
||||||
def xformers_attention_forward(self, x, context=None, mask=None):
|
def xformers_attention_forward(self, x, context=None, mask=None):
|
||||||
h = self.heads
|
h = self.heads
|
||||||
q_in = self.to_q(x)
|
q_in = self.to_q(x)
|
||||||
|
|
|
@ -50,9 +50,10 @@ parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with
|
||||||
parser.add_argument("--xformers", action='store_true', help="enable xformers for cross attention layers")
|
parser.add_argument("--xformers", action='store_true', help="enable xformers for cross attention layers")
|
||||||
parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work")
|
parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work")
|
||||||
parser.add_argument("--deepdanbooru", action='store_true', help="enable deepdanbooru interrogator")
|
parser.add_argument("--deepdanbooru", action='store_true', help="enable deepdanbooru interrogator")
|
||||||
parser.add_argument("--opt-split-attention", action='store_true', help="force-enables cross-attention layer optimization. By default, it's on for torch.cuda and off for other torch devices.")
|
parser.add_argument("--opt-split-attention", action='store_true', help="force-enables Doggettx's cross-attention layer optimization. By default, it's on for torch cuda.")
|
||||||
parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
|
parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.")
|
||||||
parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
|
parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
|
||||||
|
parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
|
||||||
parser.add_argument("--use-cpu", nargs='+',choices=['SD', 'GFPGAN', 'BSRGAN', 'ESRGAN', 'SCUNet', 'CodeFormer'], help="use CPU as torch device for specified modules", default=[])
|
parser.add_argument("--use-cpu", nargs='+',choices=['SD', 'GFPGAN', 'BSRGAN', 'ESRGAN', 'SCUNet', 'CodeFormer'], help="use CPU as torch device for specified modules", default=[])
|
||||||
parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
|
parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
|
||||||
parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None)
|
parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None)
|
||||||
|
@ -85,6 +86,7 @@ parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram
|
||||||
xformers_available = False
|
xformers_available = False
|
||||||
config_filename = cmd_opts.ui_settings_file
|
config_filename = cmd_opts.ui_settings_file
|
||||||
|
|
||||||
|
os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True)
|
||||||
hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir)
|
hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir)
|
||||||
loaded_hypernetwork = None
|
loaded_hypernetwork = None
|
||||||
|
|
||||||
|
@ -227,6 +229,10 @@ options_templates.update(options_section(('system', "System"), {
|
||||||
"multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."),
|
"multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
options_templates.update(options_section(('training', "Training"), {
|
||||||
|
"unload_models_when_training": OptionInfo(False, "Unload VAE and CLIP from VRAM when training"),
|
||||||
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
||||||
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, show_on_main_page=True),
|
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, show_on_main_page=True),
|
||||||
"sd_hypernetwork": OptionInfo("None", "Stable Diffusion finetune hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}),
|
"sd_hypernetwork": OptionInfo("None", "Stable Diffusion finetune hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}),
|
||||||
|
@ -249,8 +255,8 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"),
|
||||||
"interrogate_clip_num_beams": OptionInfo(1, "Interrogate: num_beams for BLIP", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}),
|
"interrogate_clip_num_beams": OptionInfo(1, "Interrogate: num_beams for BLIP", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}),
|
||||||
"interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum description length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}),
|
"interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum description length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}),
|
||||||
"interrogate_clip_max_length": OptionInfo(48, "Interrogate: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}),
|
"interrogate_clip_max_length": OptionInfo(48, "Interrogate: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}),
|
||||||
"interrogate_clip_dict_limit": OptionInfo(1500, "Interrogate: maximum number of lines in text file (0 = No limit)"),
|
|
||||||
"interrogate_deepbooru_score_threshold": OptionInfo(0.5, "Interrogate: deepbooru score threshold", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),
|
"interrogate_deepbooru_score_threshold": OptionInfo(0.5, "Interrogate: deepbooru score threshold", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),
|
||||||
|
"deepbooru_sort_alpha": OptionInfo(True, "Interrogate: deepbooru sort alphabetically"),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('ui', "User interface"), {
|
options_templates.update(options_section(('ui', "User interface"), {
|
||||||
|
|
|
@ -8,14 +8,14 @@ from torchvision import transforms
|
||||||
|
|
||||||
import random
|
import random
|
||||||
import tqdm
|
import tqdm
|
||||||
from modules import devices
|
from modules import devices, shared
|
||||||
import re
|
import re
|
||||||
|
|
||||||
re_tag = re.compile(r"[a-zA-Z][_\w\d()]+")
|
re_tag = re.compile(r"[a-zA-Z][_\w\d()]+")
|
||||||
|
|
||||||
|
|
||||||
class PersonalizedBase(Dataset):
|
class PersonalizedBase(Dataset):
|
||||||
def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None):
|
def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None, include_cond=False):
|
||||||
|
|
||||||
self.placeholder_token = placeholder_token
|
self.placeholder_token = placeholder_token
|
||||||
|
|
||||||
|
@ -32,12 +32,15 @@ class PersonalizedBase(Dataset):
|
||||||
|
|
||||||
assert data_root, 'dataset directory not specified'
|
assert data_root, 'dataset directory not specified'
|
||||||
|
|
||||||
|
cond_model = shared.sd_model.cond_stage_model
|
||||||
|
|
||||||
self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
|
self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
|
||||||
print("Preparing dataset...")
|
print("Preparing dataset...")
|
||||||
for path in tqdm.tqdm(self.image_paths):
|
for path in tqdm.tqdm(self.image_paths):
|
||||||
image = Image.open(path)
|
try:
|
||||||
image = image.convert('RGB')
|
image = Image.open(path).convert('RGB').resize((self.width, self.height), PIL.Image.BICUBIC)
|
||||||
image = image.resize((self.width, self.height), PIL.Image.BICUBIC)
|
except Exception:
|
||||||
|
continue
|
||||||
|
|
||||||
filename = os.path.basename(path)
|
filename = os.path.basename(path)
|
||||||
filename_tokens = os.path.splitext(filename)[0]
|
filename_tokens = os.path.splitext(filename)[0]
|
||||||
|
@ -52,7 +55,13 @@ class PersonalizedBase(Dataset):
|
||||||
init_latent = model.get_first_stage_encoding(model.encode_first_stage(torchdata.unsqueeze(dim=0))).squeeze()
|
init_latent = model.get_first_stage_encoding(model.encode_first_stage(torchdata.unsqueeze(dim=0))).squeeze()
|
||||||
init_latent = init_latent.to(devices.cpu)
|
init_latent = init_latent.to(devices.cpu)
|
||||||
|
|
||||||
self.dataset.append((init_latent, filename_tokens))
|
if include_cond:
|
||||||
|
text = self.create_text(filename_tokens)
|
||||||
|
cond = cond_model([text]).to(devices.cpu)
|
||||||
|
else:
|
||||||
|
cond = None
|
||||||
|
|
||||||
|
self.dataset.append((init_latent, filename_tokens, cond))
|
||||||
|
|
||||||
self.length = len(self.dataset) * repeats
|
self.length = len(self.dataset) * repeats
|
||||||
|
|
||||||
|
@ -63,6 +72,12 @@ class PersonalizedBase(Dataset):
|
||||||
def shuffle(self):
|
def shuffle(self):
|
||||||
self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0])]
|
self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0])]
|
||||||
|
|
||||||
|
def create_text(self, filename_tokens):
|
||||||
|
text = random.choice(self.lines)
|
||||||
|
text = text.replace("[name]", self.placeholder_token)
|
||||||
|
text = text.replace("[filewords]", ' '.join(filename_tokens))
|
||||||
|
return text
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self.length
|
return self.length
|
||||||
|
|
||||||
|
@ -71,10 +86,7 @@ class PersonalizedBase(Dataset):
|
||||||
self.shuffle()
|
self.shuffle()
|
||||||
|
|
||||||
index = self.indexes[i % len(self.indexes)]
|
index = self.indexes[i % len(self.indexes)]
|
||||||
x, filename_tokens = self.dataset[index]
|
x, filename_tokens, cond = self.dataset[index]
|
||||||
|
|
||||||
text = random.choice(self.lines)
|
text = self.create_text(filename_tokens)
|
||||||
text = text.replace("[name]", self.placeholder_token)
|
return x, text, cond
|
||||||
text = text.replace("[filewords]", ' '.join(filename_tokens))
|
|
||||||
|
|
||||||
return x, text
|
|
||||||
|
|
34
modules/textual_inversion/learn_schedule.py
Normal file
34
modules/textual_inversion/learn_schedule.py
Normal file
|
@ -0,0 +1,34 @@
|
||||||
|
|
||||||
|
class LearnSchedule:
|
||||||
|
def __init__(self, learn_rate, max_steps, cur_step=0):
|
||||||
|
pairs = learn_rate.split(',')
|
||||||
|
self.rates = []
|
||||||
|
self.it = 0
|
||||||
|
self.maxit = 0
|
||||||
|
for i, pair in enumerate(pairs):
|
||||||
|
tmp = pair.split(':')
|
||||||
|
if len(tmp) == 2:
|
||||||
|
step = int(tmp[1])
|
||||||
|
if step > cur_step:
|
||||||
|
self.rates.append((float(tmp[0]), min(step, max_steps)))
|
||||||
|
self.maxit += 1
|
||||||
|
if step > max_steps:
|
||||||
|
return
|
||||||
|
elif step == -1:
|
||||||
|
self.rates.append((float(tmp[0]), max_steps))
|
||||||
|
self.maxit += 1
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
self.rates.append((float(tmp[0]), max_steps))
|
||||||
|
self.maxit += 1
|
||||||
|
return
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __next__(self):
|
||||||
|
if self.it < self.maxit:
|
||||||
|
self.it += 1
|
||||||
|
return self.rates[self.it - 1]
|
||||||
|
else:
|
||||||
|
raise StopIteration
|
|
@ -3,11 +3,14 @@ from PIL import Image, ImageOps
|
||||||
import platform
|
import platform
|
||||||
import sys
|
import sys
|
||||||
import tqdm
|
import tqdm
|
||||||
|
import time
|
||||||
|
|
||||||
from modules import shared, images
|
from modules import shared, images
|
||||||
|
from modules.shared import opts, cmd_opts
|
||||||
|
if cmd_opts.deepdanbooru:
|
||||||
|
import modules.deepbooru as deepbooru
|
||||||
|
|
||||||
|
def preprocess(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False):
|
||||||
def preprocess(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption):
|
|
||||||
width = process_width
|
width = process_width
|
||||||
height = process_height
|
height = process_height
|
||||||
src = os.path.abspath(process_src)
|
src = os.path.abspath(process_src)
|
||||||
|
@ -25,10 +28,21 @@ def preprocess(process_src, process_dst, process_width, process_height, process_
|
||||||
if process_caption:
|
if process_caption:
|
||||||
shared.interrogator.load()
|
shared.interrogator.load()
|
||||||
|
|
||||||
|
if process_caption_deepbooru:
|
||||||
|
deepbooru.create_deepbooru_process(opts.interrogate_deepbooru_score_threshold, opts.deepbooru_sort_alpha)
|
||||||
|
|
||||||
def save_pic_with_caption(image, index):
|
def save_pic_with_caption(image, index):
|
||||||
if process_caption:
|
if process_caption:
|
||||||
caption = "-" + shared.interrogator.generate_caption(image)
|
caption = "-" + shared.interrogator.generate_caption(image)
|
||||||
caption = sanitize_caption(os.path.join(dst, f"{index:05}-{subindex[0]}"), caption, ".png")
|
caption = sanitize_caption(os.path.join(dst, f"{index:05}-{subindex[0]}"), caption, ".png")
|
||||||
|
elif process_caption_deepbooru:
|
||||||
|
shared.deepbooru_process_return["value"] = -1
|
||||||
|
shared.deepbooru_process_queue.put(image)
|
||||||
|
while shared.deepbooru_process_return["value"] == -1:
|
||||||
|
time.sleep(0.2)
|
||||||
|
caption = "-" + shared.deepbooru_process_return["value"]
|
||||||
|
caption = sanitize_caption(os.path.join(dst, f"{index:05}-{subindex[0]}"), caption, ".png")
|
||||||
|
shared.deepbooru_process_return["value"] = -1
|
||||||
else:
|
else:
|
||||||
caption = filename
|
caption = filename
|
||||||
caption = os.path.splitext(caption)[0]
|
caption = os.path.splitext(caption)[0]
|
||||||
|
@ -46,7 +60,10 @@ def preprocess(process_src, process_dst, process_width, process_height, process_
|
||||||
for index, imagefile in enumerate(tqdm.tqdm(files)):
|
for index, imagefile in enumerate(tqdm.tqdm(files)):
|
||||||
subindex = [0]
|
subindex = [0]
|
||||||
filename = os.path.join(src, imagefile)
|
filename = os.path.join(src, imagefile)
|
||||||
img = Image.open(filename).convert("RGB")
|
try:
|
||||||
|
img = Image.open(filename).convert("RGB")
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
|
||||||
if shared.state.interrupted:
|
if shared.state.interrupted:
|
||||||
break
|
break
|
||||||
|
@ -80,6 +97,10 @@ def preprocess(process_src, process_dst, process_width, process_height, process_
|
||||||
if process_caption:
|
if process_caption:
|
||||||
shared.interrogator.send_blip_to_ram()
|
shared.interrogator.send_blip_to_ram()
|
||||||
|
|
||||||
|
if process_caption_deepbooru:
|
||||||
|
deepbooru.release_process()
|
||||||
|
|
||||||
|
|
||||||
def sanitize_caption(base_path, original_caption, suffix):
|
def sanitize_caption(base_path, original_caption, suffix):
|
||||||
operating_system = platform.system().lower()
|
operating_system = platform.system().lower()
|
||||||
if (operating_system == "windows"):
|
if (operating_system == "windows"):
|
||||||
|
|
|
@ -10,6 +10,7 @@ import datetime
|
||||||
|
|
||||||
from modules import shared, devices, sd_hijack, processing, sd_models
|
from modules import shared, devices, sd_hijack, processing, sd_models
|
||||||
import modules.textual_inversion.dataset
|
import modules.textual_inversion.dataset
|
||||||
|
from modules.textual_inversion.learn_schedule import LearnSchedule
|
||||||
|
|
||||||
|
|
||||||
class Embedding:
|
class Embedding:
|
||||||
|
@ -189,8 +190,6 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
|
||||||
embedding = hijack.embedding_db.word_embeddings[embedding_name]
|
embedding = hijack.embedding_db.word_embeddings[embedding_name]
|
||||||
embedding.vec.requires_grad = True
|
embedding.vec.requires_grad = True
|
||||||
|
|
||||||
optimizer = torch.optim.AdamW([embedding.vec], lr=learn_rate)
|
|
||||||
|
|
||||||
losses = torch.zeros((32,))
|
losses = torch.zeros((32,))
|
||||||
|
|
||||||
last_saved_file = "<none>"
|
last_saved_file = "<none>"
|
||||||
|
@ -200,15 +199,24 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
|
||||||
if ititial_step > steps:
|
if ititial_step > steps:
|
||||||
return embedding, filename
|
return embedding, filename
|
||||||
|
|
||||||
tr_img_len = len([os.path.join(data_root, file_path) for file_path in os.listdir(data_root)])
|
schedules = iter(LearnSchedule(learn_rate, steps, ititial_step))
|
||||||
epoch_len = (tr_img_len * num_repeats) + tr_img_len
|
(learn_rate, end_step) = next(schedules)
|
||||||
|
print(f'Training at rate of {learn_rate} until step {end_step}')
|
||||||
|
|
||||||
|
optimizer = torch.optim.AdamW([embedding.vec], lr=learn_rate)
|
||||||
|
|
||||||
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
|
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
|
||||||
for i, (x, text) in pbar:
|
for i, (x, text, _) in pbar:
|
||||||
embedding.step = i + ititial_step
|
embedding.step = i + ititial_step
|
||||||
|
|
||||||
if embedding.step > steps:
|
if embedding.step > end_step:
|
||||||
break
|
try:
|
||||||
|
(learn_rate, end_step) = next(schedules)
|
||||||
|
except:
|
||||||
|
break
|
||||||
|
tqdm.tqdm.write(f'Training at rate of {learn_rate} until step {end_step}')
|
||||||
|
for pg in optimizer.param_groups:
|
||||||
|
pg['lr'] = learn_rate
|
||||||
|
|
||||||
if shared.state.interrupted:
|
if shared.state.interrupted:
|
||||||
break
|
break
|
||||||
|
@ -226,10 +234,10 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
epoch_num = embedding.step // epoch_len
|
epoch_num = embedding.step // len(ds)
|
||||||
epoch_step = embedding.step - (epoch_num * epoch_len) + 1
|
epoch_step = embedding.step - (epoch_num * len(ds)) + 1
|
||||||
|
|
||||||
pbar.set_description(f"[Epoch {epoch_num}: {epoch_step}/{epoch_len}]loss: {losses.mean():.7f}")
|
pbar.set_description(f"[Epoch {epoch_num}: {epoch_step}/{len(ds)}]loss: {losses.mean():.7f}")
|
||||||
|
|
||||||
if embedding.step > 0 and embedding_dir is not None and embedding.step % save_embedding_every == 0:
|
if embedding.step > 0 and embedding_dir is not None and embedding.step % save_embedding_every == 0:
|
||||||
last_saved_file = os.path.join(embedding_dir, f'{embedding_name}-{embedding.step}.pt')
|
last_saved_file = os.path.join(embedding_dir, f'{embedding_name}-{embedding.step}.pt')
|
||||||
|
@ -278,4 +286,3 @@ Last saved image: {html.escape(last_saved_image)}<br/>
|
||||||
embedding.save(filename)
|
embedding.save(filename)
|
||||||
|
|
||||||
return embedding, filename
|
return embedding, filename
|
||||||
|
|
||||||
|
|
|
@ -22,6 +22,9 @@ def preprocess(*args):
|
||||||
|
|
||||||
|
|
||||||
def train_embedding(*args):
|
def train_embedding(*args):
|
||||||
|
|
||||||
|
assert not shared.cmd_opts.lowvram, 'Training models with lowvram not possible'
|
||||||
|
|
||||||
try:
|
try:
|
||||||
sd_hijack.undo_optimizations()
|
sd_hijack.undo_optimizations()
|
||||||
|
|
||||||
|
|
|
@ -108,7 +108,7 @@ def send_gradio_gallery_to_image(x):
|
||||||
|
|
||||||
|
|
||||||
def save_files(js_data, images, do_make_zip, index):
|
def save_files(js_data, images, do_make_zip, index):
|
||||||
import csv
|
import csv
|
||||||
filenames = []
|
filenames = []
|
||||||
fullfns = []
|
fullfns = []
|
||||||
|
|
||||||
|
@ -132,6 +132,8 @@ def save_files(js_data, images, do_make_zip, index):
|
||||||
images = [images[index]]
|
images = [images[index]]
|
||||||
start_index = index
|
start_index = index
|
||||||
|
|
||||||
|
os.makedirs(opts.outdir_save, exist_ok=True)
|
||||||
|
|
||||||
with open(os.path.join(opts.outdir_save, "log.csv"), "a", encoding="utf8", newline='') as file:
|
with open(os.path.join(opts.outdir_save, "log.csv"), "a", encoding="utf8", newline='') as file:
|
||||||
at_start = file.tell() == 0
|
at_start = file.tell() == 0
|
||||||
writer = csv.writer(file)
|
writer = csv.writer(file)
|
||||||
|
@ -182,8 +184,15 @@ def wrap_gradio_call(func, extra_outputs=None):
|
||||||
try:
|
try:
|
||||||
res = list(func(*args, **kwargs))
|
res = list(func(*args, **kwargs))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
# When printing out our debug argument list, do not print out more than a MB of text
|
||||||
|
max_debug_str_len = 131072 # (1024*1024)/8
|
||||||
|
|
||||||
print("Error completing request", file=sys.stderr)
|
print("Error completing request", file=sys.stderr)
|
||||||
print("Arguments:", args, kwargs, file=sys.stderr)
|
argStr = f"Arguments: {str(args)} {str(kwargs)}"
|
||||||
|
print(argStr[:max_debug_str_len], file=sys.stderr)
|
||||||
|
if len(argStr) > max_debug_str_len:
|
||||||
|
print(f"(Argument list truncated at {max_debug_str_len}/{len(argStr)} characters)", file=sys.stderr)
|
||||||
|
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
|
||||||
shared.state.job = ""
|
shared.state.job = ""
|
||||||
|
@ -318,7 +327,7 @@ def interrogate(image):
|
||||||
|
|
||||||
|
|
||||||
def interrogate_deepbooru(image):
|
def interrogate_deepbooru(image):
|
||||||
prompt = get_deepbooru_tags(image, opts.interrogate_deepbooru_score_threshold)
|
prompt = get_deepbooru_tags(image)
|
||||||
return gr_show(True) if prompt is None else prompt
|
return gr_show(True) if prompt is None else prompt
|
||||||
|
|
||||||
|
|
||||||
|
@ -558,11 +567,11 @@ def create_ui(wrap_gradio_gpu_call):
|
||||||
button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else 'open_folder'
|
button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else 'open_folder'
|
||||||
open_txt2img_folder = gr.Button(folder_symbol, elem_id=button_id)
|
open_txt2img_folder = gr.Button(folder_symbol, elem_id=button_id)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
do_make_zip = gr.Checkbox(label="Make Zip when Save?", value=False)
|
do_make_zip = gr.Checkbox(label="Make Zip when Save?", value=False)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False)
|
download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False)
|
||||||
|
|
||||||
with gr.Group():
|
with gr.Group():
|
||||||
html_info = gr.HTML()
|
html_info = gr.HTML()
|
||||||
|
@ -747,11 +756,11 @@ def create_ui(wrap_gradio_gpu_call):
|
||||||
button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else 'open_folder'
|
button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else 'open_folder'
|
||||||
open_img2img_folder = gr.Button(folder_symbol, elem_id=button_id)
|
open_img2img_folder = gr.Button(folder_symbol, elem_id=button_id)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
do_make_zip = gr.Checkbox(label="Make Zip when Save?", value=False)
|
do_make_zip = gr.Checkbox(label="Make Zip when Save?", value=False)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False)
|
download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False)
|
||||||
|
|
||||||
with gr.Group():
|
with gr.Group():
|
||||||
html_info = gr.HTML()
|
html_info = gr.HTML()
|
||||||
|
@ -913,7 +922,15 @@ def create_ui(wrap_gradio_gpu_call):
|
||||||
with gr.TabItem('Batch Process'):
|
with gr.TabItem('Batch Process'):
|
||||||
image_batch = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file")
|
image_batch = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file")
|
||||||
|
|
||||||
upscaling_resize = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Resize", value=2)
|
with gr.Tabs(elem_id="extras_resize_mode"):
|
||||||
|
with gr.TabItem('Scale by'):
|
||||||
|
upscaling_resize = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Resize", value=2)
|
||||||
|
with gr.TabItem('Scale to'):
|
||||||
|
with gr.Group():
|
||||||
|
with gr.Row():
|
||||||
|
upscaling_resize_w = gr.Number(label="Width", value=512, precision=0)
|
||||||
|
upscaling_resize_h = gr.Number(label="Height", value=512, precision=0)
|
||||||
|
upscaling_crop = gr.Checkbox(label='Crop to fit', value=True)
|
||||||
|
|
||||||
with gr.Group():
|
with gr.Group():
|
||||||
extras_upscaler_1 = gr.Radio(label='Upscaler 1', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index")
|
extras_upscaler_1 = gr.Radio(label='Upscaler 1', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index")
|
||||||
|
@ -944,6 +961,7 @@ def create_ui(wrap_gradio_gpu_call):
|
||||||
fn=wrap_gradio_gpu_call(modules.extras.run_extras),
|
fn=wrap_gradio_gpu_call(modules.extras.run_extras),
|
||||||
_js="get_extras_tab_index",
|
_js="get_extras_tab_index",
|
||||||
inputs=[
|
inputs=[
|
||||||
|
dummy_component,
|
||||||
dummy_component,
|
dummy_component,
|
||||||
extras_image,
|
extras_image,
|
||||||
image_batch,
|
image_batch,
|
||||||
|
@ -951,6 +969,9 @@ def create_ui(wrap_gradio_gpu_call):
|
||||||
codeformer_visibility,
|
codeformer_visibility,
|
||||||
codeformer_weight,
|
codeformer_weight,
|
||||||
upscaling_resize,
|
upscaling_resize,
|
||||||
|
upscaling_resize_w,
|
||||||
|
upscaling_resize_h,
|
||||||
|
upscaling_crop,
|
||||||
extras_upscaler_1,
|
extras_upscaler_1,
|
||||||
extras_upscaler_2,
|
extras_upscaler_2,
|
||||||
extras_upscaler_2_visibility,
|
extras_upscaler_2_visibility,
|
||||||
|
@ -961,14 +982,14 @@ def create_ui(wrap_gradio_gpu_call):
|
||||||
html_info,
|
html_info,
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
extras_send_to_img2img.click(
|
extras_send_to_img2img.click(
|
||||||
fn=lambda x: image_from_url_text(x),
|
fn=lambda x: image_from_url_text(x),
|
||||||
_js="extract_image_from_gallery_img2img",
|
_js="extract_image_from_gallery_img2img",
|
||||||
inputs=[result_images],
|
inputs=[result_images],
|
||||||
outputs=[init_img],
|
outputs=[init_img],
|
||||||
)
|
)
|
||||||
|
|
||||||
extras_send_to_inpaint.click(
|
extras_send_to_inpaint.click(
|
||||||
fn=lambda x: image_from_url_text(x),
|
fn=lambda x: image_from_url_text(x),
|
||||||
_js="extract_image_from_gallery_inpaint",
|
_js="extract_image_from_gallery_inpaint",
|
||||||
|
@ -1015,14 +1036,14 @@ def create_ui(wrap_gradio_gpu_call):
|
||||||
|
|
||||||
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
|
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
|
||||||
|
|
||||||
with gr.Blocks() as textual_inversion_interface:
|
with gr.Blocks() as train_interface:
|
||||||
with gr.Row().style(equal_height=False):
|
with gr.Row().style(equal_height=False):
|
||||||
with gr.Column():
|
gr.HTML(value="<p style='margin-bottom: 0.7em'>See <b><a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\">wiki</a></b> for detailed explanation.</p>")
|
||||||
with gr.Group():
|
|
||||||
gr.HTML(value="<p style='margin-bottom: 0.7em'>See <b><a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\">wiki</a></b> for detailed explanation.</p>")
|
|
||||||
|
|
||||||
gr.HTML(value="<p style='margin-bottom: 0.7em'>Create a new embedding</p>")
|
with gr.Row().style(equal_height=False):
|
||||||
|
with gr.Tabs(elem_id="train_tabs"):
|
||||||
|
|
||||||
|
with gr.Tab(label="Create embedding"):
|
||||||
new_embedding_name = gr.Textbox(label="Name")
|
new_embedding_name = gr.Textbox(label="Name")
|
||||||
initialization_text = gr.Textbox(label="Initialization text", value="*")
|
initialization_text = gr.Textbox(label="Initialization text", value="*")
|
||||||
nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1)
|
nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1)
|
||||||
|
@ -1034,10 +1055,9 @@ def create_ui(wrap_gradio_gpu_call):
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
create_embedding = gr.Button(value="Create embedding", variant='primary')
|
create_embedding = gr.Button(value="Create embedding", variant='primary')
|
||||||
|
|
||||||
with gr.Group():
|
with gr.Tab(label="Create hypernetwork"):
|
||||||
gr.HTML(value="<p style='margin-bottom: 0.7em'>Create a new hypernetwork</p>")
|
|
||||||
|
|
||||||
new_hypernetwork_name = gr.Textbox(label="Name")
|
new_hypernetwork_name = gr.Textbox(label="Name")
|
||||||
|
new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"])
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column(scale=3):
|
with gr.Column(scale=3):
|
||||||
|
@ -1046,9 +1066,7 @@ def create_ui(wrap_gradio_gpu_call):
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
create_hypernetwork = gr.Button(value="Create hypernetwork", variant='primary')
|
create_hypernetwork = gr.Button(value="Create hypernetwork", variant='primary')
|
||||||
|
|
||||||
with gr.Group():
|
with gr.Tab(label="Preprocess images"):
|
||||||
gr.HTML(value="<p style='margin-bottom: 0.7em'>Preprocess images</p>")
|
|
||||||
|
|
||||||
process_src = gr.Textbox(label='Source directory')
|
process_src = gr.Textbox(label='Source directory')
|
||||||
process_dst = gr.Textbox(label='Destination directory')
|
process_dst = gr.Textbox(label='Destination directory')
|
||||||
process_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
|
process_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
|
||||||
|
@ -1058,6 +1076,10 @@ def create_ui(wrap_gradio_gpu_call):
|
||||||
process_flip = gr.Checkbox(label='Create flipped copies')
|
process_flip = gr.Checkbox(label='Create flipped copies')
|
||||||
process_split = gr.Checkbox(label='Split oversized images into two')
|
process_split = gr.Checkbox(label='Split oversized images into two')
|
||||||
process_caption = gr.Checkbox(label='Use BLIP caption as filename')
|
process_caption = gr.Checkbox(label='Use BLIP caption as filename')
|
||||||
|
if cmd_opts.deepdanbooru:
|
||||||
|
process_caption_deepbooru = gr.Checkbox(label='Use deepbooru caption as filename')
|
||||||
|
else:
|
||||||
|
process_caption_deepbooru = gr.Checkbox(label='Use deepbooru caption as filename', visible=False)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column(scale=3):
|
with gr.Column(scale=3):
|
||||||
|
@ -1066,11 +1088,11 @@ def create_ui(wrap_gradio_gpu_call):
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
run_preprocess = gr.Button(value="Preprocess", variant='primary')
|
run_preprocess = gr.Button(value="Preprocess", variant='primary')
|
||||||
|
|
||||||
with gr.Group():
|
with gr.Tab(label="Train"):
|
||||||
gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding; must specify a directory with a set of 1:1 ratio images</p>")
|
gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding; must specify a directory with a set of 1:1 ratio images</p>")
|
||||||
train_embedding_name = gr.Dropdown(label='Embedding', choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
|
train_embedding_name = gr.Dropdown(label='Embedding', choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
|
||||||
train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', choices=[x for x in shared.hypernetworks.keys()])
|
train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', choices=[x for x in shared.hypernetworks.keys()])
|
||||||
learn_rate = gr.Number(label='Learning rate', value=5.0e-03)
|
learn_rate = gr.Textbox(label='Learning rate', placeholder="Learning rate", value="0.005")
|
||||||
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
|
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
|
||||||
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
|
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
|
||||||
template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"))
|
template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"))
|
||||||
|
@ -1115,6 +1137,7 @@ def create_ui(wrap_gradio_gpu_call):
|
||||||
fn=modules.hypernetworks.ui.create_hypernetwork,
|
fn=modules.hypernetworks.ui.create_hypernetwork,
|
||||||
inputs=[
|
inputs=[
|
||||||
new_hypernetwork_name,
|
new_hypernetwork_name,
|
||||||
|
new_hypernetwork_sizes,
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
train_hypernetwork_name,
|
train_hypernetwork_name,
|
||||||
|
@ -1134,6 +1157,7 @@ def create_ui(wrap_gradio_gpu_call):
|
||||||
process_flip,
|
process_flip,
|
||||||
process_split,
|
process_split,
|
||||||
process_caption,
|
process_caption,
|
||||||
|
process_caption_deepbooru
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
ti_output,
|
ti_output,
|
||||||
|
@ -1353,13 +1377,14 @@ Requested path was: {f}
|
||||||
shared.state.interrupt()
|
shared.state.interrupt()
|
||||||
settings_interface.gradio_ref.do_restart = True
|
settings_interface.gradio_ref.do_restart = True
|
||||||
|
|
||||||
|
|
||||||
restart_gradio.click(
|
restart_gradio.click(
|
||||||
fn=request_restart,
|
fn=request_restart,
|
||||||
inputs=[],
|
inputs=[],
|
||||||
outputs=[],
|
outputs=[],
|
||||||
_js='function(){restart_reload()}'
|
_js='function(){restart_reload()}'
|
||||||
)
|
)
|
||||||
|
|
||||||
if column is not None:
|
if column is not None:
|
||||||
column.__exit__()
|
column.__exit__()
|
||||||
|
|
||||||
|
@ -1369,8 +1394,8 @@ Requested path was: {f}
|
||||||
(extras_interface, "Extras", "extras"),
|
(extras_interface, "Extras", "extras"),
|
||||||
(pnginfo_interface, "PNG Info", "pnginfo"),
|
(pnginfo_interface, "PNG Info", "pnginfo"),
|
||||||
(modelmerger_interface, "Checkpoint Merger", "modelmerger"),
|
(modelmerger_interface, "Checkpoint Merger", "modelmerger"),
|
||||||
(textual_inversion_interface, "Textual inversion", "ti"),
|
|
||||||
(images_history, "History", "images_history"),
|
(images_history, "History", "images_history"),
|
||||||
|
(train_interface, "Train", "ti"),
|
||||||
(settings_interface, "Settings", "settings"),
|
(settings_interface, "Settings", "settings"),
|
||||||
|
|
||||||
]
|
]
|
||||||
|
@ -1393,12 +1418,12 @@ Requested path was: {f}
|
||||||
component_dict[k] = component
|
component_dict[k] = component
|
||||||
|
|
||||||
settings_interface.gradio_ref = demo
|
settings_interface.gradio_ref = demo
|
||||||
|
|
||||||
with gr.Tabs() as tabs:
|
with gr.Tabs() as tabs:
|
||||||
for interface, label, ifid in interfaces:
|
for interface, label, ifid in interfaces:
|
||||||
with gr.TabItem(label, id=ifid, elem_id='tab_' + ifid):
|
with gr.TabItem(label, id=ifid, elem_id='tab_' + ifid):
|
||||||
interface.render()
|
interface.render()
|
||||||
|
|
||||||
if os.path.exists(os.path.join(script_path, "notification.mp3")):
|
if os.path.exists(os.path.join(script_path, "notification.mp3")):
|
||||||
audio_notification = gr.Audio(interactive=False, value=os.path.join(script_path, "notification.mp3"), elem_id="audio_notification", visible=False)
|
audio_notification = gr.Audio(interactive=False, value=os.path.join(script_path, "notification.mp3"), elem_id="audio_notification", visible=False)
|
||||||
|
|
||||||
|
@ -1531,10 +1556,10 @@ Requested path was: {f}
|
||||||
|
|
||||||
if getattr(obj,'custom_script_source',None) is not None:
|
if getattr(obj,'custom_script_source',None) is not None:
|
||||||
key = 'customscript/' + obj.custom_script_source + '/' + key
|
key = 'customscript/' + obj.custom_script_source + '/' + key
|
||||||
|
|
||||||
if getattr(obj, 'do_not_save_to_config', False):
|
if getattr(obj, 'do_not_save_to_config', False):
|
||||||
return
|
return
|
||||||
|
|
||||||
saved_value = ui_settings.get(key, None)
|
saved_value = ui_settings.get(key, None)
|
||||||
if saved_value is None:
|
if saved_value is None:
|
||||||
ui_settings[key] = getattr(obj, field)
|
ui_settings[key] = getattr(obj, field)
|
||||||
|
@ -1558,10 +1583,10 @@ Requested path was: {f}
|
||||||
|
|
||||||
if type(x) == gr.Textbox:
|
if type(x) == gr.Textbox:
|
||||||
apply_field(x, 'value')
|
apply_field(x, 'value')
|
||||||
|
|
||||||
if type(x) == gr.Number:
|
if type(x) == gr.Number:
|
||||||
apply_field(x, 'value')
|
apply_field(x, 'value')
|
||||||
|
|
||||||
visit(txt2img_interface, loadsave, "txt2img")
|
visit(txt2img_interface, loadsave, "txt2img")
|
||||||
visit(img2img_interface, loadsave, "img2img")
|
visit(img2img_interface, loadsave, "img2img")
|
||||||
visit(extras_interface, loadsave, "extras")
|
visit(extras_interface, loadsave, "extras")
|
||||||
|
|
|
@ -4,7 +4,7 @@ fairscale==0.4.4
|
||||||
fonts
|
fonts
|
||||||
font-roboto
|
font-roboto
|
||||||
gfpgan
|
gfpgan
|
||||||
gradio==3.4b3
|
gradio==3.4.1
|
||||||
invisible-watermark
|
invisible-watermark
|
||||||
numpy
|
numpy
|
||||||
omegaconf
|
omegaconf
|
||||||
|
|
|
@ -2,7 +2,7 @@ transformers==4.19.2
|
||||||
diffusers==0.3.0
|
diffusers==0.3.0
|
||||||
basicsr==1.4.2
|
basicsr==1.4.2
|
||||||
gfpgan==1.3.8
|
gfpgan==1.3.8
|
||||||
gradio==3.4b3
|
gradio==3.4.1
|
||||||
numpy==1.23.3
|
numpy==1.23.3
|
||||||
Pillow==9.2.0
|
Pillow==9.2.0
|
||||||
realesrgan==0.3.0
|
realesrgan==0.3.0
|
||||||
|
|
|
@ -129,8 +129,6 @@ class Script(scripts.Script):
|
||||||
return [original_prompt, original_negative_prompt, cfg, st, randomness, sigma_adjustment]
|
return [original_prompt, original_negative_prompt, cfg, st, randomness, sigma_adjustment]
|
||||||
|
|
||||||
def run(self, p, original_prompt, original_negative_prompt, cfg, st, randomness, sigma_adjustment):
|
def run(self, p, original_prompt, original_negative_prompt, cfg, st, randomness, sigma_adjustment):
|
||||||
p.batch_size = 1
|
|
||||||
p.batch_count = 1
|
|
||||||
|
|
||||||
|
|
||||||
def sample_extra(conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
|
def sample_extra(conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
|
||||||
|
@ -154,7 +152,7 @@ class Script(scripts.Script):
|
||||||
rec_noise = find_noise_for_image(p, cond, uncond, cfg, st)
|
rec_noise = find_noise_for_image(p, cond, uncond, cfg, st)
|
||||||
self.cache = Cached(rec_noise, cfg, st, lat, original_prompt, original_negative_prompt, sigma_adjustment)
|
self.cache = Cached(rec_noise, cfg, st, lat, original_prompt, original_negative_prompt, sigma_adjustment)
|
||||||
|
|
||||||
rand_noise = processing.create_random_tensors(p.init_latent.shape[1:], [p.seed + x + 1 for x in range(p.init_latent.shape[0])])
|
rand_noise = processing.create_random_tensors(p.init_latent.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, seed_resize_from_h=p.seed_resize_from_h, seed_resize_from_w=p.seed_resize_from_w, p=p)
|
||||||
|
|
||||||
combined_noise = ((1 - randomness) * rec_noise + randomness * rand_noise) / ((randomness**2 + (1-randomness)**2) ** 0.5)
|
combined_noise = ((1 - randomness) * rec_noise + randomness * rand_noise) / ((randomness**2 + (1-randomness)**2) ** 0.5)
|
||||||
|
|
||||||
|
|
|
@ -77,14 +77,42 @@ def apply_sampler(p, x, xs):
|
||||||
p.sampler_index = sampler_index
|
p.sampler_index = sampler_index
|
||||||
|
|
||||||
|
|
||||||
|
def confirm_samplers(p, xs):
|
||||||
|
samplers_dict = build_samplers_dict(p)
|
||||||
|
for x in xs:
|
||||||
|
if x.lower() not in samplers_dict.keys():
|
||||||
|
raise RuntimeError(f"Unknown sampler: {x}")
|
||||||
|
|
||||||
|
|
||||||
def apply_checkpoint(p, x, xs):
|
def apply_checkpoint(p, x, xs):
|
||||||
info = modules.sd_models.get_closet_checkpoint_match(x)
|
info = modules.sd_models.get_closet_checkpoint_match(x)
|
||||||
assert info is not None, f'Checkpoint for {x} not found'
|
if info is None:
|
||||||
|
raise RuntimeError(f"Unknown checkpoint: {x}")
|
||||||
modules.sd_models.reload_model_weights(shared.sd_model, info)
|
modules.sd_models.reload_model_weights(shared.sd_model, info)
|
||||||
|
|
||||||
|
|
||||||
|
def confirm_checkpoints(p, xs):
|
||||||
|
for x in xs:
|
||||||
|
if modules.sd_models.get_closet_checkpoint_match(x) is None:
|
||||||
|
raise RuntimeError(f"Unknown checkpoint: {x}")
|
||||||
|
|
||||||
|
|
||||||
def apply_hypernetwork(p, x, xs):
|
def apply_hypernetwork(p, x, xs):
|
||||||
hypernetwork.load_hypernetwork(x)
|
if x.lower() in ["", "none"]:
|
||||||
|
name = None
|
||||||
|
else:
|
||||||
|
name = hypernetwork.find_closest_hypernetwork_name(x)
|
||||||
|
if not name:
|
||||||
|
raise RuntimeError(f"Unknown hypernetwork: {x}")
|
||||||
|
hypernetwork.load_hypernetwork(name)
|
||||||
|
|
||||||
|
|
||||||
|
def confirm_hypernetworks(p, xs):
|
||||||
|
for x in xs:
|
||||||
|
if x.lower() in ["", "none"]:
|
||||||
|
continue
|
||||||
|
if not hypernetwork.find_closest_hypernetwork_name(x):
|
||||||
|
raise RuntimeError(f"Unknown hypernetwork: {x}")
|
||||||
|
|
||||||
|
|
||||||
def apply_clip_skip(p, x, xs):
|
def apply_clip_skip(p, x, xs):
|
||||||
|
@ -121,29 +149,29 @@ def str_permutations(x):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
AxisOption = namedtuple("AxisOption", ["label", "type", "apply", "format_value"])
|
AxisOption = namedtuple("AxisOption", ["label", "type", "apply", "format_value", "confirm"])
|
||||||
AxisOptionImg2Img = namedtuple("AxisOptionImg2Img", ["label", "type", "apply", "format_value"])
|
AxisOptionImg2Img = namedtuple("AxisOptionImg2Img", ["label", "type", "apply", "format_value", "confirm"])
|
||||||
|
|
||||||
|
|
||||||
axis_options = [
|
axis_options = [
|
||||||
AxisOption("Nothing", str, do_nothing, format_nothing),
|
AxisOption("Nothing", str, do_nothing, format_nothing, None),
|
||||||
AxisOption("Seed", int, apply_field("seed"), format_value_add_label),
|
AxisOption("Seed", int, apply_field("seed"), format_value_add_label, None),
|
||||||
AxisOption("Var. seed", int, apply_field("subseed"), format_value_add_label),
|
AxisOption("Var. seed", int, apply_field("subseed"), format_value_add_label, None),
|
||||||
AxisOption("Var. strength", float, apply_field("subseed_strength"), format_value_add_label),
|
AxisOption("Var. strength", float, apply_field("subseed_strength"), format_value_add_label, None),
|
||||||
AxisOption("Steps", int, apply_field("steps"), format_value_add_label),
|
AxisOption("Steps", int, apply_field("steps"), format_value_add_label, None),
|
||||||
AxisOption("CFG Scale", float, apply_field("cfg_scale"), format_value_add_label),
|
AxisOption("CFG Scale", float, apply_field("cfg_scale"), format_value_add_label, None),
|
||||||
AxisOption("Prompt S/R", str, apply_prompt, format_value),
|
AxisOption("Prompt S/R", str, apply_prompt, format_value, None),
|
||||||
AxisOption("Prompt order", str_permutations, apply_order, format_value_join_list),
|
AxisOption("Prompt order", str_permutations, apply_order, format_value_join_list, None),
|
||||||
AxisOption("Sampler", str, apply_sampler, format_value),
|
AxisOption("Sampler", str, apply_sampler, format_value, confirm_samplers),
|
||||||
AxisOption("Checkpoint name", str, apply_checkpoint, format_value),
|
AxisOption("Checkpoint name", str, apply_checkpoint, format_value, confirm_checkpoints),
|
||||||
AxisOption("Hypernetwork", str, apply_hypernetwork, format_value),
|
AxisOption("Hypernetwork", str, apply_hypernetwork, format_value, confirm_hypernetworks),
|
||||||
AxisOption("Sigma Churn", float, apply_field("s_churn"), format_value_add_label),
|
AxisOption("Sigma Churn", float, apply_field("s_churn"), format_value_add_label, None),
|
||||||
AxisOption("Sigma min", float, apply_field("s_tmin"), format_value_add_label),
|
AxisOption("Sigma min", float, apply_field("s_tmin"), format_value_add_label, None),
|
||||||
AxisOption("Sigma max", float, apply_field("s_tmax"), format_value_add_label),
|
AxisOption("Sigma max", float, apply_field("s_tmax"), format_value_add_label, None),
|
||||||
AxisOption("Sigma noise", float, apply_field("s_noise"), format_value_add_label),
|
AxisOption("Sigma noise", float, apply_field("s_noise"), format_value_add_label, None),
|
||||||
AxisOption("Eta", float, apply_field("eta"), format_value_add_label),
|
AxisOption("Eta", float, apply_field("eta"), format_value_add_label, None),
|
||||||
AxisOption("Clip skip", int, apply_clip_skip, format_value_add_label),
|
AxisOption("Clip skip", int, apply_clip_skip, format_value_add_label, None),
|
||||||
AxisOptionImg2Img("Denoising", float, apply_field("denoising_strength"), format_value_add_label), # as it is now all AxisOptionImg2Img items must go after AxisOption ones
|
AxisOptionImg2Img("Denoising", float, apply_field("denoising_strength"), format_value_add_label, None), # as it is now all AxisOptionImg2Img items must go after AxisOption ones
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -197,7 +225,7 @@ class Script(scripts.Script):
|
||||||
x_values = gr.Textbox(label="X values", visible=False, lines=1)
|
x_values = gr.Textbox(label="X values", visible=False, lines=1)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
y_type = gr.Dropdown(label="Y type", choices=[x.label for x in current_axis_options], value=current_axis_options[4].label, visible=False, type="index", elem_id="y_type")
|
y_type = gr.Dropdown(label="Y type", choices=[x.label for x in current_axis_options], value=current_axis_options[0].label, visible=False, type="index", elem_id="y_type")
|
||||||
y_values = gr.Textbox(label="Y values", visible=False, lines=1)
|
y_values = gr.Textbox(label="Y values", visible=False, lines=1)
|
||||||
|
|
||||||
draw_legend = gr.Checkbox(label='Draw legend', value=True)
|
draw_legend = gr.Checkbox(label='Draw legend', value=True)
|
||||||
|
@ -269,17 +297,10 @@ class Script(scripts.Script):
|
||||||
valslist = list(permutations(valslist))
|
valslist = list(permutations(valslist))
|
||||||
|
|
||||||
valslist = [opt.type(x) for x in valslist]
|
valslist = [opt.type(x) for x in valslist]
|
||||||
|
|
||||||
# Confirm options are valid before starting
|
# Confirm options are valid before starting
|
||||||
if opt.label == "Sampler":
|
if opt.confirm:
|
||||||
samplers_dict = build_samplers_dict(p)
|
opt.confirm(p, valslist)
|
||||||
for sampler_val in valslist:
|
|
||||||
if sampler_val.lower() not in samplers_dict.keys():
|
|
||||||
raise RuntimeError(f"Unknown sampler: {sampler_val}")
|
|
||||||
elif opt.label == "Checkpoint name":
|
|
||||||
for ckpt_val in valslist:
|
|
||||||
if modules.sd_models.get_closet_checkpoint_match(ckpt_val) is None:
|
|
||||||
raise RuntimeError(f"Checkpoint for {ckpt_val} not found")
|
|
||||||
|
|
||||||
return valslist
|
return valslist
|
||||||
|
|
||||||
|
|
11
style.css
11
style.css
|
@ -240,6 +240,7 @@ fieldset span.text-gray-500, .gr-block.gr-box span.text-gray-500, label.block s
|
||||||
#settings fieldset span.text-gray-500, #settings .gr-block.gr-box span.text-gray-500, #settings label.block span{
|
#settings fieldset span.text-gray-500, #settings .gr-block.gr-box span.text-gray-500, #settings label.block span{
|
||||||
position: relative;
|
position: relative;
|
||||||
border: none;
|
border: none;
|
||||||
|
margin-right: 8em;
|
||||||
}
|
}
|
||||||
|
|
||||||
.gr-panel div.flex-col div.justify-between label span{
|
.gr-panel div.flex-col div.justify-between label span{
|
||||||
|
@ -494,4 +495,14 @@ canvas[key="mask"] {
|
||||||
filter: invert();
|
filter: invert();
|
||||||
mix-blend-mode: multiply;
|
mix-blend-mode: multiply;
|
||||||
pointer-events: none;
|
pointer-events: none;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/* gradio 3.4.1 stuff for editable scrollbar values */
|
||||||
|
.gr-box > div > div > input.gr-text-input{
|
||||||
|
position: absolute;
|
||||||
|
right: 0.5em;
|
||||||
|
top: -0.6em;
|
||||||
|
z-index: 200;
|
||||||
|
width: 8em;
|
||||||
}
|
}
|
29
webui.py
29
webui.py
|
@ -31,12 +31,7 @@ from modules.paths import script_path
|
||||||
from modules.shared import cmd_opts
|
from modules.shared import cmd_opts
|
||||||
import modules.hypernetworks.hypernetwork
|
import modules.hypernetworks.hypernetwork
|
||||||
|
|
||||||
modelloader.cleanup_models()
|
|
||||||
modules.sd_models.setup_model()
|
|
||||||
codeformer.setup_model(cmd_opts.codeformer_models_path)
|
|
||||||
gfpgan.setup_model(cmd_opts.gfpgan_models_path)
|
|
||||||
shared.face_restorers.append(modules.face_restoration.FaceRestoration())
|
|
||||||
modelloader.load_upscalers()
|
|
||||||
queue_lock = threading.Lock()
|
queue_lock = threading.Lock()
|
||||||
|
|
||||||
|
|
||||||
|
@ -78,15 +73,24 @@ def wrap_gradio_gpu_call(func, extra_outputs=None):
|
||||||
return modules.ui.wrap_gradio_call(f, extra_outputs=extra_outputs)
|
return modules.ui.wrap_gradio_call(f, extra_outputs=extra_outputs)
|
||||||
|
|
||||||
|
|
||||||
modules.scripts.load_scripts(os.path.join(script_path, "scripts"))
|
def initialize():
|
||||||
|
modelloader.cleanup_models()
|
||||||
|
modules.sd_models.setup_model()
|
||||||
|
codeformer.setup_model(cmd_opts.codeformer_models_path)
|
||||||
|
gfpgan.setup_model(cmd_opts.gfpgan_models_path)
|
||||||
|
shared.face_restorers.append(modules.face_restoration.FaceRestoration())
|
||||||
|
modelloader.load_upscalers()
|
||||||
|
|
||||||
shared.sd_model = modules.sd_models.load_model()
|
modules.scripts.load_scripts(os.path.join(script_path, "scripts"))
|
||||||
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(shared.sd_model)))
|
|
||||||
|
|
||||||
shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork)))
|
shared.sd_model = modules.sd_models.load_model()
|
||||||
|
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(shared.sd_model)))
|
||||||
|
shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork)))
|
||||||
|
|
||||||
|
|
||||||
def webui():
|
def webui():
|
||||||
|
initialize()
|
||||||
|
|
||||||
# make the program just exit at ctrl+c without waiting for anything
|
# make the program just exit at ctrl+c without waiting for anything
|
||||||
def sigint_handler(sig, frame):
|
def sigint_handler(sig, frame):
|
||||||
print(f'Interrupted with signal {sig} in {frame}')
|
print(f'Interrupted with signal {sig} in {frame}')
|
||||||
|
@ -98,7 +102,7 @@ def webui():
|
||||||
|
|
||||||
demo = modules.ui.create_ui(wrap_gradio_gpu_call=wrap_gradio_gpu_call)
|
demo = modules.ui.create_ui(wrap_gradio_gpu_call=wrap_gradio_gpu_call)
|
||||||
|
|
||||||
app,local_url,share_url = demo.launch(
|
app, local_url, share_url = demo.launch(
|
||||||
share=cmd_opts.share,
|
share=cmd_opts.share,
|
||||||
server_name="0.0.0.0" if cmd_opts.listen else None,
|
server_name="0.0.0.0" if cmd_opts.listen else None,
|
||||||
server_port=cmd_opts.port,
|
server_port=cmd_opts.port,
|
||||||
|
@ -124,9 +128,10 @@ def webui():
|
||||||
modules.scripts.reload_scripts(os.path.join(script_path, "scripts"))
|
modules.scripts.reload_scripts(os.path.join(script_path, "scripts"))
|
||||||
print('Reloading modules: modules.ui')
|
print('Reloading modules: modules.ui')
|
||||||
importlib.reload(modules.ui)
|
importlib.reload(modules.ui)
|
||||||
|
print('Refreshing Model List')
|
||||||
|
modules.sd_models.list_models()
|
||||||
print('Restarting Gradio')
|
print('Restarting Gradio')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
webui()
|
webui()
|
||||||
|
|
Loading…
Reference in a new issue