Merge branch 'master' into #1484_fix_empty_styles_pattern
This commit is contained in:
commit
3fac3764b3
38 changed files with 1588 additions and 409 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -25,3 +25,4 @@ __pycache__
|
||||||
/.idea
|
/.idea
|
||||||
notification.mp3
|
notification.mp3
|
||||||
/SwinIR
|
/SwinIR
|
||||||
|
/textual_inversion
|
||||||
|
|
17
README.md
17
README.md
|
@ -11,12 +11,12 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web
|
||||||
- One click install and run script (but you still must install python and git)
|
- One click install and run script (but you still must install python and git)
|
||||||
- Outpainting
|
- Outpainting
|
||||||
- Inpainting
|
- Inpainting
|
||||||
- Prompt
|
- Prompt Matrix
|
||||||
- Stable Diffusion upscale
|
- Stable Diffusion Upscale
|
||||||
- Attention, specify parts of text that the model should pay more attention to
|
- Attention, specify parts of text that the model should pay more attention to
|
||||||
- a man in a ((txuedo)) - will pay more attentinoto tuxedo
|
- a man in a ((tuxedo)) - will pay more attention to tuxedo
|
||||||
- a man in a (txuedo:1.21) - alternative syntax
|
- a man in a (tuxedo:1.21) - alternative syntax
|
||||||
- Loopback, run img2img procvessing multiple times
|
- Loopback, run img2img processing multiple times
|
||||||
- X/Y plot, a way to draw a 2 dimensional plot of images with different parameters
|
- X/Y plot, a way to draw a 2 dimensional plot of images with different parameters
|
||||||
- Textual Inversion
|
- Textual Inversion
|
||||||
- have as many embeddings as you want and use any names you like for them
|
- have as many embeddings as you want and use any names you like for them
|
||||||
|
@ -35,15 +35,15 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web
|
||||||
- 4GB video card support (also reports of 2GB working)
|
- 4GB video card support (also reports of 2GB working)
|
||||||
- Correct seeds for batches
|
- Correct seeds for batches
|
||||||
- Prompt length validation
|
- Prompt length validation
|
||||||
- get length of prompt in tokensas you type
|
- get length of prompt in tokens as you type
|
||||||
- get a warning after geenration if some text was truncated
|
- get a warning after generation if some text was truncated
|
||||||
- Generation parameters
|
- Generation parameters
|
||||||
- parameters you used to generate images are saved with that image
|
- parameters you used to generate images are saved with that image
|
||||||
- in PNG chunks for PNG, in EXIF for JPEG
|
- in PNG chunks for PNG, in EXIF for JPEG
|
||||||
- can drag the image to PNG info tab to restore generation parameters and automatically copy them into UI
|
- can drag the image to PNG info tab to restore generation parameters and automatically copy them into UI
|
||||||
- can be disabled in settings
|
- can be disabled in settings
|
||||||
- Settings page
|
- Settings page
|
||||||
- Running arbitrary python code from UI (must run with commandline flag to enable)
|
- Running arbitrary python code from UI (must run with --allow-code to enable)
|
||||||
- Mouseover hints for most UI elements
|
- Mouseover hints for most UI elements
|
||||||
- Possible to change defaults/mix/max/step values for UI elements via text config
|
- Possible to change defaults/mix/max/step values for UI elements via text config
|
||||||
- Random artist button
|
- Random artist button
|
||||||
|
@ -113,6 +113,7 @@ The documentation was moved from this README over to the project's [wiki](https:
|
||||||
- LDSR - https://github.com/Hafiidz/latent-diffusion
|
- LDSR - https://github.com/Hafiidz/latent-diffusion
|
||||||
- Ideas for optimizations - https://github.com/basujindal/stable-diffusion
|
- Ideas for optimizations - https://github.com/basujindal/stable-diffusion
|
||||||
- Doggettx - Cross Attention layer optimization - https://github.com/Doggettx/stable-diffusion, original idea for prompt editing.
|
- Doggettx - Cross Attention layer optimization - https://github.com/Doggettx/stable-diffusion, original idea for prompt editing.
|
||||||
|
- Rinon Gal - Textual Inversion - https://github.com/rinongal/textual_inversion (we're not using his code, but we are using his ideas).
|
||||||
- Idea for SD upscale - https://github.com/jquesnelle/txt2imghd
|
- Idea for SD upscale - https://github.com/jquesnelle/txt2imghd
|
||||||
- Noise generation for outpainting mk2 - https://github.com/parlance-zz/g-diffuser-bot
|
- Noise generation for outpainting mk2 - https://github.com/parlance-zz/g-diffuser-bot
|
||||||
- CLIP interrogator idea and borrowing some code - https://github.com/pharmapsychotic/clip-interrogator
|
- CLIP interrogator idea and borrowing some code - https://github.com/pharmapsychotic/clip-interrogator
|
||||||
|
|
|
@ -30,6 +30,7 @@ function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_inte
|
||||||
onUiUpdate(function(){
|
onUiUpdate(function(){
|
||||||
check_progressbar('txt2img', 'txt2img_progressbar', 'txt2img_progress_span', 'txt2img_interrupt', 'txt2img_preview', 'txt2img_gallery')
|
check_progressbar('txt2img', 'txt2img_progressbar', 'txt2img_progress_span', 'txt2img_interrupt', 'txt2img_preview', 'txt2img_gallery')
|
||||||
check_progressbar('img2img', 'img2img_progressbar', 'img2img_progress_span', 'img2img_interrupt', 'img2img_preview', 'img2img_gallery')
|
check_progressbar('img2img', 'img2img_progressbar', 'img2img_progress_span', 'img2img_interrupt', 'img2img_preview', 'img2img_gallery')
|
||||||
|
check_progressbar('ti', 'ti_progressbar', 'ti_progress_span', 'ti_interrupt', 'ti_preview', 'ti_gallery')
|
||||||
})
|
})
|
||||||
|
|
||||||
function requestMoreProgress(id_part, id_progressbar_span, id_interrupt){
|
function requestMoreProgress(id_part, id_progressbar_span, id_interrupt){
|
||||||
|
|
8
javascript/textualInversion.js
Normal file
8
javascript/textualInversion.js
Normal file
|
@ -0,0 +1,8 @@
|
||||||
|
|
||||||
|
|
||||||
|
function start_training_textual_inversion(){
|
||||||
|
requestProgress('ti')
|
||||||
|
gradioApp().querySelector('#ti_error').innerHTML=''
|
||||||
|
|
||||||
|
return args_to_array(arguments)
|
||||||
|
}
|
|
@ -199,6 +199,26 @@ let txt2img_textarea, img2img_textarea = undefined;
|
||||||
let wait_time = 800
|
let wait_time = 800
|
||||||
let token_timeout;
|
let token_timeout;
|
||||||
|
|
||||||
|
function update_txt2img_tokens(...args) {
|
||||||
|
update_token_counter("txt2img_token_button")
|
||||||
|
if (args.length == 2)
|
||||||
|
return args[0]
|
||||||
|
return args;
|
||||||
|
}
|
||||||
|
|
||||||
|
function update_img2img_tokens(...args) {
|
||||||
|
update_token_counter("img2img_token_button")
|
||||||
|
if (args.length == 2)
|
||||||
|
return args[0]
|
||||||
|
return args;
|
||||||
|
}
|
||||||
|
|
||||||
|
function update_token_counter(button_id) {
|
||||||
|
if (token_timeout)
|
||||||
|
clearTimeout(token_timeout);
|
||||||
|
token_timeout = setTimeout(() => gradioApp().getElementById(button_id)?.click(), wait_time);
|
||||||
|
}
|
||||||
|
|
||||||
function submit_prompt(event, generate_button_id) {
|
function submit_prompt(event, generate_button_id) {
|
||||||
if (event.altKey && event.keyCode === 13) {
|
if (event.altKey && event.keyCode === 13) {
|
||||||
event.preventDefault();
|
event.preventDefault();
|
||||||
|
@ -207,8 +227,7 @@ function submit_prompt(event, generate_button_id) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
function update_token_counter(button_id) {
|
function restart_reload(){
|
||||||
if (token_timeout)
|
document.body.innerHTML='<h1 style="font-family:monospace;margin-top:20%;color:lightgray;text-align:center;">Reloading...</h1>';
|
||||||
clearTimeout(token_timeout);
|
setTimeout(function(){location.reload()},2000)
|
||||||
token_timeout = setTimeout(() => gradioApp().getElementById(button_id)?.click(), wait_time);
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,6 +15,7 @@ requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
|
||||||
commandline_args = os.environ.get('COMMANDLINE_ARGS', "")
|
commandline_args = os.environ.get('COMMANDLINE_ARGS', "")
|
||||||
|
|
||||||
gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379")
|
gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379")
|
||||||
|
clip_package = os.environ.get('CLIP_PACKAGE', "git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1")
|
||||||
|
|
||||||
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "69ae4b35e0a0f6ee1af8bb9a5d0016ccb27e36dc")
|
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "69ae4b35e0a0f6ee1af8bb9a5d0016ccb27e36dc")
|
||||||
taming_transformers_commit_hash = os.environ.get('TAMING_TRANSFORMERS_COMMIT_HASH', "24268930bf1dce879235a7fddd0b2355b84d7ea6")
|
taming_transformers_commit_hash = os.environ.get('TAMING_TRANSFORMERS_COMMIT_HASH', "24268930bf1dce879235a7fddd0b2355b84d7ea6")
|
||||||
|
@ -111,6 +112,9 @@ if not skip_torch_cuda_test:
|
||||||
if not is_installed("gfpgan"):
|
if not is_installed("gfpgan"):
|
||||||
run_pip(f"install {gfpgan_package}", "gfpgan")
|
run_pip(f"install {gfpgan_package}", "gfpgan")
|
||||||
|
|
||||||
|
if not is_installed("clip"):
|
||||||
|
run_pip(f"install {clip_package}", "clip")
|
||||||
|
|
||||||
os.makedirs(dir_repos, exist_ok=True)
|
os.makedirs(dir_repos, exist_ok=True)
|
||||||
|
|
||||||
git_clone("https://github.com/CompVis/stable-diffusion.git", repo_dir('stable-diffusion'), "Stable Diffusion", stable_diffusion_commit_hash)
|
git_clone("https://github.com/CompVis/stable-diffusion.git", repo_dir('stable-diffusion'), "Stable Diffusion", stable_diffusion_commit_hash)
|
||||||
|
|
|
@ -32,10 +32,9 @@ def enable_tf32():
|
||||||
|
|
||||||
errors.run(enable_tf32, "Enabling TF32")
|
errors.run(enable_tf32, "Enabling TF32")
|
||||||
|
|
||||||
|
|
||||||
device = get_optimal_device()
|
device = get_optimal_device()
|
||||||
device_codeformer = cpu if has_mps else device
|
device_codeformer = cpu if has_mps else device
|
||||||
|
dtype = torch.float16
|
||||||
|
|
||||||
def randn(seed, shape):
|
def randn(seed, shape):
|
||||||
# Pytorch currently doesn't handle setting randomness correctly when the metal backend is used.
|
# Pytorch currently doesn't handle setting randomness correctly when the metal backend is used.
|
||||||
|
|
|
@ -73,8 +73,8 @@ def fix_model_layers(crt_model, pretrained_net):
|
||||||
class UpscalerESRGAN(Upscaler):
|
class UpscalerESRGAN(Upscaler):
|
||||||
def __init__(self, dirname):
|
def __init__(self, dirname):
|
||||||
self.name = "ESRGAN"
|
self.name = "ESRGAN"
|
||||||
self.model_url = "https://drive.google.com/u/0/uc?id=1TPrz5QKd8DHHt1k8SRtm6tMiPjz_Qene&export=download"
|
self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/ESRGAN.pth"
|
||||||
self.model_name = "ESRGAN 4x"
|
self.model_name = "ESRGAN_4x"
|
||||||
self.scalers = []
|
self.scalers = []
|
||||||
self.user_path = dirname
|
self.user_path = dirname
|
||||||
self.model_path = os.path.join(models_path, self.name)
|
self.model_path = os.path.join(models_path, self.name)
|
||||||
|
|
|
@ -311,7 +311,12 @@ def apply_filename_pattern(x, p, seed, prompt):
|
||||||
x = x.replace("[cfg]", str(p.cfg_scale))
|
x = x.replace("[cfg]", str(p.cfg_scale))
|
||||||
x = x.replace("[width]", str(p.width))
|
x = x.replace("[width]", str(p.width))
|
||||||
x = x.replace("[height]", str(p.height))
|
x = x.replace("[height]", str(p.height))
|
||||||
x = x.replace("[styles]", sanitize_filename_part(", ".join([x for x in p.styles if not x == "None"]) or "No styles", replace_spaces=False))
|
|
||||||
|
#currently disabled if using the save button, will work otherwise
|
||||||
|
# if enabled it will cause a bug because styles is not included in the save_files data dictionary
|
||||||
|
if hasattr(p, "styles"):
|
||||||
|
x = x.replace("[styles]", sanitize_filename_part(", ".join([x for x in p.styles if not x == "None"] or "None"), replace_spaces=False))
|
||||||
|
|
||||||
x = x.replace("[sampler]", sanitize_filename_part(sd_samplers.samplers[p.sampler_index].name, replace_spaces=False))
|
x = x.replace("[sampler]", sanitize_filename_part(sd_samplers.samplers[p.sampler_index].name, replace_spaces=False))
|
||||||
|
|
||||||
x = x.replace("[model_hash]", shared.sd_model.sd_model_hash)
|
x = x.replace("[model_hash]", shared.sd_model.sd_model_hash)
|
||||||
|
|
|
@ -103,7 +103,9 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
|
||||||
inpaint_full_res_padding=inpaint_full_res_padding,
|
inpaint_full_res_padding=inpaint_full_res_padding,
|
||||||
inpainting_mask_invert=inpainting_mask_invert,
|
inpainting_mask_invert=inpainting_mask_invert,
|
||||||
)
|
)
|
||||||
print(f"\nimg2img: {prompt}", file=shared.progress_print_out)
|
|
||||||
|
if shared.cmd_opts.enable_console_prompts:
|
||||||
|
print(f"\nimg2img: {prompt}", file=shared.progress_print_out)
|
||||||
|
|
||||||
p.extra_generation_params["Mask blur"] = mask_blur
|
p.extra_generation_params["Mask blur"] = mask_blur
|
||||||
|
|
||||||
|
|
|
@ -21,6 +21,7 @@ Category = namedtuple("Category", ["name", "topn", "items"])
|
||||||
|
|
||||||
re_topn = re.compile(r"\.top(\d+)\.")
|
re_topn = re.compile(r"\.top(\d+)\.")
|
||||||
|
|
||||||
|
|
||||||
class InterrogateModels:
|
class InterrogateModels:
|
||||||
blip_model = None
|
blip_model = None
|
||||||
clip_model = None
|
clip_model = None
|
||||||
|
|
|
@ -5,7 +5,6 @@ import importlib
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
from basicsr.utils.download_util import load_file_from_url
|
from basicsr.utils.download_util import load_file_from_url
|
||||||
|
|
||||||
from modules import shared
|
from modules import shared
|
||||||
from modules.upscaler import Upscaler
|
from modules.upscaler import Upscaler
|
||||||
from modules.paths import script_path, models_path
|
from modules.paths import script_path, models_path
|
||||||
|
@ -43,7 +42,7 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
|
||||||
for place in places:
|
for place in places:
|
||||||
if os.path.exists(place):
|
if os.path.exists(place):
|
||||||
for file in glob.iglob(place + '**/**', recursive=True):
|
for file in glob.iglob(place + '**/**', recursive=True):
|
||||||
full_path = os.path.join(place, file)
|
full_path = file
|
||||||
if os.path.isdir(full_path):
|
if os.path.isdir(full_path):
|
||||||
continue
|
continue
|
||||||
if len(ext_filter) != 0:
|
if len(ext_filter) != 0:
|
||||||
|
@ -121,16 +120,30 @@ def move_files(src_path: str, dest_path: str, ext_filter: str = None):
|
||||||
|
|
||||||
|
|
||||||
def load_upscalers():
|
def load_upscalers():
|
||||||
|
sd = shared.script_path
|
||||||
|
# We can only do this 'magic' method to dynamically load upscalers if they are referenced,
|
||||||
|
# so we'll try to import any _model.py files before looking in __subclasses__
|
||||||
|
modules_dir = os.path.join(sd, "modules")
|
||||||
|
for file in os.listdir(modules_dir):
|
||||||
|
if "_model.py" in file:
|
||||||
|
model_name = file.replace("_model.py", "")
|
||||||
|
full_model = f"modules.{model_name}_model"
|
||||||
|
try:
|
||||||
|
importlib.import_module(full_model)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
datas = []
|
datas = []
|
||||||
|
c_o = vars(shared.cmd_opts)
|
||||||
for cls in Upscaler.__subclasses__():
|
for cls in Upscaler.__subclasses__():
|
||||||
name = cls.__name__
|
name = cls.__name__
|
||||||
module_name = cls.__module__
|
module_name = cls.__module__
|
||||||
module = importlib.import_module(module_name)
|
module = importlib.import_module(module_name)
|
||||||
class_ = getattr(module, name)
|
class_ = getattr(module, name)
|
||||||
cmd_name = f"{name.lower().replace('upscaler', '')}-models-path"
|
cmd_name = f"{name.lower().replace('upscaler', '')}_models_path"
|
||||||
opt_string = None
|
opt_string = None
|
||||||
try:
|
try:
|
||||||
opt_string = shared.opts.__getattr__(cmd_name)
|
if cmd_name in c_o:
|
||||||
|
opt_string = c_o[cmd_name]
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
scaler = class_(opt_string)
|
scaler = class_(opt_string)
|
||||||
|
|
|
@ -20,7 +20,6 @@ path_dirs = [
|
||||||
(os.path.join(sd_path, '../taming-transformers'), 'taming', 'Taming Transformers', []),
|
(os.path.join(sd_path, '../taming-transformers'), 'taming', 'Taming Transformers', []),
|
||||||
(os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer', []),
|
(os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer', []),
|
||||||
(os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []),
|
(os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []),
|
||||||
(os.path.join(sd_path, '../latent-diffusion'), 'LDSR.py', 'LDSR', []),
|
|
||||||
(os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]),
|
(os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
@ -56,7 +56,7 @@ class StableDiffusionProcessing:
|
||||||
self.prompt: str = prompt
|
self.prompt: str = prompt
|
||||||
self.prompt_for_display: str = None
|
self.prompt_for_display: str = None
|
||||||
self.negative_prompt: str = (negative_prompt or "")
|
self.negative_prompt: str = (negative_prompt or "")
|
||||||
self.styles: str = styles
|
self.styles: list = styles or []
|
||||||
self.seed: int = seed
|
self.seed: int = seed
|
||||||
self.subseed: int = subseed
|
self.subseed: int = subseed
|
||||||
self.subseed_strength: float = subseed_strength
|
self.subseed_strength: float = subseed_strength
|
||||||
|
@ -271,7 +271,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
|
||||||
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
|
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
|
||||||
"Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
|
"Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
|
||||||
"Denoising strength": getattr(p, 'denoising_strength', None),
|
"Denoising strength": getattr(p, 'denoising_strength', None),
|
||||||
"Eta": (None if p.sampler.eta == p.sampler.default_eta else p.sampler.eta),
|
"Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta),
|
||||||
}
|
}
|
||||||
|
|
||||||
generation_params.update(p.extra_generation_params)
|
generation_params.update(p.extra_generation_params)
|
||||||
|
@ -295,8 +295,11 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
||||||
|
|
||||||
fix_seed(p)
|
fix_seed(p)
|
||||||
|
|
||||||
os.makedirs(p.outpath_samples, exist_ok=True)
|
if p.outpath_samples is not None:
|
||||||
os.makedirs(p.outpath_grids, exist_ok=True)
|
os.makedirs(p.outpath_samples, exist_ok=True)
|
||||||
|
|
||||||
|
if p.outpath_grids is not None:
|
||||||
|
os.makedirs(p.outpath_grids, exist_ok=True)
|
||||||
|
|
||||||
modules.sd_hijack.model_hijack.apply_circular(p.tiling)
|
modules.sd_hijack.model_hijack.apply_circular(p.tiling)
|
||||||
|
|
||||||
|
@ -323,7 +326,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
||||||
return create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration, position_in_batch)
|
return create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration, position_in_batch)
|
||||||
|
|
||||||
if os.path.exists(cmd_opts.embeddings_dir):
|
if os.path.exists(cmd_opts.embeddings_dir):
|
||||||
model_hijack.load_textual_inversion_embeddings(cmd_opts.embeddings_dir, p.sd_model)
|
model_hijack.embedding_db.load_textual_inversion_embeddings()
|
||||||
|
|
||||||
infotexts = []
|
infotexts = []
|
||||||
output_images = []
|
output_images = []
|
||||||
|
|
|
@ -162,6 +162,40 @@ class ScriptRunner:
|
||||||
|
|
||||||
return processed
|
return processed
|
||||||
|
|
||||||
|
def reload_sources(self):
|
||||||
|
for si, script in list(enumerate(self.scripts)):
|
||||||
|
with open(script.filename, "r", encoding="utf8") as file:
|
||||||
|
args_from = script.args_from
|
||||||
|
args_to = script.args_to
|
||||||
|
filename = script.filename
|
||||||
|
text = file.read()
|
||||||
|
|
||||||
|
from types import ModuleType
|
||||||
|
|
||||||
|
compiled = compile(text, filename, 'exec')
|
||||||
|
module = ModuleType(script.filename)
|
||||||
|
exec(compiled, module.__dict__)
|
||||||
|
|
||||||
|
for key, script_class in module.__dict__.items():
|
||||||
|
if type(script_class) == type and issubclass(script_class, Script):
|
||||||
|
self.scripts[si] = script_class()
|
||||||
|
self.scripts[si].filename = filename
|
||||||
|
self.scripts[si].args_from = args_from
|
||||||
|
self.scripts[si].args_to = args_to
|
||||||
|
|
||||||
scripts_txt2img = ScriptRunner()
|
scripts_txt2img = ScriptRunner()
|
||||||
scripts_img2img = ScriptRunner()
|
scripts_img2img = ScriptRunner()
|
||||||
|
|
||||||
|
def reload_script_body_only():
|
||||||
|
scripts_txt2img.reload_sources()
|
||||||
|
scripts_img2img.reload_sources()
|
||||||
|
|
||||||
|
|
||||||
|
def reload_scripts(basedir):
|
||||||
|
global scripts_txt2img, scripts_img2img
|
||||||
|
|
||||||
|
scripts_data.clear()
|
||||||
|
load_scripts(basedir)
|
||||||
|
|
||||||
|
scripts_txt2img = ScriptRunner()
|
||||||
|
scripts_img2img = ScriptRunner()
|
||||||
|
|
90
modules/scunet_model.py
Normal file
90
modules/scunet_model.py
Normal file
|
@ -0,0 +1,90 @@
|
||||||
|
import os.path
|
||||||
|
import sys
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
import PIL.Image
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from basicsr.utils.download_util import load_file_from_url
|
||||||
|
|
||||||
|
import modules.upscaler
|
||||||
|
from modules import shared, modelloader
|
||||||
|
from modules.paths import models_path
|
||||||
|
from modules.scunet_model_arch import SCUNet as net
|
||||||
|
|
||||||
|
|
||||||
|
class UpscalerScuNET(modules.upscaler.Upscaler):
|
||||||
|
def __init__(self, dirname):
|
||||||
|
self.name = "ScuNET"
|
||||||
|
self.model_path = os.path.join(models_path, self.name)
|
||||||
|
self.model_name = "ScuNET GAN"
|
||||||
|
self.model_name2 = "ScuNET PSNR"
|
||||||
|
self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/scunet_color_real_gan.pth"
|
||||||
|
self.model_url2 = "https://github.com/cszn/KAIR/releases/download/v1.0/scunet_color_real_psnr.pth"
|
||||||
|
self.user_path = dirname
|
||||||
|
super().__init__()
|
||||||
|
model_paths = self.find_models(ext_filter=[".pth"])
|
||||||
|
scalers = []
|
||||||
|
add_model2 = True
|
||||||
|
for file in model_paths:
|
||||||
|
if "http" in file:
|
||||||
|
name = self.model_name
|
||||||
|
else:
|
||||||
|
name = modelloader.friendly_name(file)
|
||||||
|
if name == self.model_name2 or file == self.model_url2:
|
||||||
|
add_model2 = False
|
||||||
|
try:
|
||||||
|
scaler_data = modules.upscaler.UpscalerData(name, file, self, 4)
|
||||||
|
scalers.append(scaler_data)
|
||||||
|
except Exception:
|
||||||
|
print(f"Error loading ScuNET model: {file}", file=sys.stderr)
|
||||||
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
if add_model2:
|
||||||
|
scaler_data2 = modules.upscaler.UpscalerData(self.model_name2, self.model_url2, self)
|
||||||
|
scalers.append(scaler_data2)
|
||||||
|
self.scalers = scalers
|
||||||
|
|
||||||
|
def do_upscale(self, img: PIL.Image, selected_file):
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
model = self.load_model(selected_file)
|
||||||
|
if model is None:
|
||||||
|
return img
|
||||||
|
|
||||||
|
device = shared.device
|
||||||
|
img = np.array(img)
|
||||||
|
img = img[:, :, ::-1]
|
||||||
|
img = np.moveaxis(img, 2, 0) / 255
|
||||||
|
img = torch.from_numpy(img).float()
|
||||||
|
img = img.unsqueeze(0).to(shared.device)
|
||||||
|
|
||||||
|
img = img.to(device)
|
||||||
|
with torch.no_grad():
|
||||||
|
output = model(img)
|
||||||
|
output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
|
||||||
|
output = 255. * np.moveaxis(output, 0, 2)
|
||||||
|
output = output.astype(np.uint8)
|
||||||
|
output = output[:, :, ::-1]
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
return PIL.Image.fromarray(output, 'RGB')
|
||||||
|
|
||||||
|
def load_model(self, path: str):
|
||||||
|
device = shared.device
|
||||||
|
if "http" in path:
|
||||||
|
filename = load_file_from_url(url=self.model_url, model_dir=self.model_path, file_name="%s.pth" % self.name,
|
||||||
|
progress=True)
|
||||||
|
else:
|
||||||
|
filename = path
|
||||||
|
if not os.path.exists(os.path.join(self.model_path, filename)) or filename is None:
|
||||||
|
print(f"ScuNET: Unable to load model from {filename}", file=sys.stderr)
|
||||||
|
return None
|
||||||
|
|
||||||
|
model = net(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64)
|
||||||
|
model.load_state_dict(torch.load(filename), strict=True)
|
||||||
|
model.eval()
|
||||||
|
for k, v in model.named_parameters():
|
||||||
|
v.requires_grad = False
|
||||||
|
model = model.to(device)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
265
modules/scunet_model_arch.py
Normal file
265
modules/scunet_model_arch.py
Normal file
|
@ -0,0 +1,265 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from einops import rearrange
|
||||||
|
from einops.layers.torch import Rearrange
|
||||||
|
from timm.models.layers import trunc_normal_, DropPath
|
||||||
|
|
||||||
|
|
||||||
|
class WMSA(nn.Module):
|
||||||
|
""" Self-attention module in Swin Transformer
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, input_dim, output_dim, head_dim, window_size, type):
|
||||||
|
super(WMSA, self).__init__()
|
||||||
|
self.input_dim = input_dim
|
||||||
|
self.output_dim = output_dim
|
||||||
|
self.head_dim = head_dim
|
||||||
|
self.scale = self.head_dim ** -0.5
|
||||||
|
self.n_heads = input_dim // head_dim
|
||||||
|
self.window_size = window_size
|
||||||
|
self.type = type
|
||||||
|
self.embedding_layer = nn.Linear(self.input_dim, 3 * self.input_dim, bias=True)
|
||||||
|
|
||||||
|
self.relative_position_params = nn.Parameter(
|
||||||
|
torch.zeros((2 * window_size - 1) * (2 * window_size - 1), self.n_heads))
|
||||||
|
|
||||||
|
self.linear = nn.Linear(self.input_dim, self.output_dim)
|
||||||
|
|
||||||
|
trunc_normal_(self.relative_position_params, std=.02)
|
||||||
|
self.relative_position_params = torch.nn.Parameter(
|
||||||
|
self.relative_position_params.view(2 * window_size - 1, 2 * window_size - 1, self.n_heads).transpose(1,
|
||||||
|
2).transpose(
|
||||||
|
0, 1))
|
||||||
|
|
||||||
|
def generate_mask(self, h, w, p, shift):
|
||||||
|
""" generating the mask of SW-MSA
|
||||||
|
Args:
|
||||||
|
shift: shift parameters in CyclicShift.
|
||||||
|
Returns:
|
||||||
|
attn_mask: should be (1 1 w p p),
|
||||||
|
"""
|
||||||
|
# supporting sqaure.
|
||||||
|
attn_mask = torch.zeros(h, w, p, p, p, p, dtype=torch.bool, device=self.relative_position_params.device)
|
||||||
|
if self.type == 'W':
|
||||||
|
return attn_mask
|
||||||
|
|
||||||
|
s = p - shift
|
||||||
|
attn_mask[-1, :, :s, :, s:, :] = True
|
||||||
|
attn_mask[-1, :, s:, :, :s, :] = True
|
||||||
|
attn_mask[:, -1, :, :s, :, s:] = True
|
||||||
|
attn_mask[:, -1, :, s:, :, :s] = True
|
||||||
|
attn_mask = rearrange(attn_mask, 'w1 w2 p1 p2 p3 p4 -> 1 1 (w1 w2) (p1 p2) (p3 p4)')
|
||||||
|
return attn_mask
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
""" Forward pass of Window Multi-head Self-attention module.
|
||||||
|
Args:
|
||||||
|
x: input tensor with shape of [b h w c];
|
||||||
|
attn_mask: attention mask, fill -inf where the value is True;
|
||||||
|
Returns:
|
||||||
|
output: tensor shape [b h w c]
|
||||||
|
"""
|
||||||
|
if self.type != 'W': x = torch.roll(x, shifts=(-(self.window_size // 2), -(self.window_size // 2)), dims=(1, 2))
|
||||||
|
x = rearrange(x, 'b (w1 p1) (w2 p2) c -> b w1 w2 p1 p2 c', p1=self.window_size, p2=self.window_size)
|
||||||
|
h_windows = x.size(1)
|
||||||
|
w_windows = x.size(2)
|
||||||
|
# sqaure validation
|
||||||
|
# assert h_windows == w_windows
|
||||||
|
|
||||||
|
x = rearrange(x, 'b w1 w2 p1 p2 c -> b (w1 w2) (p1 p2) c', p1=self.window_size, p2=self.window_size)
|
||||||
|
qkv = self.embedding_layer(x)
|
||||||
|
q, k, v = rearrange(qkv, 'b nw np (threeh c) -> threeh b nw np c', c=self.head_dim).chunk(3, dim=0)
|
||||||
|
sim = torch.einsum('hbwpc,hbwqc->hbwpq', q, k) * self.scale
|
||||||
|
# Adding learnable relative embedding
|
||||||
|
sim = sim + rearrange(self.relative_embedding(), 'h p q -> h 1 1 p q')
|
||||||
|
# Using Attn Mask to distinguish different subwindows.
|
||||||
|
if self.type != 'W':
|
||||||
|
attn_mask = self.generate_mask(h_windows, w_windows, self.window_size, shift=self.window_size // 2)
|
||||||
|
sim = sim.masked_fill_(attn_mask, float("-inf"))
|
||||||
|
|
||||||
|
probs = nn.functional.softmax(sim, dim=-1)
|
||||||
|
output = torch.einsum('hbwij,hbwjc->hbwic', probs, v)
|
||||||
|
output = rearrange(output, 'h b w p c -> b w p (h c)')
|
||||||
|
output = self.linear(output)
|
||||||
|
output = rearrange(output, 'b (w1 w2) (p1 p2) c -> b (w1 p1) (w2 p2) c', w1=h_windows, p1=self.window_size)
|
||||||
|
|
||||||
|
if self.type != 'W': output = torch.roll(output, shifts=(self.window_size // 2, self.window_size // 2),
|
||||||
|
dims=(1, 2))
|
||||||
|
return output
|
||||||
|
|
||||||
|
def relative_embedding(self):
|
||||||
|
cord = torch.tensor(np.array([[i, j] for i in range(self.window_size) for j in range(self.window_size)]))
|
||||||
|
relation = cord[:, None, :] - cord[None, :, :] + self.window_size - 1
|
||||||
|
# negative is allowed
|
||||||
|
return self.relative_position_params[:, relation[:, :, 0].long(), relation[:, :, 1].long()]
|
||||||
|
|
||||||
|
|
||||||
|
class Block(nn.Module):
|
||||||
|
def __init__(self, input_dim, output_dim, head_dim, window_size, drop_path, type='W', input_resolution=None):
|
||||||
|
""" SwinTransformer Block
|
||||||
|
"""
|
||||||
|
super(Block, self).__init__()
|
||||||
|
self.input_dim = input_dim
|
||||||
|
self.output_dim = output_dim
|
||||||
|
assert type in ['W', 'SW']
|
||||||
|
self.type = type
|
||||||
|
if input_resolution <= window_size:
|
||||||
|
self.type = 'W'
|
||||||
|
|
||||||
|
self.ln1 = nn.LayerNorm(input_dim)
|
||||||
|
self.msa = WMSA(input_dim, input_dim, head_dim, window_size, self.type)
|
||||||
|
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||||
|
self.ln2 = nn.LayerNorm(input_dim)
|
||||||
|
self.mlp = nn.Sequential(
|
||||||
|
nn.Linear(input_dim, 4 * input_dim),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Linear(4 * input_dim, output_dim),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = x + self.drop_path(self.msa(self.ln1(x)))
|
||||||
|
x = x + self.drop_path(self.mlp(self.ln2(x)))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class ConvTransBlock(nn.Module):
|
||||||
|
def __init__(self, conv_dim, trans_dim, head_dim, window_size, drop_path, type='W', input_resolution=None):
|
||||||
|
""" SwinTransformer and Conv Block
|
||||||
|
"""
|
||||||
|
super(ConvTransBlock, self).__init__()
|
||||||
|
self.conv_dim = conv_dim
|
||||||
|
self.trans_dim = trans_dim
|
||||||
|
self.head_dim = head_dim
|
||||||
|
self.window_size = window_size
|
||||||
|
self.drop_path = drop_path
|
||||||
|
self.type = type
|
||||||
|
self.input_resolution = input_resolution
|
||||||
|
|
||||||
|
assert self.type in ['W', 'SW']
|
||||||
|
if self.input_resolution <= self.window_size:
|
||||||
|
self.type = 'W'
|
||||||
|
|
||||||
|
self.trans_block = Block(self.trans_dim, self.trans_dim, self.head_dim, self.window_size, self.drop_path,
|
||||||
|
self.type, self.input_resolution)
|
||||||
|
self.conv1_1 = nn.Conv2d(self.conv_dim + self.trans_dim, self.conv_dim + self.trans_dim, 1, 1, 0, bias=True)
|
||||||
|
self.conv1_2 = nn.Conv2d(self.conv_dim + self.trans_dim, self.conv_dim + self.trans_dim, 1, 1, 0, bias=True)
|
||||||
|
|
||||||
|
self.conv_block = nn.Sequential(
|
||||||
|
nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False),
|
||||||
|
nn.ReLU(True),
|
||||||
|
nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
conv_x, trans_x = torch.split(self.conv1_1(x), (self.conv_dim, self.trans_dim), dim=1)
|
||||||
|
conv_x = self.conv_block(conv_x) + conv_x
|
||||||
|
trans_x = Rearrange('b c h w -> b h w c')(trans_x)
|
||||||
|
trans_x = self.trans_block(trans_x)
|
||||||
|
trans_x = Rearrange('b h w c -> b c h w')(trans_x)
|
||||||
|
res = self.conv1_2(torch.cat((conv_x, trans_x), dim=1))
|
||||||
|
x = x + res
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class SCUNet(nn.Module):
|
||||||
|
# def __init__(self, in_nc=3, config=[2, 2, 2, 2, 2, 2, 2], dim=64, drop_path_rate=0.0, input_resolution=256):
|
||||||
|
def __init__(self, in_nc=3, config=None, dim=64, drop_path_rate=0.0, input_resolution=256):
|
||||||
|
super(SCUNet, self).__init__()
|
||||||
|
if config is None:
|
||||||
|
config = [2, 2, 2, 2, 2, 2, 2]
|
||||||
|
self.config = config
|
||||||
|
self.dim = dim
|
||||||
|
self.head_dim = 32
|
||||||
|
self.window_size = 8
|
||||||
|
|
||||||
|
# drop path rate for each layer
|
||||||
|
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(config))]
|
||||||
|
|
||||||
|
self.m_head = [nn.Conv2d(in_nc, dim, 3, 1, 1, bias=False)]
|
||||||
|
|
||||||
|
begin = 0
|
||||||
|
self.m_down1 = [ConvTransBlock(dim // 2, dim // 2, self.head_dim, self.window_size, dpr[i + begin],
|
||||||
|
'W' if not i % 2 else 'SW', input_resolution)
|
||||||
|
for i in range(config[0])] + \
|
||||||
|
[nn.Conv2d(dim, 2 * dim, 2, 2, 0, bias=False)]
|
||||||
|
|
||||||
|
begin += config[0]
|
||||||
|
self.m_down2 = [ConvTransBlock(dim, dim, self.head_dim, self.window_size, dpr[i + begin],
|
||||||
|
'W' if not i % 2 else 'SW', input_resolution // 2)
|
||||||
|
for i in range(config[1])] + \
|
||||||
|
[nn.Conv2d(2 * dim, 4 * dim, 2, 2, 0, bias=False)]
|
||||||
|
|
||||||
|
begin += config[1]
|
||||||
|
self.m_down3 = [ConvTransBlock(2 * dim, 2 * dim, self.head_dim, self.window_size, dpr[i + begin],
|
||||||
|
'W' if not i % 2 else 'SW', input_resolution // 4)
|
||||||
|
for i in range(config[2])] + \
|
||||||
|
[nn.Conv2d(4 * dim, 8 * dim, 2, 2, 0, bias=False)]
|
||||||
|
|
||||||
|
begin += config[2]
|
||||||
|
self.m_body = [ConvTransBlock(4 * dim, 4 * dim, self.head_dim, self.window_size, dpr[i + begin],
|
||||||
|
'W' if not i % 2 else 'SW', input_resolution // 8)
|
||||||
|
for i in range(config[3])]
|
||||||
|
|
||||||
|
begin += config[3]
|
||||||
|
self.m_up3 = [nn.ConvTranspose2d(8 * dim, 4 * dim, 2, 2, 0, bias=False), ] + \
|
||||||
|
[ConvTransBlock(2 * dim, 2 * dim, self.head_dim, self.window_size, dpr[i + begin],
|
||||||
|
'W' if not i % 2 else 'SW', input_resolution // 4)
|
||||||
|
for i in range(config[4])]
|
||||||
|
|
||||||
|
begin += config[4]
|
||||||
|
self.m_up2 = [nn.ConvTranspose2d(4 * dim, 2 * dim, 2, 2, 0, bias=False), ] + \
|
||||||
|
[ConvTransBlock(dim, dim, self.head_dim, self.window_size, dpr[i + begin],
|
||||||
|
'W' if not i % 2 else 'SW', input_resolution // 2)
|
||||||
|
for i in range(config[5])]
|
||||||
|
|
||||||
|
begin += config[5]
|
||||||
|
self.m_up1 = [nn.ConvTranspose2d(2 * dim, dim, 2, 2, 0, bias=False), ] + \
|
||||||
|
[ConvTransBlock(dim // 2, dim // 2, self.head_dim, self.window_size, dpr[i + begin],
|
||||||
|
'W' if not i % 2 else 'SW', input_resolution)
|
||||||
|
for i in range(config[6])]
|
||||||
|
|
||||||
|
self.m_tail = [nn.Conv2d(dim, in_nc, 3, 1, 1, bias=False)]
|
||||||
|
|
||||||
|
self.m_head = nn.Sequential(*self.m_head)
|
||||||
|
self.m_down1 = nn.Sequential(*self.m_down1)
|
||||||
|
self.m_down2 = nn.Sequential(*self.m_down2)
|
||||||
|
self.m_down3 = nn.Sequential(*self.m_down3)
|
||||||
|
self.m_body = nn.Sequential(*self.m_body)
|
||||||
|
self.m_up3 = nn.Sequential(*self.m_up3)
|
||||||
|
self.m_up2 = nn.Sequential(*self.m_up2)
|
||||||
|
self.m_up1 = nn.Sequential(*self.m_up1)
|
||||||
|
self.m_tail = nn.Sequential(*self.m_tail)
|
||||||
|
# self.apply(self._init_weights)
|
||||||
|
|
||||||
|
def forward(self, x0):
|
||||||
|
|
||||||
|
h, w = x0.size()[-2:]
|
||||||
|
paddingBottom = int(np.ceil(h / 64) * 64 - h)
|
||||||
|
paddingRight = int(np.ceil(w / 64) * 64 - w)
|
||||||
|
x0 = nn.ReplicationPad2d((0, paddingRight, 0, paddingBottom))(x0)
|
||||||
|
|
||||||
|
x1 = self.m_head(x0)
|
||||||
|
x2 = self.m_down1(x1)
|
||||||
|
x3 = self.m_down2(x2)
|
||||||
|
x4 = self.m_down3(x3)
|
||||||
|
x = self.m_body(x4)
|
||||||
|
x = self.m_up3(x + x4)
|
||||||
|
x = self.m_up2(x + x3)
|
||||||
|
x = self.m_up1(x + x2)
|
||||||
|
x = self.m_tail(x + x1)
|
||||||
|
|
||||||
|
x = x[..., :h, :w]
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def _init_weights(self, m):
|
||||||
|
if isinstance(m, nn.Linear):
|
||||||
|
trunc_normal_(m.weight, std=.02)
|
||||||
|
if m.bias is not None:
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
elif isinstance(m, nn.LayerNorm):
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
nn.init.constant_(m.weight, 1.0)
|
|
@ -6,244 +6,41 @@ import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from torch import einsum
|
from torch import einsum
|
||||||
|
|
||||||
from modules import prompt_parser
|
import modules.textual_inversion.textual_inversion
|
||||||
|
from modules import prompt_parser, devices, sd_hijack_optimizations, shared
|
||||||
from modules.shared import opts, device, cmd_opts
|
from modules.shared import opts, device, cmd_opts
|
||||||
|
|
||||||
from ldm.util import default
|
|
||||||
from einops import rearrange
|
|
||||||
import ldm.modules.attention
|
import ldm.modules.attention
|
||||||
import ldm.modules.diffusionmodules.model
|
import ldm.modules.diffusionmodules.model
|
||||||
|
|
||||||
|
attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
|
||||||
# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
|
diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
|
||||||
def split_cross_attention_forward_v1(self, x, context=None, mask=None):
|
diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
|
||||||
h = self.heads
|
|
||||||
|
|
||||||
q = self.to_q(x)
|
|
||||||
context = default(context, x)
|
|
||||||
k = self.to_k(context)
|
|
||||||
v = self.to_v(context)
|
|
||||||
del context, x
|
|
||||||
|
|
||||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
|
||||||
|
|
||||||
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
|
|
||||||
for i in range(0, q.shape[0], 2):
|
|
||||||
end = i + 2
|
|
||||||
s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
|
|
||||||
s1 *= self.scale
|
|
||||||
|
|
||||||
s2 = s1.softmax(dim=-1)
|
|
||||||
del s1
|
|
||||||
|
|
||||||
r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
|
|
||||||
del s2
|
|
||||||
|
|
||||||
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
|
|
||||||
del r1
|
|
||||||
|
|
||||||
return self.to_out(r2)
|
|
||||||
|
|
||||||
|
|
||||||
# taken from https://github.com/Doggettx/stable-diffusion
|
def apply_optimizations():
|
||||||
def split_cross_attention_forward(self, x, context=None, mask=None):
|
if cmd_opts.opt_split_attention_v1:
|
||||||
h = self.heads
|
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
|
||||||
|
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
|
||||||
|
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward
|
||||||
|
ldm.modules.diffusionmodules.model.nonlinearity = sd_hijack_optimizations.nonlinearity_hijack
|
||||||
|
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward
|
||||||
|
|
||||||
q_in = self.to_q(x)
|
|
||||||
context = default(context, x)
|
|
||||||
k_in = self.to_k(context) * self.scale
|
|
||||||
v_in = self.to_v(context)
|
|
||||||
del context, x
|
|
||||||
|
|
||||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
|
def undo_optimizations():
|
||||||
del q_in, k_in, v_in
|
ldm.modules.attention.CrossAttention.forward = attention_CrossAttention_forward
|
||||||
|
ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity
|
||||||
|
ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
|
||||||
|
|
||||||
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
|
||||||
|
|
||||||
stats = torch.cuda.memory_stats(q.device)
|
|
||||||
mem_active = stats['active_bytes.all.current']
|
|
||||||
mem_reserved = stats['reserved_bytes.all.current']
|
|
||||||
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
|
|
||||||
mem_free_torch = mem_reserved - mem_active
|
|
||||||
mem_free_total = mem_free_cuda + mem_free_torch
|
|
||||||
|
|
||||||
gb = 1024 ** 3
|
|
||||||
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
|
|
||||||
modifier = 3 if q.element_size() == 2 else 2.5
|
|
||||||
mem_required = tensor_size * modifier
|
|
||||||
steps = 1
|
|
||||||
|
|
||||||
if mem_required > mem_free_total:
|
|
||||||
steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
|
|
||||||
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
|
|
||||||
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
|
|
||||||
|
|
||||||
if steps > 64:
|
|
||||||
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
|
|
||||||
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
|
|
||||||
f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
|
|
||||||
|
|
||||||
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
|
||||||
for i in range(0, q.shape[1], slice_size):
|
|
||||||
end = i + slice_size
|
|
||||||
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
|
|
||||||
|
|
||||||
s2 = s1.softmax(dim=-1, dtype=q.dtype)
|
|
||||||
del s1
|
|
||||||
|
|
||||||
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
|
|
||||||
del s2
|
|
||||||
|
|
||||||
del q, k, v
|
|
||||||
|
|
||||||
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
|
|
||||||
del r1
|
|
||||||
|
|
||||||
return self.to_out(r2)
|
|
||||||
|
|
||||||
def nonlinearity_hijack(x):
|
|
||||||
# swish
|
|
||||||
t = torch.sigmoid(x)
|
|
||||||
x *= t
|
|
||||||
del t
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
def cross_attention_attnblock_forward(self, x):
|
|
||||||
h_ = x
|
|
||||||
h_ = self.norm(h_)
|
|
||||||
q1 = self.q(h_)
|
|
||||||
k1 = self.k(h_)
|
|
||||||
v = self.v(h_)
|
|
||||||
|
|
||||||
# compute attention
|
|
||||||
b, c, h, w = q1.shape
|
|
||||||
|
|
||||||
q2 = q1.reshape(b, c, h*w)
|
|
||||||
del q1
|
|
||||||
|
|
||||||
q = q2.permute(0, 2, 1) # b,hw,c
|
|
||||||
del q2
|
|
||||||
|
|
||||||
k = k1.reshape(b, c, h*w) # b,c,hw
|
|
||||||
del k1
|
|
||||||
|
|
||||||
h_ = torch.zeros_like(k, device=q.device)
|
|
||||||
|
|
||||||
stats = torch.cuda.memory_stats(q.device)
|
|
||||||
mem_active = stats['active_bytes.all.current']
|
|
||||||
mem_reserved = stats['reserved_bytes.all.current']
|
|
||||||
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
|
|
||||||
mem_free_torch = mem_reserved - mem_active
|
|
||||||
mem_free_total = mem_free_cuda + mem_free_torch
|
|
||||||
|
|
||||||
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
|
|
||||||
mem_required = tensor_size * 2.5
|
|
||||||
steps = 1
|
|
||||||
|
|
||||||
if mem_required > mem_free_total:
|
|
||||||
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
|
|
||||||
|
|
||||||
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
|
||||||
for i in range(0, q.shape[1], slice_size):
|
|
||||||
end = i + slice_size
|
|
||||||
|
|
||||||
w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
|
||||||
w2 = w1 * (int(c)**(-0.5))
|
|
||||||
del w1
|
|
||||||
w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype)
|
|
||||||
del w2
|
|
||||||
|
|
||||||
# attend to values
|
|
||||||
v1 = v.reshape(b, c, h*w)
|
|
||||||
w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
|
|
||||||
del w3
|
|
||||||
|
|
||||||
h_[:, :, i:end] = torch.bmm(v1, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
|
||||||
del v1, w4
|
|
||||||
|
|
||||||
h2 = h_.reshape(b, c, h, w)
|
|
||||||
del h_
|
|
||||||
|
|
||||||
h3 = self.proj_out(h2)
|
|
||||||
del h2
|
|
||||||
|
|
||||||
h3 += x
|
|
||||||
|
|
||||||
return h3
|
|
||||||
|
|
||||||
class StableDiffusionModelHijack:
|
class StableDiffusionModelHijack:
|
||||||
ids_lookup = {}
|
|
||||||
word_embeddings = {}
|
|
||||||
word_embeddings_checksums = {}
|
|
||||||
fixes = None
|
fixes = None
|
||||||
comments = []
|
comments = []
|
||||||
dir_mtime = None
|
|
||||||
layers = None
|
layers = None
|
||||||
circular_enabled = False
|
circular_enabled = False
|
||||||
clip = None
|
clip = None
|
||||||
|
|
||||||
def load_textual_inversion_embeddings(self, dirname, model):
|
embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase(cmd_opts.embeddings_dir)
|
||||||
mt = os.path.getmtime(dirname)
|
|
||||||
if self.dir_mtime is not None and mt <= self.dir_mtime:
|
|
||||||
return
|
|
||||||
|
|
||||||
self.dir_mtime = mt
|
|
||||||
self.ids_lookup.clear()
|
|
||||||
self.word_embeddings.clear()
|
|
||||||
|
|
||||||
tokenizer = model.cond_stage_model.tokenizer
|
|
||||||
|
|
||||||
def const_hash(a):
|
|
||||||
r = 0
|
|
||||||
for v in a:
|
|
||||||
r = (r * 281 ^ int(v) * 997) & 0xFFFFFFFF
|
|
||||||
return r
|
|
||||||
|
|
||||||
def process_file(path, filename):
|
|
||||||
name = os.path.splitext(filename)[0]
|
|
||||||
|
|
||||||
data = torch.load(path, map_location="cpu")
|
|
||||||
|
|
||||||
# textual inversion embeddings
|
|
||||||
if 'string_to_param' in data:
|
|
||||||
param_dict = data['string_to_param']
|
|
||||||
if hasattr(param_dict, '_parameters'):
|
|
||||||
param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11
|
|
||||||
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
|
|
||||||
emb = next(iter(param_dict.items()))[1]
|
|
||||||
# diffuser concepts
|
|
||||||
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
|
|
||||||
assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
|
|
||||||
|
|
||||||
emb = next(iter(data.values()))
|
|
||||||
if len(emb.shape) == 1:
|
|
||||||
emb = emb.unsqueeze(0)
|
|
||||||
|
|
||||||
self.word_embeddings[name] = emb.detach().to(device)
|
|
||||||
self.word_embeddings_checksums[name] = f'{const_hash(emb.reshape(-1)*100)&0xffff:04x}'
|
|
||||||
|
|
||||||
ids = tokenizer([name], add_special_tokens=False)['input_ids'][0]
|
|
||||||
|
|
||||||
first_id = ids[0]
|
|
||||||
if first_id not in self.ids_lookup:
|
|
||||||
self.ids_lookup[first_id] = []
|
|
||||||
self.ids_lookup[first_id].append((ids, name))
|
|
||||||
|
|
||||||
for fn in os.listdir(dirname):
|
|
||||||
try:
|
|
||||||
fullfn = os.path.join(dirname, fn)
|
|
||||||
|
|
||||||
if os.stat(fullfn).st_size == 0:
|
|
||||||
continue
|
|
||||||
|
|
||||||
process_file(fullfn, fn)
|
|
||||||
except Exception:
|
|
||||||
print(f"Error loading emedding {fn}:", file=sys.stderr)
|
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
continue
|
|
||||||
|
|
||||||
print(f"Loaded a total of {len(self.word_embeddings)} textual inversion embeddings.")
|
|
||||||
|
|
||||||
def hijack(self, m):
|
def hijack(self, m):
|
||||||
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
|
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
|
||||||
|
@ -253,12 +50,7 @@ class StableDiffusionModelHijack:
|
||||||
|
|
||||||
self.clip = m.cond_stage_model
|
self.clip = m.cond_stage_model
|
||||||
|
|
||||||
if cmd_opts.opt_split_attention_v1:
|
apply_optimizations()
|
||||||
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1
|
|
||||||
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
|
|
||||||
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward
|
|
||||||
ldm.modules.diffusionmodules.model.nonlinearity = nonlinearity_hijack
|
|
||||||
ldm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward
|
|
||||||
|
|
||||||
def flatten(el):
|
def flatten(el):
|
||||||
flattened = [flatten(children) for children in el.children()]
|
flattened = [flatten(children) for children in el.children()]
|
||||||
|
@ -296,7 +88,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
||||||
def __init__(self, wrapped, hijack):
|
def __init__(self, wrapped, hijack):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.wrapped = wrapped
|
self.wrapped = wrapped
|
||||||
self.hijack = hijack
|
self.hijack: StableDiffusionModelHijack = hijack
|
||||||
self.tokenizer = wrapped.tokenizer
|
self.tokenizer = wrapped.tokenizer
|
||||||
self.max_length = wrapped.max_length
|
self.max_length = wrapped.max_length
|
||||||
self.token_mults = {}
|
self.token_mults = {}
|
||||||
|
@ -317,7 +109,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
||||||
if mult != 1.0:
|
if mult != 1.0:
|
||||||
self.token_mults[ident] = mult
|
self.token_mults[ident] = mult
|
||||||
|
|
||||||
|
|
||||||
def tokenize_line(self, line, used_custom_terms, hijack_comments):
|
def tokenize_line(self, line, used_custom_terms, hijack_comments):
|
||||||
id_start = self.wrapped.tokenizer.bos_token_id
|
id_start = self.wrapped.tokenizer.bos_token_id
|
||||||
id_end = self.wrapped.tokenizer.eos_token_id
|
id_end = self.wrapped.tokenizer.eos_token_id
|
||||||
|
@ -339,28 +130,19 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
||||||
while i < len(tokens):
|
while i < len(tokens):
|
||||||
token = tokens[i]
|
token = tokens[i]
|
||||||
|
|
||||||
possible_matches = self.hijack.ids_lookup.get(token, None)
|
embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
|
||||||
|
|
||||||
if possible_matches is None:
|
if embedding is None:
|
||||||
remade_tokens.append(token)
|
remade_tokens.append(token)
|
||||||
multipliers.append(weight)
|
multipliers.append(weight)
|
||||||
|
i += 1
|
||||||
else:
|
else:
|
||||||
found = False
|
emb_len = int(embedding.vec.shape[0])
|
||||||
for ids, word in possible_matches:
|
fixes.append((len(remade_tokens), embedding))
|
||||||
if tokens[i:i + len(ids)] == ids:
|
remade_tokens += [0] * emb_len
|
||||||
emb_len = int(self.hijack.word_embeddings[word].shape[0])
|
multipliers += [weight] * emb_len
|
||||||
fixes.append((len(remade_tokens), word))
|
used_custom_terms.append((embedding.name, embedding.checksum()))
|
||||||
remade_tokens += [0] * emb_len
|
i += embedding_length_in_tokens
|
||||||
multipliers += [weight] * emb_len
|
|
||||||
i += len(ids) - 1
|
|
||||||
found = True
|
|
||||||
used_custom_terms.append((word, self.hijack.word_embeddings_checksums[word]))
|
|
||||||
break
|
|
||||||
|
|
||||||
if not found:
|
|
||||||
remade_tokens.append(token)
|
|
||||||
multipliers.append(weight)
|
|
||||||
i += 1
|
|
||||||
|
|
||||||
if len(remade_tokens) > maxlen - 2:
|
if len(remade_tokens) > maxlen - 2:
|
||||||
vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
|
vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
|
||||||
|
@ -431,32 +213,23 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
||||||
while i < len(tokens):
|
while i < len(tokens):
|
||||||
token = tokens[i]
|
token = tokens[i]
|
||||||
|
|
||||||
possible_matches = self.hijack.ids_lookup.get(token, None)
|
embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
|
||||||
|
|
||||||
mult_change = self.token_mults.get(token) if opts.enable_emphasis else None
|
mult_change = self.token_mults.get(token) if opts.enable_emphasis else None
|
||||||
if mult_change is not None:
|
if mult_change is not None:
|
||||||
mult *= mult_change
|
mult *= mult_change
|
||||||
elif possible_matches is None:
|
i += 1
|
||||||
|
elif embedding is None:
|
||||||
remade_tokens.append(token)
|
remade_tokens.append(token)
|
||||||
multipliers.append(mult)
|
multipliers.append(mult)
|
||||||
|
i += 1
|
||||||
else:
|
else:
|
||||||
found = False
|
emb_len = int(embedding.vec.shape[0])
|
||||||
for ids, word in possible_matches:
|
fixes.append((len(remade_tokens), embedding))
|
||||||
if tokens[i:i+len(ids)] == ids:
|
remade_tokens += [0] * emb_len
|
||||||
emb_len = int(self.hijack.word_embeddings[word].shape[0])
|
multipliers += [mult] * emb_len
|
||||||
fixes.append((len(remade_tokens), word))
|
used_custom_terms.append((embedding.name, embedding.checksum()))
|
||||||
remade_tokens += [0] * emb_len
|
i += embedding_length_in_tokens
|
||||||
multipliers += [mult] * emb_len
|
|
||||||
i += len(ids) - 1
|
|
||||||
found = True
|
|
||||||
used_custom_terms.append((word, self.hijack.word_embeddings_checksums[word]))
|
|
||||||
break
|
|
||||||
|
|
||||||
if not found:
|
|
||||||
remade_tokens.append(token)
|
|
||||||
multipliers.append(mult)
|
|
||||||
|
|
||||||
i += 1
|
|
||||||
|
|
||||||
if len(remade_tokens) > maxlen - 2:
|
if len(remade_tokens) > maxlen - 2:
|
||||||
vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
|
vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
|
||||||
|
@ -464,6 +237,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
||||||
overflowing_words = [vocab.get(int(x), "") for x in ovf]
|
overflowing_words = [vocab.get(int(x), "") for x in ovf]
|
||||||
overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
|
overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
|
||||||
hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
|
hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
|
||||||
|
|
||||||
token_count = len(remade_tokens)
|
token_count = len(remade_tokens)
|
||||||
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
|
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
|
||||||
remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end]
|
remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end]
|
||||||
|
@ -484,7 +258,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
||||||
else:
|
else:
|
||||||
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text)
|
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text)
|
||||||
|
|
||||||
|
|
||||||
self.hijack.fixes = hijack_fixes
|
self.hijack.fixes = hijack_fixes
|
||||||
self.hijack.comments = hijack_comments
|
self.hijack.comments = hijack_comments
|
||||||
|
|
||||||
|
@ -517,14 +290,19 @@ class EmbeddingsWithFixes(torch.nn.Module):
|
||||||
|
|
||||||
inputs_embeds = self.wrapped(input_ids)
|
inputs_embeds = self.wrapped(input_ids)
|
||||||
|
|
||||||
if batch_fixes is not None:
|
if batch_fixes is None or len(batch_fixes) == 0 or max([len(x) for x in batch_fixes]) == 0:
|
||||||
for fixes, tensor in zip(batch_fixes, inputs_embeds):
|
return inputs_embeds
|
||||||
for offset, word in fixes:
|
|
||||||
emb = self.embeddings.word_embeddings[word]
|
|
||||||
emb_len = min(tensor.shape[0]-offset-1, emb.shape[0])
|
|
||||||
tensor[offset+1:offset+1+emb_len] = self.embeddings.word_embeddings[word][0:emb_len]
|
|
||||||
|
|
||||||
return inputs_embeds
|
vecs = []
|
||||||
|
for fixes, tensor in zip(batch_fixes, inputs_embeds):
|
||||||
|
for offset, embedding in fixes:
|
||||||
|
emb = embedding.vec
|
||||||
|
emb_len = min(tensor.shape[0]-offset-1, emb.shape[0])
|
||||||
|
tensor = torch.cat([tensor[0:offset+1], emb[0:emb_len], tensor[offset+1+emb_len:]])
|
||||||
|
|
||||||
|
vecs.append(tensor)
|
||||||
|
|
||||||
|
return torch.stack(vecs)
|
||||||
|
|
||||||
|
|
||||||
def add_circular_option_to_conv_2d():
|
def add_circular_option_to_conv_2d():
|
||||||
|
|
164
modules/sd_hijack_optimizations.py
Normal file
164
modules/sd_hijack_optimizations.py
Normal file
|
@ -0,0 +1,164 @@
|
||||||
|
import math
|
||||||
|
import torch
|
||||||
|
from torch import einsum
|
||||||
|
|
||||||
|
from ldm.util import default
|
||||||
|
from einops import rearrange
|
||||||
|
|
||||||
|
|
||||||
|
# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
|
||||||
|
def split_cross_attention_forward_v1(self, x, context=None, mask=None):
|
||||||
|
h = self.heads
|
||||||
|
|
||||||
|
q = self.to_q(x)
|
||||||
|
context = default(context, x)
|
||||||
|
k = self.to_k(context)
|
||||||
|
v = self.to_v(context)
|
||||||
|
del context, x
|
||||||
|
|
||||||
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
||||||
|
|
||||||
|
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
|
||||||
|
for i in range(0, q.shape[0], 2):
|
||||||
|
end = i + 2
|
||||||
|
s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
|
||||||
|
s1 *= self.scale
|
||||||
|
|
||||||
|
s2 = s1.softmax(dim=-1)
|
||||||
|
del s1
|
||||||
|
|
||||||
|
r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
|
||||||
|
del s2
|
||||||
|
|
||||||
|
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
|
||||||
|
del r1
|
||||||
|
|
||||||
|
return self.to_out(r2)
|
||||||
|
|
||||||
|
|
||||||
|
# taken from https://github.com/Doggettx/stable-diffusion
|
||||||
|
def split_cross_attention_forward(self, x, context=None, mask=None):
|
||||||
|
h = self.heads
|
||||||
|
|
||||||
|
q_in = self.to_q(x)
|
||||||
|
context = default(context, x)
|
||||||
|
k_in = self.to_k(context) * self.scale
|
||||||
|
v_in = self.to_v(context)
|
||||||
|
del context, x
|
||||||
|
|
||||||
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
|
||||||
|
del q_in, k_in, v_in
|
||||||
|
|
||||||
|
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||||
|
|
||||||
|
stats = torch.cuda.memory_stats(q.device)
|
||||||
|
mem_active = stats['active_bytes.all.current']
|
||||||
|
mem_reserved = stats['reserved_bytes.all.current']
|
||||||
|
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
|
||||||
|
mem_free_torch = mem_reserved - mem_active
|
||||||
|
mem_free_total = mem_free_cuda + mem_free_torch
|
||||||
|
|
||||||
|
gb = 1024 ** 3
|
||||||
|
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
|
||||||
|
modifier = 3 if q.element_size() == 2 else 2.5
|
||||||
|
mem_required = tensor_size * modifier
|
||||||
|
steps = 1
|
||||||
|
|
||||||
|
if mem_required > mem_free_total:
|
||||||
|
steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
|
||||||
|
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
|
||||||
|
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
|
||||||
|
|
||||||
|
if steps > 64:
|
||||||
|
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
|
||||||
|
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
|
||||||
|
f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
|
||||||
|
|
||||||
|
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
||||||
|
for i in range(0, q.shape[1], slice_size):
|
||||||
|
end = i + slice_size
|
||||||
|
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
|
||||||
|
|
||||||
|
s2 = s1.softmax(dim=-1, dtype=q.dtype)
|
||||||
|
del s1
|
||||||
|
|
||||||
|
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
|
||||||
|
del s2
|
||||||
|
|
||||||
|
del q, k, v
|
||||||
|
|
||||||
|
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
|
||||||
|
del r1
|
||||||
|
|
||||||
|
return self.to_out(r2)
|
||||||
|
|
||||||
|
def nonlinearity_hijack(x):
|
||||||
|
# swish
|
||||||
|
t = torch.sigmoid(x)
|
||||||
|
x *= t
|
||||||
|
del t
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def cross_attention_attnblock_forward(self, x):
|
||||||
|
h_ = x
|
||||||
|
h_ = self.norm(h_)
|
||||||
|
q1 = self.q(h_)
|
||||||
|
k1 = self.k(h_)
|
||||||
|
v = self.v(h_)
|
||||||
|
|
||||||
|
# compute attention
|
||||||
|
b, c, h, w = q1.shape
|
||||||
|
|
||||||
|
q2 = q1.reshape(b, c, h*w)
|
||||||
|
del q1
|
||||||
|
|
||||||
|
q = q2.permute(0, 2, 1) # b,hw,c
|
||||||
|
del q2
|
||||||
|
|
||||||
|
k = k1.reshape(b, c, h*w) # b,c,hw
|
||||||
|
del k1
|
||||||
|
|
||||||
|
h_ = torch.zeros_like(k, device=q.device)
|
||||||
|
|
||||||
|
stats = torch.cuda.memory_stats(q.device)
|
||||||
|
mem_active = stats['active_bytes.all.current']
|
||||||
|
mem_reserved = stats['reserved_bytes.all.current']
|
||||||
|
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
|
||||||
|
mem_free_torch = mem_reserved - mem_active
|
||||||
|
mem_free_total = mem_free_cuda + mem_free_torch
|
||||||
|
|
||||||
|
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
|
||||||
|
mem_required = tensor_size * 2.5
|
||||||
|
steps = 1
|
||||||
|
|
||||||
|
if mem_required > mem_free_total:
|
||||||
|
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
|
||||||
|
|
||||||
|
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
||||||
|
for i in range(0, q.shape[1], slice_size):
|
||||||
|
end = i + slice_size
|
||||||
|
|
||||||
|
w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
||||||
|
w2 = w1 * (int(c)**(-0.5))
|
||||||
|
del w1
|
||||||
|
w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype)
|
||||||
|
del w2
|
||||||
|
|
||||||
|
# attend to values
|
||||||
|
v1 = v.reshape(b, c, h*w)
|
||||||
|
w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
|
||||||
|
del w3
|
||||||
|
|
||||||
|
h_[:, :, i:end] = torch.bmm(v1, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
||||||
|
del v1, w4
|
||||||
|
|
||||||
|
h2 = h_.reshape(b, c, h, w)
|
||||||
|
del h_
|
||||||
|
|
||||||
|
h3 = self.proj_out(h2)
|
||||||
|
del h2
|
||||||
|
|
||||||
|
h3 += x
|
||||||
|
|
||||||
|
return h3
|
|
@ -8,14 +8,11 @@ from omegaconf import OmegaConf
|
||||||
|
|
||||||
from ldm.util import instantiate_from_config
|
from ldm.util import instantiate_from_config
|
||||||
|
|
||||||
from modules import shared, modelloader
|
from modules import shared, modelloader, devices
|
||||||
from modules.paths import models_path
|
from modules.paths import models_path
|
||||||
|
|
||||||
model_dir = "Stable-diffusion"
|
model_dir = "Stable-diffusion"
|
||||||
model_path = os.path.abspath(os.path.join(models_path, model_dir))
|
model_path = os.path.abspath(os.path.join(models_path, model_dir))
|
||||||
model_name = "sd-v1-4.ckpt"
|
|
||||||
model_url = "https://drive.yerf.org/wl/?id=EBfTrmcCCUAGaQBXVIj5lJmEhjoP1tgl&mode=grid&download=1"
|
|
||||||
user_dir = None
|
|
||||||
|
|
||||||
CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name'])
|
CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name'])
|
||||||
checkpoints_list = {}
|
checkpoints_list = {}
|
||||||
|
@ -30,12 +27,10 @@ except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def setup_model(dirname):
|
def setup_model():
|
||||||
global user_dir
|
|
||||||
user_dir = dirname
|
|
||||||
if not os.path.exists(model_path):
|
if not os.path.exists(model_path):
|
||||||
os.makedirs(model_path)
|
os.makedirs(model_path)
|
||||||
checkpoints_list.clear()
|
|
||||||
list_models()
|
list_models()
|
||||||
|
|
||||||
|
|
||||||
|
@ -45,13 +40,13 @@ def checkpoint_tiles():
|
||||||
|
|
||||||
def list_models():
|
def list_models():
|
||||||
checkpoints_list.clear()
|
checkpoints_list.clear()
|
||||||
model_list = modelloader.load_models(model_path=model_path, model_url=model_url, command_path=user_dir, ext_filter=[".ckpt"], download_name=model_name)
|
model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt"])
|
||||||
|
|
||||||
def modeltitle(path, shorthash):
|
def modeltitle(path, shorthash):
|
||||||
abspath = os.path.abspath(path)
|
abspath = os.path.abspath(path)
|
||||||
|
|
||||||
if user_dir is not None and abspath.startswith(user_dir):
|
if shared.cmd_opts.ckpt_dir is not None and abspath.startswith(shared.cmd_opts.ckpt_dir):
|
||||||
name = abspath.replace(user_dir, '')
|
name = abspath.replace(shared.cmd_opts.ckpt_dir, '')
|
||||||
elif abspath.startswith(model_path):
|
elif abspath.startswith(model_path):
|
||||||
name = abspath.replace(model_path, '')
|
name = abspath.replace(model_path, '')
|
||||||
else:
|
else:
|
||||||
|
@ -69,7 +64,7 @@ def list_models():
|
||||||
h = model_hash(cmd_ckpt)
|
h = model_hash(cmd_ckpt)
|
||||||
title, short_model_name = modeltitle(cmd_ckpt, h)
|
title, short_model_name = modeltitle(cmd_ckpt, h)
|
||||||
checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name)
|
checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name)
|
||||||
shared.opts.sd_model_checkpoint = title
|
shared.opts.data['sd_model_checkpoint'] = title
|
||||||
elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
|
elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
|
||||||
print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr)
|
print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr)
|
||||||
for filename in model_list:
|
for filename in model_list:
|
||||||
|
@ -106,8 +101,11 @@ def select_checkpoint():
|
||||||
|
|
||||||
if len(checkpoints_list) == 0:
|
if len(checkpoints_list) == 0:
|
||||||
print(f"No checkpoints found. When searching for checkpoints, looked at:", file=sys.stderr)
|
print(f"No checkpoints found. When searching for checkpoints, looked at:", file=sys.stderr)
|
||||||
print(f" - file {os.path.abspath(shared.cmd_opts.ckpt)}", file=sys.stderr)
|
if shared.cmd_opts.ckpt is not None:
|
||||||
print(f" - directory {os.path.abspath(shared.cmd_opts.ckpt_dir)}", file=sys.stderr)
|
print(f" - file {os.path.abspath(shared.cmd_opts.ckpt)}", file=sys.stderr)
|
||||||
|
print(f" - directory {model_path}", file=sys.stderr)
|
||||||
|
if shared.cmd_opts.ckpt_dir is not None:
|
||||||
|
print(f" - directory {os.path.abspath(shared.cmd_opts.ckpt_dir)}", file=sys.stderr)
|
||||||
print(f"Can't run without a checkpoint. Find and place a .ckpt file into any of those locations. The program will exit.", file=sys.stderr)
|
print(f"Can't run without a checkpoint. Find and place a .ckpt file into any of those locations. The program will exit.", file=sys.stderr)
|
||||||
exit(1)
|
exit(1)
|
||||||
|
|
||||||
|
@ -134,6 +132,8 @@ def load_model_weights(model, checkpoint_file, sd_model_hash):
|
||||||
if not shared.cmd_opts.no_half:
|
if not shared.cmd_opts.no_half:
|
||||||
model.half()
|
model.half()
|
||||||
|
|
||||||
|
devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16
|
||||||
|
|
||||||
model.sd_model_hash = sd_model_hash
|
model.sd_model_hash = sd_model_hash
|
||||||
model.sd_model_checkpint = checkpoint_file
|
model.sd_model_checkpint = checkpoint_file
|
||||||
|
|
||||||
|
|
|
@ -77,7 +77,9 @@ def extended_tdqm(sequence, *args, desc=None, **kwargs):
|
||||||
state.sampling_steps = len(sequence)
|
state.sampling_steps = len(sequence)
|
||||||
state.sampling_step = 0
|
state.sampling_step = 0
|
||||||
|
|
||||||
for x in tqdm.tqdm(sequence, *args, desc=state.job, file=shared.progress_print_out, **kwargs):
|
seq = sequence if cmd_opts.disable_console_progressbars else tqdm.tqdm(sequence, *args, desc=state.job, file=shared.progress_print_out, **kwargs)
|
||||||
|
|
||||||
|
for x in seq:
|
||||||
if state.interrupted:
|
if state.interrupted:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
@ -207,7 +209,9 @@ def extended_trange(sampler, count, *args, **kwargs):
|
||||||
state.sampling_steps = count
|
state.sampling_steps = count
|
||||||
state.sampling_step = 0
|
state.sampling_step = 0
|
||||||
|
|
||||||
for x in tqdm.trange(count, *args, desc=state.job, file=shared.progress_print_out, **kwargs):
|
seq = range(count) if cmd_opts.disable_console_progressbars else tqdm.trange(count, *args, desc=state.job, file=shared.progress_print_out, **kwargs)
|
||||||
|
|
||||||
|
for x in seq:
|
||||||
if state.interrupted:
|
if state.interrupted:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
|
@ -40,6 +40,7 @@ parser.add_argument("--gfpgan-models-path", type=str, help="Path to directory wi
|
||||||
parser.add_argument("--esrgan-models-path", type=str, help="Path to directory with ESRGAN model file(s).", default=os.path.join(model_path, 'ESRGAN'))
|
parser.add_argument("--esrgan-models-path", type=str, help="Path to directory with ESRGAN model file(s).", default=os.path.join(model_path, 'ESRGAN'))
|
||||||
parser.add_argument("--bsrgan-models-path", type=str, help="Path to directory with BSRGAN model file(s).", default=os.path.join(model_path, 'BSRGAN'))
|
parser.add_argument("--bsrgan-models-path", type=str, help="Path to directory with BSRGAN model file(s).", default=os.path.join(model_path, 'BSRGAN'))
|
||||||
parser.add_argument("--realesrgan-models-path", type=str, help="Path to directory with RealESRGAN model file(s).", default=os.path.join(model_path, 'RealESRGAN'))
|
parser.add_argument("--realesrgan-models-path", type=str, help="Path to directory with RealESRGAN model file(s).", default=os.path.join(model_path, 'RealESRGAN'))
|
||||||
|
parser.add_argument("--scunet-models-path", type=str, help="Path to directory with ScuNET model file(s).", default=os.path.join(model_path, 'ScuNET'))
|
||||||
parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(model_path, 'SwinIR'))
|
parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(model_path, 'SwinIR'))
|
||||||
parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(model_path, 'LDSR'))
|
parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(model_path, 'LDSR'))
|
||||||
parser.add_argument("--opt-split-attention", action='store_true', help="force-enables cross-attention layer optimization. By default, it's on for torch.cuda and off for other torch devices.")
|
parser.add_argument("--opt-split-attention", action='store_true', help="force-enables cross-attention layer optimization. By default, it's on for torch.cuda and off for other torch devices.")
|
||||||
|
@ -57,6 +58,9 @@ parser.add_argument("--opt-channelslast", action='store_true', help="change memo
|
||||||
parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(script_path, 'styles.csv'))
|
parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(script_path, 'styles.csv'))
|
||||||
parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False)
|
parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False)
|
||||||
parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False)
|
parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False)
|
||||||
|
parser.add_argument("--disable-console-progressbars", action='store_true', help="do not output progressbars to console", default=False)
|
||||||
|
parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False)
|
||||||
|
|
||||||
|
|
||||||
cmd_opts = parser.parse_args()
|
cmd_opts = parser.parse_args()
|
||||||
device = get_optimal_device()
|
device = get_optimal_device()
|
||||||
|
@ -78,6 +82,7 @@ class State:
|
||||||
current_latent = None
|
current_latent = None
|
||||||
current_image = None
|
current_image = None
|
||||||
current_image_sampling_step = 0
|
current_image_sampling_step = 0
|
||||||
|
textinfo = None
|
||||||
|
|
||||||
def interrupt(self):
|
def interrupt(self):
|
||||||
self.interrupted = True
|
self.interrupted = True
|
||||||
|
@ -88,7 +93,7 @@ class State:
|
||||||
self.current_image_sampling_step = 0
|
self.current_image_sampling_step = 0
|
||||||
|
|
||||||
def get_job_timestamp(self):
|
def get_job_timestamp(self):
|
||||||
return datetime.datetime.now().strftime("%Y%m%d%H%M%S")
|
return datetime.datetime.now().strftime("%Y%m%d%H%M%S") # shouldn't this return job_timestamp?
|
||||||
|
|
||||||
|
|
||||||
state = State()
|
state = State()
|
||||||
|
@ -165,9 +170,10 @@ options_templates.update(options_section(('saving-paths', "Paths for saving"), {
|
||||||
|
|
||||||
options_templates.update(options_section(('saving-to-dirs', "Saving to a directory"), {
|
options_templates.update(options_section(('saving-to-dirs', "Saving to a directory"), {
|
||||||
"save_to_dirs": OptionInfo(False, "Save images to a subdirectory"),
|
"save_to_dirs": OptionInfo(False, "Save images to a subdirectory"),
|
||||||
"grid_save_to_dirs": OptionInfo(False, "Save grids to subdirectory"),
|
"grid_save_to_dirs": OptionInfo(False, "Save grids to a subdirectory"),
|
||||||
|
"use_save_to_dirs_for_ui": OptionInfo(False, "When using \"Save\" button, save images to a subdirectory"),
|
||||||
"directories_filename_pattern": OptionInfo("", "Directory name pattern"),
|
"directories_filename_pattern": OptionInfo("", "Directory name pattern"),
|
||||||
"directories_max_prompt_words": OptionInfo(8, "Max prompt words", gr.Slider, {"minimum": 1, "maximum": 20, "step": 1}),
|
"directories_max_prompt_words": OptionInfo(8, "Max prompt words for [prompt_words] pattern", gr.Slider, {"minimum": 1, "maximum": 20, "step": 1}),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('upscaling', "Upscaling"), {
|
options_templates.update(options_section(('upscaling', "Upscaling"), {
|
||||||
|
@ -318,14 +324,14 @@ class TotalTQDM:
|
||||||
)
|
)
|
||||||
|
|
||||||
def update(self):
|
def update(self):
|
||||||
if not opts.multiple_tqdm:
|
if not opts.multiple_tqdm or cmd_opts.disable_console_progressbars:
|
||||||
return
|
return
|
||||||
if self._tqdm is None:
|
if self._tqdm is None:
|
||||||
self.reset()
|
self.reset()
|
||||||
self._tqdm.update()
|
self._tqdm.update()
|
||||||
|
|
||||||
def updateTotal(self, new_total):
|
def updateTotal(self, new_total):
|
||||||
if not opts.multiple_tqdm:
|
if not opts.multiple_tqdm or cmd_opts.disable_console_progressbars:
|
||||||
return
|
return
|
||||||
if self._tqdm is None:
|
if self._tqdm is None:
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
|
@ -5,6 +5,7 @@ import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from basicsr.utils.download_util import load_file_from_url
|
from basicsr.utils.download_util import load_file_from_url
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
from modules import modelloader
|
from modules import modelloader
|
||||||
from modules.paths import models_path
|
from modules.paths import models_path
|
||||||
|
@ -122,18 +123,20 @@ def inference(img, model, tile, tile_overlap, window_size, scale):
|
||||||
E = torch.zeros(b, c, h * sf, w * sf, dtype=torch.half, device=device).type_as(img)
|
E = torch.zeros(b, c, h * sf, w * sf, dtype=torch.half, device=device).type_as(img)
|
||||||
W = torch.zeros_like(E, dtype=torch.half, device=device)
|
W = torch.zeros_like(E, dtype=torch.half, device=device)
|
||||||
|
|
||||||
for h_idx in h_idx_list:
|
with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar:
|
||||||
for w_idx in w_idx_list:
|
for h_idx in h_idx_list:
|
||||||
in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
|
for w_idx in w_idx_list:
|
||||||
out_patch = model(in_patch)
|
in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
|
||||||
out_patch_mask = torch.ones_like(out_patch)
|
out_patch = model(in_patch)
|
||||||
|
out_patch_mask = torch.ones_like(out_patch)
|
||||||
|
|
||||||
E[
|
E[
|
||||||
..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
|
..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
|
||||||
].add_(out_patch)
|
].add_(out_patch)
|
||||||
W[
|
W[
|
||||||
..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
|
..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
|
||||||
].add_(out_patch_mask)
|
].add_(out_patch_mask)
|
||||||
|
pbar.update(1)
|
||||||
output = E.div_(W)
|
output = E.div_(W)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
78
modules/textual_inversion/dataset.py
Normal file
78
modules/textual_inversion/dataset.py
Normal file
|
@ -0,0 +1,78 @@
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
import PIL
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from torchvision import transforms
|
||||||
|
|
||||||
|
import random
|
||||||
|
import tqdm
|
||||||
|
from modules import devices
|
||||||
|
|
||||||
|
|
||||||
|
class PersonalizedBase(Dataset):
|
||||||
|
def __init__(self, data_root, size=None, repeats=100, flip_p=0.5, placeholder_token="*", width=512, height=512, model=None, device=None, template_file=None):
|
||||||
|
|
||||||
|
self.placeholder_token = placeholder_token
|
||||||
|
|
||||||
|
self.size = size
|
||||||
|
self.width = width
|
||||||
|
self.height = height
|
||||||
|
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
|
||||||
|
|
||||||
|
self.dataset = []
|
||||||
|
|
||||||
|
with open(template_file, "r") as file:
|
||||||
|
lines = [x.strip() for x in file.readlines()]
|
||||||
|
|
||||||
|
self.lines = lines
|
||||||
|
|
||||||
|
assert data_root, 'dataset directory not specified'
|
||||||
|
|
||||||
|
self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
|
||||||
|
print("Preparing dataset...")
|
||||||
|
for path in tqdm.tqdm(self.image_paths):
|
||||||
|
image = Image.open(path)
|
||||||
|
image = image.convert('RGB')
|
||||||
|
image = image.resize((self.width, self.height), PIL.Image.BICUBIC)
|
||||||
|
|
||||||
|
filename = os.path.basename(path)
|
||||||
|
filename_tokens = os.path.splitext(filename)[0].replace('_', '-').replace(' ', '-').split('-')
|
||||||
|
filename_tokens = [token for token in filename_tokens if token.isalpha()]
|
||||||
|
|
||||||
|
npimage = np.array(image).astype(np.uint8)
|
||||||
|
npimage = (npimage / 127.5 - 1.0).astype(np.float32)
|
||||||
|
|
||||||
|
torchdata = torch.from_numpy(npimage).to(device=device, dtype=torch.float32)
|
||||||
|
torchdata = torch.moveaxis(torchdata, 2, 0)
|
||||||
|
|
||||||
|
init_latent = model.get_first_stage_encoding(model.encode_first_stage(torchdata.unsqueeze(dim=0))).squeeze()
|
||||||
|
init_latent = init_latent.to(devices.cpu)
|
||||||
|
|
||||||
|
self.dataset.append((init_latent, filename_tokens))
|
||||||
|
|
||||||
|
self.length = len(self.dataset) * repeats
|
||||||
|
|
||||||
|
self.initial_indexes = np.arange(self.length) % len(self.dataset)
|
||||||
|
self.indexes = None
|
||||||
|
self.shuffle()
|
||||||
|
|
||||||
|
def shuffle(self):
|
||||||
|
self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0])]
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.length
|
||||||
|
|
||||||
|
def __getitem__(self, i):
|
||||||
|
if i % len(self.dataset) == 0:
|
||||||
|
self.shuffle()
|
||||||
|
|
||||||
|
index = self.indexes[i % len(self.indexes)]
|
||||||
|
x, filename_tokens = self.dataset[index]
|
||||||
|
|
||||||
|
text = random.choice(self.lines)
|
||||||
|
text = text.replace("[name]", self.placeholder_token)
|
||||||
|
text = text.replace("[filewords]", ' '.join(filename_tokens))
|
||||||
|
|
||||||
|
return x, text
|
75
modules/textual_inversion/preprocess.py
Normal file
75
modules/textual_inversion/preprocess.py
Normal file
|
@ -0,0 +1,75 @@
|
||||||
|
import os
|
||||||
|
from PIL import Image, ImageOps
|
||||||
|
import tqdm
|
||||||
|
|
||||||
|
from modules import shared, images
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess(process_src, process_dst, process_flip, process_split, process_caption):
|
||||||
|
size = 512
|
||||||
|
src = os.path.abspath(process_src)
|
||||||
|
dst = os.path.abspath(process_dst)
|
||||||
|
|
||||||
|
assert src != dst, 'same directory specified as source and desitnation'
|
||||||
|
|
||||||
|
os.makedirs(dst, exist_ok=True)
|
||||||
|
|
||||||
|
files = os.listdir(src)
|
||||||
|
|
||||||
|
shared.state.textinfo = "Preprocessing..."
|
||||||
|
shared.state.job_count = len(files)
|
||||||
|
|
||||||
|
if process_caption:
|
||||||
|
shared.interrogator.load()
|
||||||
|
|
||||||
|
def save_pic_with_caption(image, index):
|
||||||
|
if process_caption:
|
||||||
|
caption = "-" + shared.interrogator.generate_caption(image)
|
||||||
|
else:
|
||||||
|
caption = ""
|
||||||
|
|
||||||
|
image.save(os.path.join(dst, f"{index:05}-{subindex[0]}{caption}.png"))
|
||||||
|
subindex[0] += 1
|
||||||
|
|
||||||
|
def save_pic(image, index):
|
||||||
|
save_pic_with_caption(image, index)
|
||||||
|
|
||||||
|
if process_flip:
|
||||||
|
save_pic_with_caption(ImageOps.mirror(image), index)
|
||||||
|
|
||||||
|
for index, imagefile in enumerate(tqdm.tqdm(files)):
|
||||||
|
subindex = [0]
|
||||||
|
filename = os.path.join(src, imagefile)
|
||||||
|
img = Image.open(filename).convert("RGB")
|
||||||
|
|
||||||
|
if shared.state.interrupted:
|
||||||
|
break
|
||||||
|
|
||||||
|
ratio = img.height / img.width
|
||||||
|
is_tall = ratio > 1.35
|
||||||
|
is_wide = ratio < 1 / 1.35
|
||||||
|
|
||||||
|
if process_split and is_tall:
|
||||||
|
img = img.resize((size, size * img.height // img.width))
|
||||||
|
|
||||||
|
top = img.crop((0, 0, size, size))
|
||||||
|
save_pic(top, index)
|
||||||
|
|
||||||
|
bot = img.crop((0, img.height - size, size, img.height))
|
||||||
|
save_pic(bot, index)
|
||||||
|
elif process_split and is_wide:
|
||||||
|
img = img.resize((size * img.width // img.height, size))
|
||||||
|
|
||||||
|
left = img.crop((0, 0, size, size))
|
||||||
|
save_pic(left, index)
|
||||||
|
|
||||||
|
right = img.crop((img.width - size, 0, img.width, size))
|
||||||
|
save_pic(right, index)
|
||||||
|
else:
|
||||||
|
img = images.resize_image(1, img, size, size)
|
||||||
|
save_pic(img, index)
|
||||||
|
|
||||||
|
shared.state.nextjob()
|
||||||
|
|
||||||
|
if process_caption:
|
||||||
|
shared.interrogator.send_blip_to_ram()
|
271
modules/textual_inversion/textual_inversion.py
Normal file
271
modules/textual_inversion/textual_inversion.py
Normal file
|
@ -0,0 +1,271 @@
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import tqdm
|
||||||
|
import html
|
||||||
|
import datetime
|
||||||
|
|
||||||
|
|
||||||
|
from modules import shared, devices, sd_hijack, processing, sd_models
|
||||||
|
import modules.textual_inversion.dataset
|
||||||
|
|
||||||
|
|
||||||
|
class Embedding:
|
||||||
|
def __init__(self, vec, name, step=None):
|
||||||
|
self.vec = vec
|
||||||
|
self.name = name
|
||||||
|
self.step = step
|
||||||
|
self.cached_checksum = None
|
||||||
|
self.sd_checkpoint = None
|
||||||
|
self.sd_checkpoint_name = None
|
||||||
|
|
||||||
|
def save(self, filename):
|
||||||
|
embedding_data = {
|
||||||
|
"string_to_token": {"*": 265},
|
||||||
|
"string_to_param": {"*": self.vec},
|
||||||
|
"name": self.name,
|
||||||
|
"step": self.step,
|
||||||
|
"sd_checkpoint": self.sd_checkpoint,
|
||||||
|
"sd_checkpoint_name": self.sd_checkpoint_name,
|
||||||
|
}
|
||||||
|
|
||||||
|
torch.save(embedding_data, filename)
|
||||||
|
|
||||||
|
def checksum(self):
|
||||||
|
if self.cached_checksum is not None:
|
||||||
|
return self.cached_checksum
|
||||||
|
|
||||||
|
def const_hash(a):
|
||||||
|
r = 0
|
||||||
|
for v in a:
|
||||||
|
r = (r * 281 ^ int(v) * 997) & 0xFFFFFFFF
|
||||||
|
return r
|
||||||
|
|
||||||
|
self.cached_checksum = f'{const_hash(self.vec.reshape(-1) * 100) & 0xffff:04x}'
|
||||||
|
return self.cached_checksum
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingDatabase:
|
||||||
|
def __init__(self, embeddings_dir):
|
||||||
|
self.ids_lookup = {}
|
||||||
|
self.word_embeddings = {}
|
||||||
|
self.dir_mtime = None
|
||||||
|
self.embeddings_dir = embeddings_dir
|
||||||
|
|
||||||
|
def register_embedding(self, embedding, model):
|
||||||
|
|
||||||
|
self.word_embeddings[embedding.name] = embedding
|
||||||
|
|
||||||
|
ids = model.cond_stage_model.tokenizer([embedding.name], add_special_tokens=False)['input_ids'][0]
|
||||||
|
|
||||||
|
first_id = ids[0]
|
||||||
|
if first_id not in self.ids_lookup:
|
||||||
|
self.ids_lookup[first_id] = []
|
||||||
|
|
||||||
|
self.ids_lookup[first_id] = sorted(self.ids_lookup[first_id] + [(ids, embedding)], key=lambda x: len(x[0]), reverse=True)
|
||||||
|
|
||||||
|
return embedding
|
||||||
|
|
||||||
|
def load_textual_inversion_embeddings(self):
|
||||||
|
mt = os.path.getmtime(self.embeddings_dir)
|
||||||
|
if self.dir_mtime is not None and mt <= self.dir_mtime:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.dir_mtime = mt
|
||||||
|
self.ids_lookup.clear()
|
||||||
|
self.word_embeddings.clear()
|
||||||
|
|
||||||
|
def process_file(path, filename):
|
||||||
|
name = os.path.splitext(filename)[0]
|
||||||
|
|
||||||
|
data = torch.load(path, map_location="cpu")
|
||||||
|
|
||||||
|
# textual inversion embeddings
|
||||||
|
if 'string_to_param' in data:
|
||||||
|
param_dict = data['string_to_param']
|
||||||
|
if hasattr(param_dict, '_parameters'):
|
||||||
|
param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11
|
||||||
|
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
|
||||||
|
emb = next(iter(param_dict.items()))[1]
|
||||||
|
# diffuser concepts
|
||||||
|
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
|
||||||
|
assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
|
||||||
|
|
||||||
|
emb = next(iter(data.values()))
|
||||||
|
if len(emb.shape) == 1:
|
||||||
|
emb = emb.unsqueeze(0)
|
||||||
|
else:
|
||||||
|
raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
|
||||||
|
|
||||||
|
vec = emb.detach().to(devices.device, dtype=torch.float32)
|
||||||
|
embedding = Embedding(vec, name)
|
||||||
|
embedding.step = data.get('step', None)
|
||||||
|
embedding.sd_checkpoint = data.get('hash', None)
|
||||||
|
embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
|
||||||
|
self.register_embedding(embedding, shared.sd_model)
|
||||||
|
|
||||||
|
for fn in os.listdir(self.embeddings_dir):
|
||||||
|
try:
|
||||||
|
fullfn = os.path.join(self.embeddings_dir, fn)
|
||||||
|
|
||||||
|
if os.stat(fullfn).st_size == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
process_file(fullfn, fn)
|
||||||
|
except Exception:
|
||||||
|
print(f"Error loading emedding {fn}:", file=sys.stderr)
|
||||||
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
continue
|
||||||
|
|
||||||
|
print(f"Loaded a total of {len(self.word_embeddings)} textual inversion embeddings.")
|
||||||
|
|
||||||
|
def find_embedding_at_position(self, tokens, offset):
|
||||||
|
token = tokens[offset]
|
||||||
|
possible_matches = self.ids_lookup.get(token, None)
|
||||||
|
|
||||||
|
if possible_matches is None:
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
for ids, embedding in possible_matches:
|
||||||
|
if tokens[offset:offset + len(ids)] == ids:
|
||||||
|
return embedding, len(ids)
|
||||||
|
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
|
||||||
|
def create_embedding(name, num_vectors_per_token, init_text='*'):
|
||||||
|
cond_model = shared.sd_model.cond_stage_model
|
||||||
|
embedding_layer = cond_model.wrapped.transformer.text_model.embeddings
|
||||||
|
|
||||||
|
ids = cond_model.tokenizer(init_text, max_length=num_vectors_per_token, return_tensors="pt", add_special_tokens=False)["input_ids"]
|
||||||
|
embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0)
|
||||||
|
vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device)
|
||||||
|
|
||||||
|
for i in range(num_vectors_per_token):
|
||||||
|
vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token]
|
||||||
|
|
||||||
|
fn = os.path.join(shared.cmd_opts.embeddings_dir, f"{name}.pt")
|
||||||
|
assert not os.path.exists(fn), f"file {fn} already exists"
|
||||||
|
|
||||||
|
embedding = Embedding(vec, name)
|
||||||
|
embedding.step = 0
|
||||||
|
embedding.save(fn)
|
||||||
|
|
||||||
|
return fn
|
||||||
|
|
||||||
|
|
||||||
|
def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, create_image_every, save_embedding_every, template_file):
|
||||||
|
assert embedding_name, 'embedding not selected'
|
||||||
|
|
||||||
|
shared.state.textinfo = "Initializing textual inversion training..."
|
||||||
|
shared.state.job_count = steps
|
||||||
|
|
||||||
|
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
|
||||||
|
|
||||||
|
log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%d-%m"), embedding_name)
|
||||||
|
|
||||||
|
if save_embedding_every > 0:
|
||||||
|
embedding_dir = os.path.join(log_directory, "embeddings")
|
||||||
|
os.makedirs(embedding_dir, exist_ok=True)
|
||||||
|
else:
|
||||||
|
embedding_dir = None
|
||||||
|
|
||||||
|
if create_image_every > 0:
|
||||||
|
images_dir = os.path.join(log_directory, "images")
|
||||||
|
os.makedirs(images_dir, exist_ok=True)
|
||||||
|
else:
|
||||||
|
images_dir = None
|
||||||
|
|
||||||
|
cond_model = shared.sd_model.cond_stage_model
|
||||||
|
|
||||||
|
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
||||||
|
with torch.autocast("cuda"):
|
||||||
|
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, size=512, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file)
|
||||||
|
|
||||||
|
hijack = sd_hijack.model_hijack
|
||||||
|
|
||||||
|
embedding = hijack.embedding_db.word_embeddings[embedding_name]
|
||||||
|
embedding.vec.requires_grad = True
|
||||||
|
|
||||||
|
optimizer = torch.optim.AdamW([embedding.vec], lr=learn_rate)
|
||||||
|
|
||||||
|
losses = torch.zeros((32,))
|
||||||
|
|
||||||
|
last_saved_file = "<none>"
|
||||||
|
last_saved_image = "<none>"
|
||||||
|
|
||||||
|
ititial_step = embedding.step or 0
|
||||||
|
if ititial_step > steps:
|
||||||
|
return embedding, filename
|
||||||
|
|
||||||
|
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
|
||||||
|
for i, (x, text) in pbar:
|
||||||
|
embedding.step = i + ititial_step
|
||||||
|
|
||||||
|
if embedding.step > steps:
|
||||||
|
break
|
||||||
|
|
||||||
|
if shared.state.interrupted:
|
||||||
|
break
|
||||||
|
|
||||||
|
with torch.autocast("cuda"):
|
||||||
|
c = cond_model([text])
|
||||||
|
|
||||||
|
x = x.to(devices.device)
|
||||||
|
loss = shared.sd_model(x.unsqueeze(0), c)[0]
|
||||||
|
del x
|
||||||
|
|
||||||
|
losses[embedding.step % losses.shape[0]] = loss.item()
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
pbar.set_description(f"loss: {losses.mean():.7f}")
|
||||||
|
|
||||||
|
if embedding.step > 0 and embedding_dir is not None and embedding.step % save_embedding_every == 0:
|
||||||
|
last_saved_file = os.path.join(embedding_dir, f'{embedding_name}-{embedding.step}.pt')
|
||||||
|
embedding.save(last_saved_file)
|
||||||
|
|
||||||
|
if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0:
|
||||||
|
last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png')
|
||||||
|
|
||||||
|
p = processing.StableDiffusionProcessingTxt2Img(
|
||||||
|
sd_model=shared.sd_model,
|
||||||
|
prompt=text,
|
||||||
|
steps=20,
|
||||||
|
do_not_save_grid=True,
|
||||||
|
do_not_save_samples=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
processed = processing.process_images(p)
|
||||||
|
image = processed.images[0]
|
||||||
|
|
||||||
|
shared.state.current_image = image
|
||||||
|
image.save(last_saved_image)
|
||||||
|
|
||||||
|
last_saved_image += f", prompt: {text}"
|
||||||
|
|
||||||
|
shared.state.job_no = embedding.step
|
||||||
|
|
||||||
|
shared.state.textinfo = f"""
|
||||||
|
<p>
|
||||||
|
Loss: {losses.mean():.7f}<br/>
|
||||||
|
Step: {embedding.step}<br/>
|
||||||
|
Last prompt: {html.escape(text)}<br/>
|
||||||
|
Last saved embedding: {html.escape(last_saved_file)}<br/>
|
||||||
|
Last saved image: {html.escape(last_saved_image)}<br/>
|
||||||
|
</p>
|
||||||
|
"""
|
||||||
|
|
||||||
|
checkpoint = sd_models.select_checkpoint()
|
||||||
|
|
||||||
|
embedding.sd_checkpoint = checkpoint.hash
|
||||||
|
embedding.sd_checkpoint_name = checkpoint.model_name
|
||||||
|
embedding.cached_checksum = None
|
||||||
|
embedding.save(filename)
|
||||||
|
|
||||||
|
return embedding, filename
|
||||||
|
|
40
modules/textual_inversion/ui.py
Normal file
40
modules/textual_inversion/ui.py
Normal file
|
@ -0,0 +1,40 @@
|
||||||
|
import html
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
|
|
||||||
|
import modules.textual_inversion.textual_inversion
|
||||||
|
import modules.textual_inversion.preprocess
|
||||||
|
from modules import sd_hijack, shared
|
||||||
|
|
||||||
|
|
||||||
|
def create_embedding(name, initialization_text, nvpt):
|
||||||
|
filename = modules.textual_inversion.textual_inversion.create_embedding(name, nvpt, init_text=initialization_text)
|
||||||
|
|
||||||
|
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
|
||||||
|
|
||||||
|
return gr.Dropdown.update(choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())), f"Created: {filename}", ""
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess(*args):
|
||||||
|
modules.textual_inversion.preprocess.preprocess(*args)
|
||||||
|
|
||||||
|
return "Preprocessing finished.", ""
|
||||||
|
|
||||||
|
|
||||||
|
def train_embedding(*args):
|
||||||
|
|
||||||
|
try:
|
||||||
|
sd_hijack.undo_optimizations()
|
||||||
|
|
||||||
|
embedding, filename = modules.textual_inversion.textual_inversion.train_embedding(*args)
|
||||||
|
|
||||||
|
res = f"""
|
||||||
|
Training {'interrupted' if shared.state.interrupted else 'finished'} at {embedding.step} steps.
|
||||||
|
Embedding saved to {html.escape(filename)}
|
||||||
|
"""
|
||||||
|
return res, ""
|
||||||
|
except Exception:
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
sd_hijack.apply_optimizations()
|
||||||
|
|
|
@ -34,7 +34,9 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
|
||||||
denoising_strength=denoising_strength if enable_hr else None,
|
denoising_strength=denoising_strength if enable_hr else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"\ntxt2img: {prompt}", file=shared.progress_print_out)
|
if cmd_opts.enable_console_prompts:
|
||||||
|
print(f"\ntxt2img: {prompt}", file=shared.progress_print_out)
|
||||||
|
|
||||||
processed = modules.scripts.scripts_txt2img.run(p, *args)
|
processed = modules.scripts.scripts_txt2img.run(p, *args)
|
||||||
|
|
||||||
if processed is None:
|
if processed is None:
|
||||||
|
|
281
modules/ui.py
281
modules/ui.py
|
@ -11,6 +11,7 @@ import time
|
||||||
import traceback
|
import traceback
|
||||||
import platform
|
import platform
|
||||||
import subprocess as sp
|
import subprocess as sp
|
||||||
|
from functools import reduce
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
@ -21,6 +22,7 @@ import gradio as gr
|
||||||
import gradio.utils
|
import gradio.utils
|
||||||
import gradio.routes
|
import gradio.routes
|
||||||
|
|
||||||
|
from modules import sd_hijack
|
||||||
from modules.paths import script_path
|
from modules.paths import script_path
|
||||||
from modules.shared import opts, cmd_opts
|
from modules.shared import opts, cmd_opts
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
|
@ -32,6 +34,9 @@ import modules.gfpgan_model
|
||||||
import modules.codeformer_model
|
import modules.codeformer_model
|
||||||
import modules.styles
|
import modules.styles
|
||||||
import modules.generation_parameters_copypaste
|
import modules.generation_parameters_copypaste
|
||||||
|
from modules.prompt_parser import get_learned_conditioning_prompt_schedules
|
||||||
|
from modules.images import apply_filename_pattern, get_next_sequence_number
|
||||||
|
import modules.textual_inversion.ui
|
||||||
|
|
||||||
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the bowser will not show any UI
|
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the bowser will not show any UI
|
||||||
mimetypes.init()
|
mimetypes.init()
|
||||||
|
@ -95,13 +100,30 @@ def send_gradio_gallery_to_image(x):
|
||||||
|
|
||||||
def save_files(js_data, images, index):
|
def save_files(js_data, images, index):
|
||||||
import csv
|
import csv
|
||||||
|
|
||||||
os.makedirs(opts.outdir_save, exist_ok=True)
|
|
||||||
|
|
||||||
filenames = []
|
filenames = []
|
||||||
|
|
||||||
|
#quick dictionary to class object conversion. Its neccesary due apply_filename_pattern requiring it
|
||||||
|
class MyObject:
|
||||||
|
def __init__(self, d=None):
|
||||||
|
if d is not None:
|
||||||
|
for key, value in d.items():
|
||||||
|
setattr(self, key, value)
|
||||||
|
|
||||||
data = json.loads(js_data)
|
data = json.loads(js_data)
|
||||||
|
|
||||||
|
p = MyObject(data)
|
||||||
|
path = opts.outdir_save
|
||||||
|
save_to_dirs = opts.use_save_to_dirs_for_ui
|
||||||
|
|
||||||
|
if save_to_dirs:
|
||||||
|
dirname = apply_filename_pattern(opts.directories_filename_pattern or "[prompt_words]", p, p.seed, p.prompt)
|
||||||
|
path = os.path.join(opts.outdir_save, dirname)
|
||||||
|
|
||||||
|
os.makedirs(path, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
if index > -1 and opts.save_selected_only and (index >= data["index_of_first_image"]): # ensures we are looking at a specific non-grid picture, and we have save_selected_only
|
if index > -1 and opts.save_selected_only and (index >= data["index_of_first_image"]): # ensures we are looking at a specific non-grid picture, and we have save_selected_only
|
||||||
|
|
||||||
images = [images[index]]
|
images = [images[index]]
|
||||||
infotexts = [data["infotexts"][index]]
|
infotexts = [data["infotexts"][index]]
|
||||||
else:
|
else:
|
||||||
|
@ -113,11 +135,20 @@ def save_files(js_data, images, index):
|
||||||
if at_start:
|
if at_start:
|
||||||
writer.writerow(["prompt", "seed", "width", "height", "sampler", "cfgs", "steps", "filename", "negative_prompt"])
|
writer.writerow(["prompt", "seed", "width", "height", "sampler", "cfgs", "steps", "filename", "negative_prompt"])
|
||||||
|
|
||||||
filename_base = str(int(time.time() * 1000))
|
file_decoration = opts.samples_filename_pattern or "[seed]-[prompt_spaces]"
|
||||||
|
if file_decoration != "":
|
||||||
|
file_decoration = "-" + file_decoration.lower()
|
||||||
|
file_decoration = apply_filename_pattern(file_decoration, p, p.seed, p.prompt)
|
||||||
|
truncated = (file_decoration[:240] + '..') if len(file_decoration) > 240 else file_decoration
|
||||||
|
filename_base = truncated
|
||||||
extension = opts.samples_format.lower()
|
extension = opts.samples_format.lower()
|
||||||
|
|
||||||
|
basecount = get_next_sequence_number(path, "")
|
||||||
for i, filedata in enumerate(images):
|
for i, filedata in enumerate(images):
|
||||||
filename = filename_base + ("" if len(images) == 1 else "-" + str(i + 1)) + f".{extension}"
|
file_number = f"{basecount+i:05}"
|
||||||
filepath = os.path.join(opts.outdir_save, filename)
|
filename = file_number + filename_base + f".{extension}"
|
||||||
|
filepath = os.path.join(path, filename)
|
||||||
|
|
||||||
|
|
||||||
if filedata.startswith("data:image/png;base64,"):
|
if filedata.startswith("data:image/png;base64,"):
|
||||||
filedata = filedata[len("data:image/png;base64,"):]
|
filedata = filedata[len("data:image/png;base64,"):]
|
||||||
|
@ -142,8 +173,8 @@ def save_files(js_data, images, index):
|
||||||
return '', '', plaintext_to_html(f"Saved: {filenames[0]}")
|
return '', '', plaintext_to_html(f"Saved: {filenames[0]}")
|
||||||
|
|
||||||
|
|
||||||
def wrap_gradio_call(func):
|
def wrap_gradio_call(func, extra_outputs=None):
|
||||||
def f(*args, **kwargs):
|
def f(*args, extra_outputs_array=extra_outputs, **kwargs):
|
||||||
run_memmon = opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled
|
run_memmon = opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled
|
||||||
if run_memmon:
|
if run_memmon:
|
||||||
shared.mem_mon.monitor()
|
shared.mem_mon.monitor()
|
||||||
|
@ -159,7 +190,10 @@ def wrap_gradio_call(func):
|
||||||
shared.state.job = ""
|
shared.state.job = ""
|
||||||
shared.state.job_count = 0
|
shared.state.job_count = 0
|
||||||
|
|
||||||
res = [None, '', f"<div class='error'>{plaintext_to_html(type(e).__name__+': '+str(e))}</div>"]
|
if extra_outputs_array is None:
|
||||||
|
extra_outputs_array = [None, '']
|
||||||
|
|
||||||
|
res = extra_outputs_array + [f"<div class='error'>{plaintext_to_html(type(e).__name__+': '+str(e))}</div>"]
|
||||||
|
|
||||||
elapsed = time.perf_counter() - t
|
elapsed = time.perf_counter() - t
|
||||||
|
|
||||||
|
@ -179,6 +213,7 @@ def wrap_gradio_call(func):
|
||||||
res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr>{elapsed:.2f}s</p>{vram_html}</div>"
|
res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr>{elapsed:.2f}s</p>{vram_html}</div>"
|
||||||
|
|
||||||
shared.state.interrupted = False
|
shared.state.interrupted = False
|
||||||
|
shared.state.job_count = 0
|
||||||
|
|
||||||
return tuple(res)
|
return tuple(res)
|
||||||
|
|
||||||
|
@ -187,7 +222,7 @@ def wrap_gradio_call(func):
|
||||||
|
|
||||||
def check_progress_call(id_part):
|
def check_progress_call(id_part):
|
||||||
if shared.state.job_count == 0:
|
if shared.state.job_count == 0:
|
||||||
return "", gr_show(False), gr_show(False)
|
return "", gr_show(False), gr_show(False), gr_show(False)
|
||||||
|
|
||||||
progress = 0
|
progress = 0
|
||||||
|
|
||||||
|
@ -219,13 +254,19 @@ def check_progress_call(id_part):
|
||||||
else:
|
else:
|
||||||
preview_visibility = gr_show(True)
|
preview_visibility = gr_show(True)
|
||||||
|
|
||||||
return f"<span id='{id_part}_progress_span' style='display: none'>{time.time()}</span><p>{progressbar}</p>", preview_visibility, image
|
if shared.state.textinfo is not None:
|
||||||
|
textinfo_result = gr.HTML.update(value=shared.state.textinfo, visible=True)
|
||||||
|
else:
|
||||||
|
textinfo_result = gr_show(False)
|
||||||
|
|
||||||
|
return f"<span id='{id_part}_progress_span' style='display: none'>{time.time()}</span><p>{progressbar}</p>", preview_visibility, image, textinfo_result
|
||||||
|
|
||||||
|
|
||||||
def check_progress_call_initial(id_part):
|
def check_progress_call_initial(id_part):
|
||||||
shared.state.job_count = -1
|
shared.state.job_count = -1
|
||||||
shared.state.current_latent = None
|
shared.state.current_latent = None
|
||||||
shared.state.current_image = None
|
shared.state.current_image = None
|
||||||
|
shared.state.textinfo = None
|
||||||
|
|
||||||
return check_progress_call(id_part)
|
return check_progress_call(id_part)
|
||||||
|
|
||||||
|
@ -345,8 +386,11 @@ def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info:
|
||||||
outputs=[seed, dummy_component]
|
outputs=[seed, dummy_component]
|
||||||
)
|
)
|
||||||
|
|
||||||
def update_token_counter(text):
|
def update_token_counter(text, steps):
|
||||||
tokens, token_count, max_length = model_hijack.tokenize(text)
|
prompt_schedules = get_learned_conditioning_prompt_schedules([text], steps)
|
||||||
|
flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules)
|
||||||
|
prompts = [prompt_text for step,prompt_text in flat_prompts]
|
||||||
|
tokens, token_count, max_length = max([model_hijack.tokenize(prompt) for prompt in prompts], key=lambda args: args[1])
|
||||||
style_class = ' class="red"' if (token_count > max_length) else ""
|
style_class = ' class="red"' if (token_count > max_length) else ""
|
||||||
return f"<span {style_class}>{token_count}/{max_length}</span>"
|
return f"<span {style_class}>{token_count}/{max_length}</span>"
|
||||||
|
|
||||||
|
@ -364,8 +408,7 @@ def create_toprow(is_img2img):
|
||||||
roll = gr.Button(value=art_symbol, elem_id="roll", visible=len(shared.artist_db.artists) > 0)
|
roll = gr.Button(value=art_symbol, elem_id="roll", visible=len(shared.artist_db.artists) > 0)
|
||||||
paste = gr.Button(value=paste_symbol, elem_id="paste")
|
paste = gr.Button(value=paste_symbol, elem_id="paste")
|
||||||
token_counter = gr.HTML(value="<span></span>", elem_id=f"{id_part}_token_counter")
|
token_counter = gr.HTML(value="<span></span>", elem_id=f"{id_part}_token_counter")
|
||||||
hidden_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
|
token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
|
||||||
hidden_button.click(fn=update_token_counter, inputs=[prompt], outputs=[token_counter])
|
|
||||||
|
|
||||||
with gr.Column(scale=10, elem_id="style_pos_col"):
|
with gr.Column(scale=10, elem_id="style_pos_col"):
|
||||||
prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())), visible=len(shared.prompt_styles.styles) > 1)
|
prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())), visible=len(shared.prompt_styles.styles) > 1)
|
||||||
|
@ -396,16 +439,19 @@ def create_toprow(is_img2img):
|
||||||
prompt_style_apply = gr.Button('Apply style', elem_id="style_apply")
|
prompt_style_apply = gr.Button('Apply style', elem_id="style_apply")
|
||||||
save_style = gr.Button('Create style', elem_id="style_create")
|
save_style = gr.Button('Create style', elem_id="style_create")
|
||||||
|
|
||||||
return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, interrogate, prompt_style_apply, save_style, paste
|
return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, interrogate, prompt_style_apply, save_style, paste, token_counter, token_button
|
||||||
|
|
||||||
|
|
||||||
def setup_progressbar(progressbar, preview, id_part):
|
def setup_progressbar(progressbar, preview, id_part, textinfo=None):
|
||||||
|
if textinfo is None:
|
||||||
|
textinfo = gr.HTML(visible=False)
|
||||||
|
|
||||||
check_progress = gr.Button('Check progress', elem_id=f"{id_part}_check_progress", visible=False)
|
check_progress = gr.Button('Check progress', elem_id=f"{id_part}_check_progress", visible=False)
|
||||||
check_progress.click(
|
check_progress.click(
|
||||||
fn=lambda: check_progress_call(id_part),
|
fn=lambda: check_progress_call(id_part),
|
||||||
show_progress=False,
|
show_progress=False,
|
||||||
inputs=[],
|
inputs=[],
|
||||||
outputs=[progressbar, preview, preview],
|
outputs=[progressbar, preview, preview, textinfo],
|
||||||
)
|
)
|
||||||
|
|
||||||
check_progress_initial = gr.Button('Check progress (first)', elem_id=f"{id_part}_check_progress_initial", visible=False)
|
check_progress_initial = gr.Button('Check progress (first)', elem_id=f"{id_part}_check_progress_initial", visible=False)
|
||||||
|
@ -413,13 +459,16 @@ def setup_progressbar(progressbar, preview, id_part):
|
||||||
fn=lambda: check_progress_call_initial(id_part),
|
fn=lambda: check_progress_call_initial(id_part),
|
||||||
show_progress=False,
|
show_progress=False,
|
||||||
inputs=[],
|
inputs=[],
|
||||||
outputs=[progressbar, preview, preview],
|
outputs=[progressbar, preview, preview, textinfo],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
|
def create_ui(wrap_gradio_gpu_call):
|
||||||
|
import modules.img2img
|
||||||
|
import modules.txt2img
|
||||||
|
|
||||||
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
|
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
|
||||||
txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, txt2img_prompt_style_apply, txt2img_save_style, paste = create_toprow(is_img2img=False)
|
txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, txt2img_prompt_style_apply, txt2img_save_style, paste, token_counter, token_button = create_toprow(is_img2img=False)
|
||||||
dummy_component = gr.Label(visible=False)
|
dummy_component = gr.Label(visible=False)
|
||||||
|
|
||||||
with gr.Row(elem_id='txt2img_progress_row'):
|
with gr.Row(elem_id='txt2img_progress_row'):
|
||||||
|
@ -483,7 +532,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
|
||||||
connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
|
connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
|
||||||
|
|
||||||
txt2img_args = dict(
|
txt2img_args = dict(
|
||||||
fn=txt2img,
|
fn=wrap_gradio_gpu_call(modules.txt2img.txt2img),
|
||||||
_js="submit",
|
_js="submit",
|
||||||
inputs=[
|
inputs=[
|
||||||
txt2img_prompt,
|
txt2img_prompt,
|
||||||
|
@ -539,6 +588,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
|
||||||
|
|
||||||
roll.click(
|
roll.click(
|
||||||
fn=roll_artist,
|
fn=roll_artist,
|
||||||
|
_js="update_txt2img_tokens",
|
||||||
inputs=[
|
inputs=[
|
||||||
txt2img_prompt,
|
txt2img_prompt,
|
||||||
],
|
],
|
||||||
|
@ -567,9 +617,10 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
|
||||||
(hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)),
|
(hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)),
|
||||||
]
|
]
|
||||||
modules.generation_parameters_copypaste.connect_paste(paste, txt2img_paste_fields, txt2img_prompt)
|
modules.generation_parameters_copypaste.connect_paste(paste, txt2img_paste_fields, txt2img_prompt)
|
||||||
|
token_button.click(fn=update_token_counter, inputs=[txt2img_prompt, steps], outputs=[token_counter])
|
||||||
|
|
||||||
with gr.Blocks(analytics_enabled=False) as img2img_interface:
|
with gr.Blocks(analytics_enabled=False) as img2img_interface:
|
||||||
img2img_prompt, roll, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_prompt_style_apply, img2img_save_style, paste = create_toprow(is_img2img=True)
|
img2img_prompt, roll, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_prompt_style_apply, img2img_save_style, paste, token_counter, token_button = create_toprow(is_img2img=True)
|
||||||
|
|
||||||
with gr.Row(elem_id='img2img_progress_row'):
|
with gr.Row(elem_id='img2img_progress_row'):
|
||||||
with gr.Column(scale=1):
|
with gr.Column(scale=1):
|
||||||
|
@ -675,7 +726,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
|
||||||
)
|
)
|
||||||
|
|
||||||
img2img_args = dict(
|
img2img_args = dict(
|
||||||
fn=img2img,
|
fn=wrap_gradio_gpu_call(modules.img2img.img2img),
|
||||||
_js="submit_img2img",
|
_js="submit_img2img",
|
||||||
inputs=[
|
inputs=[
|
||||||
dummy_component,
|
dummy_component,
|
||||||
|
@ -743,6 +794,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
|
||||||
|
|
||||||
roll.click(
|
roll.click(
|
||||||
fn=roll_artist,
|
fn=roll_artist,
|
||||||
|
_js="update_img2img_tokens",
|
||||||
inputs=[
|
inputs=[
|
||||||
img2img_prompt,
|
img2img_prompt,
|
||||||
],
|
],
|
||||||
|
@ -753,6 +805,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
|
||||||
|
|
||||||
prompts = [(txt2img_prompt, txt2img_negative_prompt), (img2img_prompt, img2img_negative_prompt)]
|
prompts = [(txt2img_prompt, txt2img_negative_prompt), (img2img_prompt, img2img_negative_prompt)]
|
||||||
style_dropdowns = [(txt2img_prompt_style, txt2img_prompt_style2), (img2img_prompt_style, img2img_prompt_style2)]
|
style_dropdowns = [(txt2img_prompt_style, txt2img_prompt_style2), (img2img_prompt_style, img2img_prompt_style2)]
|
||||||
|
style_js_funcs = ["update_txt2img_tokens", "update_img2img_tokens"]
|
||||||
|
|
||||||
for button, (prompt, negative_prompt) in zip([txt2img_save_style, img2img_save_style], prompts):
|
for button, (prompt, negative_prompt) in zip([txt2img_save_style, img2img_save_style], prompts):
|
||||||
button.click(
|
button.click(
|
||||||
|
@ -764,9 +817,10 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
|
||||||
outputs=[txt2img_prompt_style, img2img_prompt_style, txt2img_prompt_style2, img2img_prompt_style2],
|
outputs=[txt2img_prompt_style, img2img_prompt_style, txt2img_prompt_style2, img2img_prompt_style2],
|
||||||
)
|
)
|
||||||
|
|
||||||
for button, (prompt, negative_prompt), (style1, style2) in zip([txt2img_prompt_style_apply, img2img_prompt_style_apply], prompts, style_dropdowns):
|
for button, (prompt, negative_prompt), (style1, style2), js_func in zip([txt2img_prompt_style_apply, img2img_prompt_style_apply], prompts, style_dropdowns, style_js_funcs):
|
||||||
button.click(
|
button.click(
|
||||||
fn=apply_styles,
|
fn=apply_styles,
|
||||||
|
_js=js_func,
|
||||||
inputs=[prompt, negative_prompt, style1, style2],
|
inputs=[prompt, negative_prompt, style1, style2],
|
||||||
outputs=[prompt, negative_prompt, style1, style2],
|
outputs=[prompt, negative_prompt, style1, style2],
|
||||||
)
|
)
|
||||||
|
@ -789,6 +843,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
|
||||||
(denoising_strength, "Denoising strength"),
|
(denoising_strength, "Denoising strength"),
|
||||||
]
|
]
|
||||||
modules.generation_parameters_copypaste.connect_paste(paste, img2img_paste_fields, img2img_prompt)
|
modules.generation_parameters_copypaste.connect_paste(paste, img2img_paste_fields, img2img_prompt)
|
||||||
|
token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter])
|
||||||
|
|
||||||
with gr.Blocks(analytics_enabled=False) as extras_interface:
|
with gr.Blocks(analytics_enabled=False) as extras_interface:
|
||||||
with gr.Row().style(equal_height=False):
|
with gr.Row().style(equal_height=False):
|
||||||
|
@ -828,7 +883,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
|
||||||
open_extras_folder = gr.Button('Open output directory', elem_id=button_id)
|
open_extras_folder = gr.Button('Open output directory', elem_id=button_id)
|
||||||
|
|
||||||
submit.click(
|
submit.click(
|
||||||
fn=run_extras,
|
fn=wrap_gradio_gpu_call(modules.extras.run_extras),
|
||||||
_js="get_extras_tab_index",
|
_js="get_extras_tab_index",
|
||||||
inputs=[
|
inputs=[
|
||||||
dummy_component,
|
dummy_component,
|
||||||
|
@ -878,7 +933,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
|
||||||
pnginfo_send_to_img2img = gr.Button('Send to img2img')
|
pnginfo_send_to_img2img = gr.Button('Send to img2img')
|
||||||
|
|
||||||
image.change(
|
image.change(
|
||||||
fn=wrap_gradio_call(run_pnginfo),
|
fn=wrap_gradio_call(modules.extras.run_pnginfo),
|
||||||
inputs=[image],
|
inputs=[image],
|
||||||
outputs=[html, generation_info, html2],
|
outputs=[html, generation_info, html2],
|
||||||
)
|
)
|
||||||
|
@ -900,6 +955,130 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
|
||||||
with gr.Column(variant='panel'):
|
with gr.Column(variant='panel'):
|
||||||
submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False)
|
submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False)
|
||||||
|
|
||||||
|
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
|
||||||
|
|
||||||
|
with gr.Blocks() as textual_inversion_interface:
|
||||||
|
with gr.Row().style(equal_height=False):
|
||||||
|
with gr.Column():
|
||||||
|
with gr.Group():
|
||||||
|
gr.HTML(value="<p style='margin-bottom: 0.7em'>See <b><a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\">wiki</a></b> for detailed explanation.</p>")
|
||||||
|
|
||||||
|
gr.HTML(value="<p style='margin-bottom: 0.7em'>Create a new embedding</p>")
|
||||||
|
|
||||||
|
new_embedding_name = gr.Textbox(label="Name")
|
||||||
|
initialization_text = gr.Textbox(label="Initialization text", value="*")
|
||||||
|
nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1)
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column(scale=3):
|
||||||
|
gr.HTML(value="")
|
||||||
|
|
||||||
|
with gr.Column():
|
||||||
|
create_embedding = gr.Button(value="Create", variant='primary')
|
||||||
|
|
||||||
|
with gr.Group():
|
||||||
|
gr.HTML(value="<p style='margin-bottom: 0.7em'>Preprocess images</p>")
|
||||||
|
|
||||||
|
process_src = gr.Textbox(label='Source directory')
|
||||||
|
process_dst = gr.Textbox(label='Destination directory')
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
process_flip = gr.Checkbox(label='Flip')
|
||||||
|
process_split = gr.Checkbox(label='Split into two')
|
||||||
|
process_caption = gr.Checkbox(label='Add caption')
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column(scale=3):
|
||||||
|
gr.HTML(value="")
|
||||||
|
|
||||||
|
with gr.Column():
|
||||||
|
run_preprocess = gr.Button(value="Preprocess", variant='primary')
|
||||||
|
|
||||||
|
with gr.Group():
|
||||||
|
gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding; must specify a directory with a set of 512x512 images</p>")
|
||||||
|
train_embedding_name = gr.Dropdown(label='Embedding', choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
|
||||||
|
learn_rate = gr.Number(label='Learning rate', value=5.0e-03)
|
||||||
|
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
|
||||||
|
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
|
||||||
|
template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"))
|
||||||
|
steps = gr.Number(label='Max steps', value=100000, precision=0)
|
||||||
|
create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0)
|
||||||
|
save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0)
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column(scale=2):
|
||||||
|
gr.HTML(value="")
|
||||||
|
|
||||||
|
with gr.Column():
|
||||||
|
with gr.Row():
|
||||||
|
interrupt_training = gr.Button(value="Interrupt")
|
||||||
|
train_embedding = gr.Button(value="Train", variant='primary')
|
||||||
|
|
||||||
|
with gr.Column():
|
||||||
|
progressbar = gr.HTML(elem_id="ti_progressbar")
|
||||||
|
ti_output = gr.Text(elem_id="ti_output", value="", show_label=False)
|
||||||
|
|
||||||
|
ti_gallery = gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery').style(grid=4)
|
||||||
|
ti_preview = gr.Image(elem_id='ti_preview', visible=False)
|
||||||
|
ti_progress = gr.HTML(elem_id="ti_progress", value="")
|
||||||
|
ti_outcome = gr.HTML(elem_id="ti_error", value="")
|
||||||
|
setup_progressbar(progressbar, ti_preview, 'ti', textinfo=ti_progress)
|
||||||
|
|
||||||
|
create_embedding.click(
|
||||||
|
fn=modules.textual_inversion.ui.create_embedding,
|
||||||
|
inputs=[
|
||||||
|
new_embedding_name,
|
||||||
|
initialization_text,
|
||||||
|
nvpt,
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
train_embedding_name,
|
||||||
|
ti_output,
|
||||||
|
ti_outcome,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
run_preprocess.click(
|
||||||
|
fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.preprocess, extra_outputs=[gr.update()]),
|
||||||
|
_js="start_training_textual_inversion",
|
||||||
|
inputs=[
|
||||||
|
process_src,
|
||||||
|
process_dst,
|
||||||
|
process_flip,
|
||||||
|
process_split,
|
||||||
|
process_caption,
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
ti_output,
|
||||||
|
ti_outcome,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
train_embedding.click(
|
||||||
|
fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.train_embedding, extra_outputs=[gr.update()]),
|
||||||
|
_js="start_training_textual_inversion",
|
||||||
|
inputs=[
|
||||||
|
train_embedding_name,
|
||||||
|
learn_rate,
|
||||||
|
dataset_directory,
|
||||||
|
log_directory,
|
||||||
|
steps,
|
||||||
|
create_image_every,
|
||||||
|
save_embedding_every,
|
||||||
|
template_file,
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
ti_output,
|
||||||
|
ti_outcome,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
interrupt_training.click(
|
||||||
|
fn=lambda: shared.state.interrupt(),
|
||||||
|
inputs=[],
|
||||||
|
outputs=[],
|
||||||
|
)
|
||||||
|
|
||||||
def create_setting_component(key):
|
def create_setting_component(key):
|
||||||
def fun():
|
def fun():
|
||||||
return opts.data[key] if key in opts.data else opts.data_labels[key].default
|
return opts.data[key] if key in opts.data else opts.data_labels[key].default
|
||||||
|
@ -1002,6 +1181,31 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
|
||||||
_js='function(){}'
|
_js='function(){}'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary')
|
||||||
|
restart_gradio = gr.Button(value='Restart Gradio and Refresh components (Custom Scripts, ui.py, js and css only)', variant='primary')
|
||||||
|
|
||||||
|
|
||||||
|
def reload_scripts():
|
||||||
|
modules.scripts.reload_script_body_only()
|
||||||
|
|
||||||
|
reload_script_bodies.click(
|
||||||
|
fn=reload_scripts,
|
||||||
|
inputs=[],
|
||||||
|
outputs=[],
|
||||||
|
_js='function(){}'
|
||||||
|
)
|
||||||
|
|
||||||
|
def request_restart():
|
||||||
|
settings_interface.gradio_ref.do_restart = True
|
||||||
|
|
||||||
|
restart_gradio.click(
|
||||||
|
fn=request_restart,
|
||||||
|
inputs=[],
|
||||||
|
outputs=[],
|
||||||
|
_js='function(){restart_reload()}'
|
||||||
|
)
|
||||||
|
|
||||||
if column is not None:
|
if column is not None:
|
||||||
column.__exit__()
|
column.__exit__()
|
||||||
|
|
||||||
|
@ -1011,6 +1215,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
|
||||||
(extras_interface, "Extras", "extras"),
|
(extras_interface, "Extras", "extras"),
|
||||||
(pnginfo_interface, "PNG Info", "pnginfo"),
|
(pnginfo_interface, "PNG Info", "pnginfo"),
|
||||||
(modelmerger_interface, "Checkpoint Merger", "modelmerger"),
|
(modelmerger_interface, "Checkpoint Merger", "modelmerger"),
|
||||||
|
(textual_inversion_interface, "Textual inversion", "ti"),
|
||||||
(settings_interface, "Settings", "settings"),
|
(settings_interface, "Settings", "settings"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -1027,6 +1232,8 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
|
||||||
|
|
||||||
with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo:
|
with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo:
|
||||||
|
|
||||||
|
settings_interface.gradio_ref = demo
|
||||||
|
|
||||||
with gr.Tabs() as tabs:
|
with gr.Tabs() as tabs:
|
||||||
for interface, label, ifid in interfaces:
|
for interface, label, ifid in interfaces:
|
||||||
with gr.TabItem(label, id=ifid):
|
with gr.TabItem(label, id=ifid):
|
||||||
|
@ -1044,11 +1251,11 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
|
||||||
|
|
||||||
def modelmerger(*args):
|
def modelmerger(*args):
|
||||||
try:
|
try:
|
||||||
results = run_modelmerger(*args)
|
results = modules.extras.run_modelmerger(*args)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("Error loading/saving model file:", file=sys.stderr)
|
print("Error loading/saving model file:", file=sys.stderr)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
modules.sd_models.list_models() #To remove the potentially missing models from the list
|
modules.sd_models.list_models() # to remove the potentially missing models from the list
|
||||||
return ["Error loading/saving model file. It doesn't exist or the name contains illegal characters"] + [gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(3)]
|
return ["Error loading/saving model file. It doesn't exist or the name contains illegal characters"] + [gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(3)]
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
@ -1206,12 +1413,12 @@ for filename in sorted(os.listdir(jsdir)):
|
||||||
javascript += f"\n<script>{jsfile.read()}</script>"
|
javascript += f"\n<script>{jsfile.read()}</script>"
|
||||||
|
|
||||||
|
|
||||||
def template_response(*args, **kwargs):
|
if 'gradio_routes_templates_response' not in globals():
|
||||||
res = gradio_routes_templates_response(*args, **kwargs)
|
def template_response(*args, **kwargs):
|
||||||
res.body = res.body.replace(b'</head>', f'{javascript}</head>'.encode("utf8"))
|
res = gradio_routes_templates_response(*args, **kwargs)
|
||||||
res.init_headers()
|
res.body = res.body.replace(b'</head>', f'{javascript}</head>'.encode("utf8"))
|
||||||
return res
|
res.init_headers()
|
||||||
|
return res
|
||||||
|
|
||||||
|
gradio_routes_templates_response = gradio.routes.templates.TemplateResponse
|
||||||
gradio_routes_templates_response = gradio.routes.templates.TemplateResponse
|
gradio.routes.templates.TemplateResponse = template_response
|
||||||
gradio.routes.templates.TemplateResponse = template_response
|
|
||||||
|
|
|
@ -13,14 +13,12 @@ Pillow
|
||||||
pytorch_lightning
|
pytorch_lightning
|
||||||
realesrgan
|
realesrgan
|
||||||
scikit-image>=0.19
|
scikit-image>=0.19
|
||||||
git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379
|
|
||||||
timm==0.4.12
|
timm==0.4.12
|
||||||
transformers==4.19.2
|
transformers==4.19.2
|
||||||
torch
|
torch
|
||||||
einops
|
einops
|
||||||
jsonmerge
|
jsonmerge
|
||||||
clean-fid
|
clean-fid
|
||||||
git+https://github.com/openai/CLIP@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1
|
|
||||||
resize-right
|
resize-right
|
||||||
torchdiffeq
|
torchdiffeq
|
||||||
kornia
|
kornia
|
||||||
|
|
|
@ -18,7 +18,6 @@ piexif==1.1.3
|
||||||
einops==0.4.1
|
einops==0.4.1
|
||||||
jsonmerge==1.8.0
|
jsonmerge==1.8.0
|
||||||
clean-fid==0.1.29
|
clean-fid==0.1.29
|
||||||
git+https://github.com/openai/CLIP@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1
|
|
||||||
resize-right==0.0.2
|
resize-right==0.0.2
|
||||||
torchdiffeq==0.2.3
|
torchdiffeq==0.2.3
|
||||||
kornia==0.6.7
|
kornia==0.6.7
|
||||||
|
|
|
@ -34,7 +34,11 @@ class Script(scripts.Script):
|
||||||
seed = p.seed
|
seed = p.seed
|
||||||
|
|
||||||
init_img = p.init_images[0]
|
init_img = p.init_images[0]
|
||||||
img = upscaler.scaler.upscale(init_img, 2, upscaler.data_path)
|
|
||||||
|
if(upscaler.name != "None"):
|
||||||
|
img = upscaler.scaler.upscale(init_img, 2, upscaler.data_path)
|
||||||
|
else:
|
||||||
|
img = init_img
|
||||||
|
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
|
|
||||||
|
|
10
style.css
10
style.css
|
@ -157,7 +157,7 @@ button{
|
||||||
max-width: 10em;
|
max-width: 10em;
|
||||||
}
|
}
|
||||||
|
|
||||||
#txt2img_preview, #img2img_preview{
|
#txt2img_preview, #img2img_preview, #ti_preview{
|
||||||
position: absolute;
|
position: absolute;
|
||||||
width: 320px;
|
width: 320px;
|
||||||
left: 0;
|
left: 0;
|
||||||
|
@ -172,18 +172,18 @@ button{
|
||||||
}
|
}
|
||||||
|
|
||||||
@media screen and (min-width: 768px) {
|
@media screen and (min-width: 768px) {
|
||||||
#txt2img_preview, #img2img_preview {
|
#txt2img_preview, #img2img_preview, #ti_preview {
|
||||||
position: absolute;
|
position: absolute;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@media screen and (max-width: 767px) {
|
@media screen and (max-width: 767px) {
|
||||||
#txt2img_preview, #img2img_preview {
|
#txt2img_preview, #img2img_preview, #ti_preview {
|
||||||
position: relative;
|
position: relative;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#txt2img_preview div.left-0.top-0, #img2img_preview div.left-0.top-0{
|
#txt2img_preview div.left-0.top-0, #img2img_preview div.left-0.top-0, #ti_preview div.left-0.top-0{
|
||||||
display: none;
|
display: none;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -247,7 +247,7 @@ input[type="range"]{
|
||||||
#txt2img_negative_prompt, #img2img_negative_prompt{
|
#txt2img_negative_prompt, #img2img_negative_prompt{
|
||||||
}
|
}
|
||||||
|
|
||||||
#txt2img_progressbar, #img2img_progressbar{
|
#txt2img_progressbar, #img2img_progressbar, #ti_progressbar{
|
||||||
position: absolute;
|
position: absolute;
|
||||||
z-index: 1000;
|
z-index: 1000;
|
||||||
right: 0;
|
right: 0;
|
||||||
|
|
19
textual_inversion_templates/style.txt
Normal file
19
textual_inversion_templates/style.txt
Normal file
|
@ -0,0 +1,19 @@
|
||||||
|
a painting, art by [name]
|
||||||
|
a rendering, art by [name]
|
||||||
|
a cropped painting, art by [name]
|
||||||
|
the painting, art by [name]
|
||||||
|
a clean painting, art by [name]
|
||||||
|
a dirty painting, art by [name]
|
||||||
|
a dark painting, art by [name]
|
||||||
|
a picture, art by [name]
|
||||||
|
a cool painting, art by [name]
|
||||||
|
a close-up painting, art by [name]
|
||||||
|
a bright painting, art by [name]
|
||||||
|
a cropped painting, art by [name]
|
||||||
|
a good painting, art by [name]
|
||||||
|
a close-up painting, art by [name]
|
||||||
|
a rendition, art by [name]
|
||||||
|
a nice painting, art by [name]
|
||||||
|
a small painting, art by [name]
|
||||||
|
a weird painting, art by [name]
|
||||||
|
a large painting, art by [name]
|
19
textual_inversion_templates/style_filewords.txt
Normal file
19
textual_inversion_templates/style_filewords.txt
Normal file
|
@ -0,0 +1,19 @@
|
||||||
|
a painting of [filewords], art by [name]
|
||||||
|
a rendering of [filewords], art by [name]
|
||||||
|
a cropped painting of [filewords], art by [name]
|
||||||
|
the painting of [filewords], art by [name]
|
||||||
|
a clean painting of [filewords], art by [name]
|
||||||
|
a dirty painting of [filewords], art by [name]
|
||||||
|
a dark painting of [filewords], art by [name]
|
||||||
|
a picture of [filewords], art by [name]
|
||||||
|
a cool painting of [filewords], art by [name]
|
||||||
|
a close-up painting of [filewords], art by [name]
|
||||||
|
a bright painting of [filewords], art by [name]
|
||||||
|
a cropped painting of [filewords], art by [name]
|
||||||
|
a good painting of [filewords], art by [name]
|
||||||
|
a close-up painting of [filewords], art by [name]
|
||||||
|
a rendition of [filewords], art by [name]
|
||||||
|
a nice painting of [filewords], art by [name]
|
||||||
|
a small painting of [filewords], art by [name]
|
||||||
|
a weird painting of [filewords], art by [name]
|
||||||
|
a large painting of [filewords], art by [name]
|
27
textual_inversion_templates/subject.txt
Normal file
27
textual_inversion_templates/subject.txt
Normal file
|
@ -0,0 +1,27 @@
|
||||||
|
a photo of a [name]
|
||||||
|
a rendering of a [name]
|
||||||
|
a cropped photo of the [name]
|
||||||
|
the photo of a [name]
|
||||||
|
a photo of a clean [name]
|
||||||
|
a photo of a dirty [name]
|
||||||
|
a dark photo of the [name]
|
||||||
|
a photo of my [name]
|
||||||
|
a photo of the cool [name]
|
||||||
|
a close-up photo of a [name]
|
||||||
|
a bright photo of the [name]
|
||||||
|
a cropped photo of a [name]
|
||||||
|
a photo of the [name]
|
||||||
|
a good photo of the [name]
|
||||||
|
a photo of one [name]
|
||||||
|
a close-up photo of the [name]
|
||||||
|
a rendition of the [name]
|
||||||
|
a photo of the clean [name]
|
||||||
|
a rendition of a [name]
|
||||||
|
a photo of a nice [name]
|
||||||
|
a good photo of a [name]
|
||||||
|
a photo of the nice [name]
|
||||||
|
a photo of the small [name]
|
||||||
|
a photo of the weird [name]
|
||||||
|
a photo of the large [name]
|
||||||
|
a photo of a cool [name]
|
||||||
|
a photo of a small [name]
|
27
textual_inversion_templates/subject_filewords.txt
Normal file
27
textual_inversion_templates/subject_filewords.txt
Normal file
|
@ -0,0 +1,27 @@
|
||||||
|
a photo of a [name], [filewords]
|
||||||
|
a rendering of a [name], [filewords]
|
||||||
|
a cropped photo of the [name], [filewords]
|
||||||
|
the photo of a [name], [filewords]
|
||||||
|
a photo of a clean [name], [filewords]
|
||||||
|
a photo of a dirty [name], [filewords]
|
||||||
|
a dark photo of the [name], [filewords]
|
||||||
|
a photo of my [name], [filewords]
|
||||||
|
a photo of the cool [name], [filewords]
|
||||||
|
a close-up photo of a [name], [filewords]
|
||||||
|
a bright photo of the [name], [filewords]
|
||||||
|
a cropped photo of a [name], [filewords]
|
||||||
|
a photo of the [name], [filewords]
|
||||||
|
a good photo of the [name], [filewords]
|
||||||
|
a photo of one [name], [filewords]
|
||||||
|
a close-up photo of the [name], [filewords]
|
||||||
|
a rendition of the [name], [filewords]
|
||||||
|
a photo of the clean [name], [filewords]
|
||||||
|
a rendition of a [name], [filewords]
|
||||||
|
a photo of a nice [name], [filewords]
|
||||||
|
a good photo of a [name], [filewords]
|
||||||
|
a photo of the nice [name], [filewords]
|
||||||
|
a photo of the small [name], [filewords]
|
||||||
|
a photo of the weird [name], [filewords]
|
||||||
|
a photo of the large [name], [filewords]
|
||||||
|
a photo of a cool [name], [filewords]
|
||||||
|
a photo of a small [name], [filewords]
|
63
webui.py
63
webui.py
|
@ -1,34 +1,34 @@
|
||||||
import os
|
import os
|
||||||
import threading
|
import threading
|
||||||
|
import time
|
||||||
|
import importlib
|
||||||
from modules import devices
|
from modules import devices
|
||||||
from modules.paths import script_path
|
from modules.paths import script_path
|
||||||
import signal
|
import signal
|
||||||
import threading
|
import threading
|
||||||
import modules.paths
|
|
||||||
import modules.codeformer_model as codeformer
|
import modules.codeformer_model as codeformer
|
||||||
import modules.esrgan_model as esrgan
|
|
||||||
import modules.bsrgan_model as bsrgan
|
|
||||||
import modules.extras
|
import modules.extras
|
||||||
import modules.face_restoration
|
import modules.face_restoration
|
||||||
import modules.gfpgan_model as gfpgan
|
import modules.gfpgan_model as gfpgan
|
||||||
import modules.img2img
|
import modules.img2img
|
||||||
import modules.ldsr_model as ldsr
|
|
||||||
import modules.lowvram
|
import modules.lowvram
|
||||||
import modules.realesrgan_model as realesrgan
|
import modules.paths
|
||||||
import modules.scripts
|
import modules.scripts
|
||||||
import modules.sd_hijack
|
import modules.sd_hijack
|
||||||
import modules.sd_models
|
import modules.sd_models
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
import modules.swinir_model as swinir
|
|
||||||
import modules.txt2img
|
import modules.txt2img
|
||||||
|
|
||||||
import modules.ui
|
import modules.ui
|
||||||
|
from modules import devices
|
||||||
from modules import modelloader
|
from modules import modelloader
|
||||||
from modules.paths import script_path
|
from modules.paths import script_path
|
||||||
from modules.shared import cmd_opts
|
from modules.shared import cmd_opts
|
||||||
|
|
||||||
modelloader.cleanup_models()
|
modelloader.cleanup_models()
|
||||||
modules.sd_models.setup_model(cmd_opts.ckpt_dir)
|
modules.sd_models.setup_model()
|
||||||
codeformer.setup_model(cmd_opts.codeformer_models_path)
|
codeformer.setup_model(cmd_opts.codeformer_models_path)
|
||||||
gfpgan.setup_model(cmd_opts.gfpgan_models_path)
|
gfpgan.setup_model(cmd_opts.gfpgan_models_path)
|
||||||
shared.face_restorers.append(modules.face_restoration.FaceRestoration())
|
shared.face_restorers.append(modules.face_restoration.FaceRestoration())
|
||||||
|
@ -46,7 +46,7 @@ def wrap_queued_call(func):
|
||||||
return f
|
return f
|
||||||
|
|
||||||
|
|
||||||
def wrap_gradio_gpu_call(func):
|
def wrap_gradio_gpu_call(func, extra_outputs=None):
|
||||||
def f(*args, **kwargs):
|
def f(*args, **kwargs):
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
|
|
||||||
|
@ -58,6 +58,7 @@ def wrap_gradio_gpu_call(func):
|
||||||
shared.state.current_image = None
|
shared.state.current_image = None
|
||||||
shared.state.current_image_sampling_step = 0
|
shared.state.current_image_sampling_step = 0
|
||||||
shared.state.interrupted = False
|
shared.state.interrupted = False
|
||||||
|
shared.state.textinfo = None
|
||||||
|
|
||||||
with queue_lock:
|
with queue_lock:
|
||||||
res = func(*args, **kwargs)
|
res = func(*args, **kwargs)
|
||||||
|
@ -69,7 +70,7 @@ def wrap_gradio_gpu_call(func):
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
return modules.ui.wrap_gradio_call(f)
|
return modules.ui.wrap_gradio_call(f, extra_outputs=extra_outputs)
|
||||||
|
|
||||||
|
|
||||||
modules.scripts.load_scripts(os.path.join(script_path, "scripts"))
|
modules.scripts.load_scripts(os.path.join(script_path, "scripts"))
|
||||||
|
@ -86,22 +87,34 @@ def webui():
|
||||||
|
|
||||||
signal.signal(signal.SIGINT, sigint_handler)
|
signal.signal(signal.SIGINT, sigint_handler)
|
||||||
|
|
||||||
demo = modules.ui.create_ui(
|
while 1:
|
||||||
txt2img=wrap_gradio_gpu_call(modules.txt2img.txt2img),
|
|
||||||
img2img=wrap_gradio_gpu_call(modules.img2img.img2img),
|
demo = modules.ui.create_ui(wrap_gradio_gpu_call=wrap_gradio_gpu_call)
|
||||||
run_extras=wrap_gradio_gpu_call(modules.extras.run_extras),
|
|
||||||
run_pnginfo=modules.extras.run_pnginfo,
|
demo.launch(
|
||||||
run_modelmerger=modules.extras.run_modelmerger
|
share=cmd_opts.share,
|
||||||
)
|
server_name="0.0.0.0" if cmd_opts.listen else None,
|
||||||
|
server_port=cmd_opts.port,
|
||||||
|
debug=cmd_opts.gradio_debug,
|
||||||
|
auth=[tuple(cred.split(':')) for cred in cmd_opts.gradio_auth.strip('"').split(',')] if cmd_opts.gradio_auth else None,
|
||||||
|
inbrowser=cmd_opts.autolaunch,
|
||||||
|
prevent_thread_lock=True
|
||||||
|
)
|
||||||
|
|
||||||
|
while 1:
|
||||||
|
time.sleep(0.5)
|
||||||
|
if getattr(demo, 'do_restart', False):
|
||||||
|
time.sleep(0.5)
|
||||||
|
demo.close()
|
||||||
|
time.sleep(0.5)
|
||||||
|
break
|
||||||
|
|
||||||
|
print('Reloading Custom Scripts')
|
||||||
|
modules.scripts.reload_scripts(os.path.join(script_path, "scripts"))
|
||||||
|
print('Reloading modules: modules.ui')
|
||||||
|
importlib.reload(modules.ui)
|
||||||
|
print('Restarting Gradio')
|
||||||
|
|
||||||
demo.launch(
|
|
||||||
share=cmd_opts.share,
|
|
||||||
server_name="0.0.0.0" if cmd_opts.listen else None,
|
|
||||||
server_port=cmd_opts.port,
|
|
||||||
debug=cmd_opts.gradio_debug,
|
|
||||||
auth=[tuple(cred.split(':')) for cred in cmd_opts.gradio_auth.strip('"').split(',')] if cmd_opts.gradio_auth else None,
|
|
||||||
inbrowser=cmd_opts.autolaunch,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
Loading…
Reference in a new issue