Merge remote-tracking branch 'upstream/master' into sub-quad_attn_opt
This commit is contained in:
commit
3bfe2bb549
27 changed files with 453 additions and 157 deletions
8
javascript/dragdrop.js
vendored
8
javascript/dragdrop.js
vendored
|
@ -9,11 +9,19 @@ function dropReplaceImage( imgWrap, files ) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const tmpFile = files[0];
|
||||||
|
|
||||||
imgWrap.querySelector('.modify-upload button + button, .touch-none + div button + button')?.click();
|
imgWrap.querySelector('.modify-upload button + button, .touch-none + div button + button')?.click();
|
||||||
const callback = () => {
|
const callback = () => {
|
||||||
const fileInput = imgWrap.querySelector('input[type="file"]');
|
const fileInput = imgWrap.querySelector('input[type="file"]');
|
||||||
if ( fileInput ) {
|
if ( fileInput ) {
|
||||||
|
if ( files.length === 0 ) {
|
||||||
|
files = new DataTransfer();
|
||||||
|
files.items.add(tmpFile);
|
||||||
|
fileInput.files = files.files;
|
||||||
|
} else {
|
||||||
fileInput.files = files;
|
fileInput.files = files;
|
||||||
|
}
|
||||||
fileInput.dispatchEvent(new Event('change'));
|
fileInput.dispatchEvent(new Event('change'));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -81,9 +81,6 @@ titles = {
|
||||||
|
|
||||||
"vram": "Torch active: Peak amount of VRAM used by Torch during generation, excluding cached data.\nTorch reserved: Peak amount of VRAM allocated by Torch, including all active and cached data.\nSys VRAM: Peak amount of VRAM allocation across all applications / total GPU VRAM (peak utilization%).",
|
"vram": "Torch active: Peak amount of VRAM used by Torch during generation, excluding cached data.\nTorch reserved: Peak amount of VRAM allocated by Torch, including all active and cached data.\nSys VRAM: Peak amount of VRAM allocation across all applications / total GPU VRAM (peak utilization%).",
|
||||||
|
|
||||||
"Highres. fix": "Use a two step process to partially create an image at smaller resolution, upscale, and then improve details in it without changing composition",
|
|
||||||
"Scale latent": "Uscale the image in latent space. Alternative is to produce the full image from latent representation, upscale that, and then move it back to latent space.",
|
|
||||||
|
|
||||||
"Eta noise seed delta": "If this values is non-zero, it will be added to seed and used to initialize RNG for noises when using samplers with Eta. You can use this to produce even more variation of images, or you can use this to match images of other software if you know what you are doing.",
|
"Eta noise seed delta": "If this values is non-zero, it will be added to seed and used to initialize RNG for noises when using samplers with Eta. You can use this to produce even more variation of images, or you can use this to match images of other software if you know what you are doing.",
|
||||||
"Do not add watermark to images": "If this option is enabled, watermark will not be added to created images. Warning: if you do not add watermark, you may be behaving in an unethical manner.",
|
"Do not add watermark to images": "If this option is enabled, watermark will not be added to created images. Warning: if you do not add watermark, you may be behaving in an unethical manner.",
|
||||||
|
|
||||||
|
@ -100,7 +97,13 @@ titles = {
|
||||||
"Clip skip": "Early stopping parameter for CLIP model; 1 is stop at last layer as usual, 2 is stop at penultimate layer, etc.",
|
"Clip skip": "Early stopping parameter for CLIP model; 1 is stop at last layer as usual, 2 is stop at penultimate layer, etc.",
|
||||||
|
|
||||||
"Approx NN": "Cheap neural network approximation. Very fast compared to VAE, but produces pictures with 4 times smaller horizontal/vertical resoluton and lower quality.",
|
"Approx NN": "Cheap neural network approximation. Very fast compared to VAE, but produces pictures with 4 times smaller horizontal/vertical resoluton and lower quality.",
|
||||||
"Approx cheap": "Very cheap approximation. Very fast compared to VAE, but produces pictures with 8 times smaller horizontal/vertical resoluton and extremely low quality."
|
"Approx cheap": "Very cheap approximation. Very fast compared to VAE, but produces pictures with 8 times smaller horizontal/vertical resoluton and extremely low quality.",
|
||||||
|
|
||||||
|
"Hires. fix": "Use a two step process to partially create an image at smaller resolution, upscale, and then improve details in it without changing composition",
|
||||||
|
"Hires steps": "Number of sampling steps for upscaled picture. If 0, uses same as for original.",
|
||||||
|
"Upscale by": "Adjusts the size of the image by multiplying the original width and height by the selected value. Ignored if either Resize width to or Resize height to are non-zero.",
|
||||||
|
"Resize width to": "Resizes image to this width. If 0, width is inferred from either of two nearby sliders.",
|
||||||
|
"Resize height to": "Resizes image to this height. If 0, height is inferred from either of two nearby sliders."
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -148,8 +148,8 @@ function showGalleryImage() {
|
||||||
if(e && e.parentElement.tagName == 'DIV'){
|
if(e && e.parentElement.tagName == 'DIV'){
|
||||||
e.style.cursor='pointer'
|
e.style.cursor='pointer'
|
||||||
e.style.userSelect='none'
|
e.style.userSelect='none'
|
||||||
e.addEventListener('click', function (evt) {
|
e.addEventListener('mousedown', function (evt) {
|
||||||
if(!opts.js_modal_lightbox) return;
|
if(!opts.js_modal_lightbox || evt.button != 0) return;
|
||||||
modalZoomSet(gradioApp().getElementById('modalImage'), opts.js_modal_lightbox_initially_zoomed)
|
modalZoomSet(gradioApp().getElementById('modalImage'), opts.js_modal_lightbox_initially_zoomed)
|
||||||
showModal(evt)
|
showModal(evt)
|
||||||
}, true);
|
}, true);
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
// various functions for interation with ui.py not large enough to warrant putting them in separate files
|
// various functions for interaction with ui.py not large enough to warrant putting them in separate files
|
||||||
|
|
||||||
function set_theme(theme){
|
function set_theme(theme){
|
||||||
gradioURL = window.location.href
|
gradioURL = window.location.href
|
||||||
|
|
|
@ -1,11 +1,12 @@
|
||||||
import base64
|
import base64
|
||||||
import io
|
import io
|
||||||
import time
|
import time
|
||||||
|
import datetime
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from threading import Lock
|
from threading import Lock
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from gradio.processing_utils import decode_base64_to_file
|
from gradio.processing_utils import decode_base64_to_file
|
||||||
from fastapi import APIRouter, Depends, FastAPI, HTTPException
|
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request, Response
|
||||||
from fastapi.security import HTTPBasic, HTTPBasicCredentials
|
from fastapi.security import HTTPBasic, HTTPBasicCredentials
|
||||||
from secrets import compare_digest
|
from secrets import compare_digest
|
||||||
|
|
||||||
|
@ -18,7 +19,7 @@ from modules.textual_inversion.textual_inversion import create_embedding, train_
|
||||||
from modules.textual_inversion.preprocess import preprocess
|
from modules.textual_inversion.preprocess import preprocess
|
||||||
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
|
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
|
||||||
from PIL import PngImagePlugin,Image
|
from PIL import PngImagePlugin,Image
|
||||||
from modules.sd_models import checkpoints_list
|
from modules.sd_models import checkpoints_list, find_checkpoint_config
|
||||||
from modules.realesrgan_model import get_realesrgan_models
|
from modules.realesrgan_model import get_realesrgan_models
|
||||||
from modules import devices
|
from modules import devices
|
||||||
from typing import List
|
from typing import List
|
||||||
|
@ -67,6 +68,27 @@ def encode_pil_to_base64(image):
|
||||||
bytes_data = output_bytes.getvalue()
|
bytes_data = output_bytes.getvalue()
|
||||||
return base64.b64encode(bytes_data)
|
return base64.b64encode(bytes_data)
|
||||||
|
|
||||||
|
def api_middleware(app: FastAPI):
|
||||||
|
@app.middleware("http")
|
||||||
|
async def log_and_time(req: Request, call_next):
|
||||||
|
ts = time.time()
|
||||||
|
res: Response = await call_next(req)
|
||||||
|
duration = str(round(time.time() - ts, 4))
|
||||||
|
res.headers["X-Process-Time"] = duration
|
||||||
|
endpoint = req.scope.get('path', 'err')
|
||||||
|
if shared.cmd_opts.api_log and endpoint.startswith('/sdapi'):
|
||||||
|
print('API {t} {code} {prot}/{ver} {method} {endpoint} {cli} {duration}'.format(
|
||||||
|
t = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f"),
|
||||||
|
code = res.status_code,
|
||||||
|
ver = req.scope.get('http_version', '0.0'),
|
||||||
|
cli = req.scope.get('client', ('0:0.0.0', 0))[0],
|
||||||
|
prot = req.scope.get('scheme', 'err'),
|
||||||
|
method = req.scope.get('method', 'err'),
|
||||||
|
endpoint = endpoint,
|
||||||
|
duration = duration,
|
||||||
|
))
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
class Api:
|
class Api:
|
||||||
def __init__(self, app: FastAPI, queue_lock: Lock):
|
def __init__(self, app: FastAPI, queue_lock: Lock):
|
||||||
|
@ -79,6 +101,7 @@ class Api:
|
||||||
self.router = APIRouter()
|
self.router = APIRouter()
|
||||||
self.app = app
|
self.app = app
|
||||||
self.queue_lock = queue_lock
|
self.queue_lock = queue_lock
|
||||||
|
api_middleware(self.app)
|
||||||
self.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=TextToImageResponse)
|
self.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=TextToImageResponse)
|
||||||
self.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=ImageToImageResponse)
|
self.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=ImageToImageResponse)
|
||||||
self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse)
|
self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse)
|
||||||
|
@ -303,7 +326,7 @@ class Api:
|
||||||
return upscalers
|
return upscalers
|
||||||
|
|
||||||
def get_sd_models(self):
|
def get_sd_models(self):
|
||||||
return [{"title":x.title, "model_name":x.model_name, "hash":x.hash, "filename": x.filename, "config": x.config} for x in checkpoints_list.values()]
|
return [{"title":x.title, "model_name":x.model_name, "hash":x.hash, "filename": x.filename, "config": find_checkpoint_config(x)} for x in checkpoints_list.values()]
|
||||||
|
|
||||||
def get_hypernetworks(self):
|
def get_hypernetworks(self):
|
||||||
return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks]
|
return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks]
|
||||||
|
|
|
@ -2,9 +2,30 @@ import sys
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
|
|
||||||
|
def print_error_explanation(message):
|
||||||
|
lines = message.strip().split("\n")
|
||||||
|
max_len = max([len(x) for x in lines])
|
||||||
|
|
||||||
|
print('=' * max_len, file=sys.stderr)
|
||||||
|
for line in lines:
|
||||||
|
print(line, file=sys.stderr)
|
||||||
|
print('=' * max_len, file=sys.stderr)
|
||||||
|
|
||||||
|
|
||||||
|
def display(e: Exception, task):
|
||||||
|
print(f"{task or 'error'}: {type(e).__name__}", file=sys.stderr)
|
||||||
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
|
||||||
|
message = str(e)
|
||||||
|
if "copying a param with shape torch.Size([640, 1024]) from checkpoint, the shape in current model is torch.Size([640, 768])" in message:
|
||||||
|
print_error_explanation("""
|
||||||
|
The most likely cause of this is you are trying to load Stable Diffusion 2.0 model without specifying its connfig file.
|
||||||
|
See https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#stable-diffusion-20 for how to solve this.
|
||||||
|
""")
|
||||||
|
|
||||||
|
|
||||||
def run(code, task):
|
def run(code, task):
|
||||||
try:
|
try:
|
||||||
code()
|
code()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"{task}: {type(e).__name__}", file=sys.stderr)
|
display(task, e)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
|
|
|
@ -19,8 +19,6 @@ from modules.shared import opts
|
||||||
import modules.gfpgan_model
|
import modules.gfpgan_model
|
||||||
from modules.ui import plaintext_to_html
|
from modules.ui import plaintext_to_html
|
||||||
import modules.codeformer_model
|
import modules.codeformer_model
|
||||||
import piexif
|
|
||||||
import piexif.helper
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import safetensors.torch
|
import safetensors.torch
|
||||||
|
|
||||||
|
@ -58,6 +56,9 @@ cached_images: LruCache = LruCache(max_size=5)
|
||||||
def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True):
|
def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True):
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
|
|
||||||
|
shared.state.begin()
|
||||||
|
shared.state.job = 'extras'
|
||||||
|
|
||||||
imageArr = []
|
imageArr = []
|
||||||
# Also keep track of original file names
|
# Also keep track of original file names
|
||||||
imageNameArr = []
|
imageNameArr = []
|
||||||
|
@ -94,6 +95,7 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
|
||||||
# Extra operation definitions
|
# Extra operation definitions
|
||||||
|
|
||||||
def run_gfpgan(image: Image.Image, info: str) -> Tuple[Image.Image, str]:
|
def run_gfpgan(image: Image.Image, info: str) -> Tuple[Image.Image, str]:
|
||||||
|
shared.state.job = 'extras-gfpgan'
|
||||||
restored_img = modules.gfpgan_model.gfpgan_fix_faces(np.array(image, dtype=np.uint8))
|
restored_img = modules.gfpgan_model.gfpgan_fix_faces(np.array(image, dtype=np.uint8))
|
||||||
res = Image.fromarray(restored_img)
|
res = Image.fromarray(restored_img)
|
||||||
|
|
||||||
|
@ -104,6 +106,7 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
|
||||||
return (res, info)
|
return (res, info)
|
||||||
|
|
||||||
def run_codeformer(image: Image.Image, info: str) -> Tuple[Image.Image, str]:
|
def run_codeformer(image: Image.Image, info: str) -> Tuple[Image.Image, str]:
|
||||||
|
shared.state.job = 'extras-codeformer'
|
||||||
restored_img = modules.codeformer_model.codeformer.restore(np.array(image, dtype=np.uint8), w=codeformer_weight)
|
restored_img = modules.codeformer_model.codeformer.restore(np.array(image, dtype=np.uint8), w=codeformer_weight)
|
||||||
res = Image.fromarray(restored_img)
|
res = Image.fromarray(restored_img)
|
||||||
|
|
||||||
|
@ -114,6 +117,7 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
|
||||||
return (res, info)
|
return (res, info)
|
||||||
|
|
||||||
def upscale(image, scaler_index, resize, mode, resize_w, resize_h, crop):
|
def upscale(image, scaler_index, resize, mode, resize_w, resize_h, crop):
|
||||||
|
shared.state.job = 'extras-upscale'
|
||||||
upscaler = shared.sd_upscalers[scaler_index]
|
upscaler = shared.sd_upscalers[scaler_index]
|
||||||
res = upscaler.scaler.upscale(image, resize, upscaler.data_path)
|
res = upscaler.scaler.upscale(image, resize, upscaler.data_path)
|
||||||
if mode == 1 and crop:
|
if mode == 1 and crop:
|
||||||
|
@ -180,6 +184,9 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
|
||||||
for image, image_name in zip(imageArr, imageNameArr):
|
for image, image_name in zip(imageArr, imageNameArr):
|
||||||
if image is None:
|
if image is None:
|
||||||
return outputs, "Please select an input image.", ''
|
return outputs, "Please select an input image.", ''
|
||||||
|
|
||||||
|
shared.state.textinfo = f'Processing image {image_name}'
|
||||||
|
|
||||||
existing_pnginfo = image.info or {}
|
existing_pnginfo = image.info or {}
|
||||||
|
|
||||||
image = image.convert("RGB")
|
image = image.convert("RGB")
|
||||||
|
@ -193,6 +200,10 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
|
||||||
else:
|
else:
|
||||||
basename = ''
|
basename = ''
|
||||||
|
|
||||||
|
if opts.enable_pnginfo: # append info before save
|
||||||
|
image.info = existing_pnginfo
|
||||||
|
image.info["extras"] = info
|
||||||
|
|
||||||
if save_output:
|
if save_output:
|
||||||
# Add upscaler name as a suffix.
|
# Add upscaler name as a suffix.
|
||||||
suffix = f"-{shared.sd_upscalers[extras_upscaler_1].name}" if shared.opts.use_upscaler_name_as_suffix else ""
|
suffix = f"-{shared.sd_upscalers[extras_upscaler_1].name}" if shared.opts.use_upscaler_name_as_suffix else ""
|
||||||
|
@ -203,10 +214,6 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
|
||||||
images.save_image(image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True,
|
images.save_image(image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True,
|
||||||
no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None, suffix=suffix)
|
no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None, suffix=suffix)
|
||||||
|
|
||||||
if opts.enable_pnginfo:
|
|
||||||
image.info = existing_pnginfo
|
|
||||||
image.info["extras"] = info
|
|
||||||
|
|
||||||
if extras_mode != 2 or show_extras_results :
|
if extras_mode != 2 or show_extras_results :
|
||||||
outputs.append(image)
|
outputs.append(image)
|
||||||
|
|
||||||
|
@ -242,6 +249,9 @@ def run_pnginfo(image):
|
||||||
|
|
||||||
|
|
||||||
def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format):
|
def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format):
|
||||||
|
shared.state.begin()
|
||||||
|
shared.state.job = 'model-merge'
|
||||||
|
|
||||||
def weighted_sum(theta0, theta1, alpha):
|
def weighted_sum(theta0, theta1, alpha):
|
||||||
return ((1 - alpha) * theta0) + (alpha * theta1)
|
return ((1 - alpha) * theta0) + (alpha * theta1)
|
||||||
|
|
||||||
|
@ -263,8 +273,11 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
|
||||||
theta_func1, theta_func2 = theta_funcs[interp_method]
|
theta_func1, theta_func2 = theta_funcs[interp_method]
|
||||||
|
|
||||||
if theta_func1 and not tertiary_model_info:
|
if theta_func1 and not tertiary_model_info:
|
||||||
|
shared.state.textinfo = "Failed: Interpolation method requires a tertiary model."
|
||||||
|
shared.state.end()
|
||||||
return ["Failed: Interpolation method requires a tertiary model."] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)]
|
return ["Failed: Interpolation method requires a tertiary model."] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)]
|
||||||
|
|
||||||
|
shared.state.textinfo = f"Loading {secondary_model_info.filename}..."
|
||||||
print(f"Loading {secondary_model_info.filename}...")
|
print(f"Loading {secondary_model_info.filename}...")
|
||||||
theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu')
|
theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu')
|
||||||
|
|
||||||
|
@ -281,6 +294,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
|
||||||
theta_1[key] = torch.zeros_like(theta_1[key])
|
theta_1[key] = torch.zeros_like(theta_1[key])
|
||||||
del theta_2
|
del theta_2
|
||||||
|
|
||||||
|
shared.state.textinfo = f"Loading {primary_model_info.filename}..."
|
||||||
print(f"Loading {primary_model_info.filename}...")
|
print(f"Loading {primary_model_info.filename}...")
|
||||||
theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cpu')
|
theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cpu')
|
||||||
|
|
||||||
|
@ -291,6 +305,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
|
||||||
a = theta_0[key]
|
a = theta_0[key]
|
||||||
b = theta_1[key]
|
b = theta_1[key]
|
||||||
|
|
||||||
|
shared.state.textinfo = f'Merging layer {key}'
|
||||||
# this enables merging an inpainting model (A) with another one (B);
|
# this enables merging an inpainting model (A) with another one (B);
|
||||||
# where normal model would have 4 channels, for latenst space, inpainting model would
|
# where normal model would have 4 channels, for latenst space, inpainting model would
|
||||||
# have another 4 channels for unmasked picture's latent space, plus one channel for mask, for a total of 9
|
# have another 4 channels for unmasked picture's latent space, plus one channel for mask, for a total of 9
|
||||||
|
@ -303,8 +318,6 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
|
||||||
theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier)
|
theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier)
|
||||||
result_is_inpainting_model = True
|
result_is_inpainting_model = True
|
||||||
else:
|
else:
|
||||||
assert a.shape == b.shape, f'Incompatible shapes for layer {key}: A is {a.shape}, and B is {b.shape}'
|
|
||||||
|
|
||||||
theta_0[key] = theta_func2(a, b, multiplier)
|
theta_0[key] = theta_func2(a, b, multiplier)
|
||||||
|
|
||||||
if save_as_half:
|
if save_as_half:
|
||||||
|
@ -332,6 +345,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
|
||||||
|
|
||||||
output_modelname = os.path.join(ckpt_dir, filename)
|
output_modelname = os.path.join(ckpt_dir, filename)
|
||||||
|
|
||||||
|
shared.state.textinfo = f"Saving to {output_modelname}..."
|
||||||
print(f"Saving to {output_modelname}...")
|
print(f"Saving to {output_modelname}...")
|
||||||
|
|
||||||
_, extension = os.path.splitext(output_modelname)
|
_, extension = os.path.splitext(output_modelname)
|
||||||
|
@ -343,4 +357,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
|
||||||
sd_models.list_models()
|
sd_models.list_models()
|
||||||
|
|
||||||
print("Checkpoint saved.")
|
print("Checkpoint saved.")
|
||||||
|
shared.state.textinfo = "Checkpoint saved to " + output_modelname
|
||||||
|
shared.state.end()
|
||||||
|
|
||||||
return ["Checkpoint saved to " + output_modelname] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)]
|
return ["Checkpoint saved to " + output_modelname] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)]
|
||||||
|
|
|
@ -212,11 +212,10 @@ def restore_old_hires_fix_params(res):
|
||||||
firstpass_width = math.ceil(scale * width / 64) * 64
|
firstpass_width = math.ceil(scale * width / 64) * 64
|
||||||
firstpass_height = math.ceil(scale * height / 64) * 64
|
firstpass_height = math.ceil(scale * height / 64) * 64
|
||||||
|
|
||||||
hr_scale = width / firstpass_width if firstpass_width > 0 else height / firstpass_height
|
|
||||||
|
|
||||||
res['Size-1'] = firstpass_width
|
res['Size-1'] = firstpass_width
|
||||||
res['Size-2'] = firstpass_height
|
res['Size-2'] = firstpass_height
|
||||||
res['Hires upscale'] = hr_scale
|
res['Hires resize-1'] = width
|
||||||
|
res['Hires resize-2'] = height
|
||||||
|
|
||||||
|
|
||||||
def parse_generation_parameters(x: str):
|
def parse_generation_parameters(x: str):
|
||||||
|
@ -276,6 +275,10 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
|
||||||
hypernet_hash = res.get("Hypernet hash", None)
|
hypernet_hash = res.get("Hypernet hash", None)
|
||||||
res["Hypernet"] = find_hypernetwork_key(hypernet_name, hypernet_hash)
|
res["Hypernet"] = find_hypernetwork_key(hypernet_name, hypernet_hash)
|
||||||
|
|
||||||
|
if "Hires resize-1" not in res:
|
||||||
|
res["Hires resize-1"] = 0
|
||||||
|
res["Hires resize-2"] = 0
|
||||||
|
|
||||||
restore_old_hires_fix_params(res)
|
restore_old_hires_fix_params(res)
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
|
@ -402,10 +402,8 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None,
|
||||||
|
|
||||||
shared.reload_hypernetworks()
|
shared.reload_hypernetworks()
|
||||||
|
|
||||||
return fn
|
|
||||||
|
|
||||||
|
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
||||||
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
|
||||||
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
|
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
|
||||||
from modules import images
|
from modules import images
|
||||||
|
|
||||||
|
@ -417,6 +415,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
|
||||||
shared.loaded_hypernetwork = Hypernetwork()
|
shared.loaded_hypernetwork = Hypernetwork()
|
||||||
shared.loaded_hypernetwork.load(path)
|
shared.loaded_hypernetwork.load(path)
|
||||||
|
|
||||||
|
shared.state.job = "train-hypernetwork"
|
||||||
shared.state.textinfo = "Initializing hypernetwork training..."
|
shared.state.textinfo = "Initializing hypernetwork training..."
|
||||||
shared.state.job_count = steps
|
shared.state.job_count = steps
|
||||||
|
|
||||||
|
@ -448,6 +447,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
|
||||||
|
|
||||||
scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
|
scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
|
||||||
|
|
||||||
|
clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else None
|
||||||
|
if clip_grad:
|
||||||
|
clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, initial_step, verbose=False)
|
||||||
|
|
||||||
# dataset loading may take a while, so input validations and early returns should be done before this
|
# dataset loading may take a while, so input validations and early returns should be done before this
|
||||||
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
||||||
|
|
||||||
|
@ -524,6 +527,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
|
||||||
if shared.state.interrupted:
|
if shared.state.interrupted:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
if clip_grad:
|
||||||
|
clip_grad_sched.step(hypernetwork.step)
|
||||||
|
|
||||||
with devices.autocast():
|
with devices.autocast():
|
||||||
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
|
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
|
||||||
if tag_drop_out != 0 or shuffle_tags:
|
if tag_drop_out != 0 or shuffle_tags:
|
||||||
|
@ -538,14 +544,14 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
|
||||||
|
|
||||||
_loss_step += loss.item()
|
_loss_step += loss.item()
|
||||||
scaler.scale(loss).backward()
|
scaler.scale(loss).backward()
|
||||||
|
|
||||||
# go back until we reach gradient accumulation steps
|
# go back until we reach gradient accumulation steps
|
||||||
if (j + 1) % gradient_step != 0:
|
if (j + 1) % gradient_step != 0:
|
||||||
continue
|
continue
|
||||||
# print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.7f}")
|
|
||||||
# scaler.unscale_(optimizer)
|
if clip_grad:
|
||||||
# print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.15f}")
|
clip_grad(weights, clip_grad_sched.learn_rate)
|
||||||
# torch.nn.utils.clip_grad_norm_(weights, max_norm=1.0)
|
|
||||||
# print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.15f}")
|
|
||||||
scaler.step(optimizer)
|
scaler.step(optimizer)
|
||||||
scaler.update()
|
scaler.update()
|
||||||
hypernetwork.step += 1
|
hypernetwork.step += 1
|
||||||
|
|
|
@ -136,7 +136,8 @@ class InterrogateModels:
|
||||||
|
|
||||||
def interrogate(self, pil_image):
|
def interrogate(self, pil_image):
|
||||||
res = ""
|
res = ""
|
||||||
|
shared.state.begin()
|
||||||
|
shared.state.job = 'interrogate'
|
||||||
try:
|
try:
|
||||||
|
|
||||||
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
||||||
|
@ -177,5 +178,6 @@ class InterrogateModels:
|
||||||
res += "<error>"
|
res += "<error>"
|
||||||
|
|
||||||
self.unload()
|
self.unload()
|
||||||
|
shared.state.end()
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
|
@ -76,6 +76,24 @@ def apply_overlay(image, paste_loc, index, overlays):
|
||||||
return image
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
def txt2img_image_conditioning(sd_model, x, width, height):
|
||||||
|
if sd_model.model.conditioning_key not in {'hybrid', 'concat'}:
|
||||||
|
# Dummy zero conditioning if we're not using inpainting model.
|
||||||
|
# Still takes up a bit of memory, but no encoder call.
|
||||||
|
# Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
|
||||||
|
return x.new_zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device)
|
||||||
|
|
||||||
|
# The "masked-image" in this case will just be all zeros since the entire image is masked.
|
||||||
|
image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
|
||||||
|
image_conditioning = sd_model.get_first_stage_encoding(sd_model.encode_first_stage(image_conditioning))
|
||||||
|
|
||||||
|
# Add the fake full 1s mask to the first dimension.
|
||||||
|
image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
|
||||||
|
image_conditioning = image_conditioning.to(x.dtype)
|
||||||
|
|
||||||
|
return image_conditioning
|
||||||
|
|
||||||
|
|
||||||
class StableDiffusionProcessing():
|
class StableDiffusionProcessing():
|
||||||
"""
|
"""
|
||||||
The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
|
The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
|
||||||
|
@ -136,28 +154,12 @@ class StableDiffusionProcessing():
|
||||||
self.all_negative_prompts = None
|
self.all_negative_prompts = None
|
||||||
self.all_seeds = None
|
self.all_seeds = None
|
||||||
self.all_subseeds = None
|
self.all_subseeds = None
|
||||||
|
self.iteration = 0
|
||||||
|
|
||||||
def txt2img_image_conditioning(self, x, width=None, height=None):
|
def txt2img_image_conditioning(self, x, width=None, height=None):
|
||||||
if self.sampler.conditioning_key not in {'hybrid', 'concat'}:
|
self.is_using_inpainting_conditioning = self.sd_model.model.conditioning_key in {'hybrid', 'concat'}
|
||||||
# Dummy zero conditioning if we're not using inpainting model.
|
|
||||||
# Still takes up a bit of memory, but no encoder call.
|
|
||||||
# Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
|
|
||||||
return x.new_zeros(x.shape[0], 5, 1, 1)
|
|
||||||
|
|
||||||
self.is_using_inpainting_conditioning = True
|
return txt2img_image_conditioning(self.sd_model, x, width or self.width, height or self.height)
|
||||||
|
|
||||||
height = height or self.height
|
|
||||||
width = width or self.width
|
|
||||||
|
|
||||||
# The "masked-image" in this case will just be all zeros since the entire image is masked.
|
|
||||||
image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
|
|
||||||
image_conditioning = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image_conditioning))
|
|
||||||
|
|
||||||
# Add the fake full 1s mask to the first dimension.
|
|
||||||
image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
|
|
||||||
image_conditioning = image_conditioning.to(x.dtype)
|
|
||||||
|
|
||||||
return image_conditioning
|
|
||||||
|
|
||||||
def depth2img_image_conditioning(self, source_image):
|
def depth2img_image_conditioning(self, source_image):
|
||||||
# Use the AddMiDaS helper to Format our source image to suit the MiDaS model
|
# Use the AddMiDaS helper to Format our source image to suit the MiDaS model
|
||||||
|
@ -420,7 +422,7 @@ def fix_seed(p):
|
||||||
p.subseed = get_fixed_seed(p.subseed)
|
p.subseed = get_fixed_seed(p.subseed)
|
||||||
|
|
||||||
|
|
||||||
def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration=0, position_in_batch=0):
|
def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0):
|
||||||
index = position_in_batch + iteration * p.batch_size
|
index = position_in_batch + iteration * p.batch_size
|
||||||
|
|
||||||
clip_skip = getattr(p, 'clip_skip', opts.CLIP_stop_at_last_layers)
|
clip_skip = getattr(p, 'clip_skip', opts.CLIP_stop_at_last_layers)
|
||||||
|
@ -544,6 +546,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||||
state.job_count = p.n_iter
|
state.job_count = p.n_iter
|
||||||
|
|
||||||
for n in range(p.n_iter):
|
for n in range(p.n_iter):
|
||||||
|
p.iteration = n
|
||||||
|
|
||||||
if state.skipped:
|
if state.skipped:
|
||||||
state.skipped = False
|
state.skipped = False
|
||||||
|
|
||||||
|
@ -658,12 +662,17 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||||
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||||
sampler = None
|
sampler = None
|
||||||
|
|
||||||
def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, **kwargs):
|
def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, hr_second_pass_steps: int = 0, hr_resize_x: int = 0, hr_resize_y: int = 0, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.enable_hr = enable_hr
|
self.enable_hr = enable_hr
|
||||||
self.denoising_strength = denoising_strength
|
self.denoising_strength = denoising_strength
|
||||||
self.hr_scale = hr_scale
|
self.hr_scale = hr_scale
|
||||||
self.hr_upscaler = hr_upscaler
|
self.hr_upscaler = hr_upscaler
|
||||||
|
self.hr_second_pass_steps = hr_second_pass_steps
|
||||||
|
self.hr_resize_x = hr_resize_x
|
||||||
|
self.hr_resize_y = hr_resize_y
|
||||||
|
self.hr_upscale_to_x = hr_resize_x
|
||||||
|
self.hr_upscale_to_y = hr_resize_y
|
||||||
|
|
||||||
if firstphase_width != 0 or firstphase_height != 0:
|
if firstphase_width != 0 or firstphase_height != 0:
|
||||||
print("firstphase_width/firstphase_height no longer supported; use hr_scale", file=sys.stderr)
|
print("firstphase_width/firstphase_height no longer supported; use hr_scale", file=sys.stderr)
|
||||||
|
@ -671,14 +680,60 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||||
self.width = firstphase_width
|
self.width = firstphase_width
|
||||||
self.height = firstphase_height
|
self.height = firstphase_height
|
||||||
|
|
||||||
|
self.truncate_x = 0
|
||||||
|
self.truncate_y = 0
|
||||||
|
|
||||||
|
|
||||||
def init(self, all_prompts, all_seeds, all_subseeds):
|
def init(self, all_prompts, all_seeds, all_subseeds):
|
||||||
if self.enable_hr:
|
if self.enable_hr:
|
||||||
if state.job_count == -1:
|
if self.hr_resize_x == 0 and self.hr_resize_y == 0:
|
||||||
state.job_count = self.n_iter * 2
|
|
||||||
else:
|
|
||||||
state.job_count = state.job_count * 2
|
|
||||||
|
|
||||||
self.extra_generation_params["Hires upscale"] = self.hr_scale
|
self.extra_generation_params["Hires upscale"] = self.hr_scale
|
||||||
|
self.hr_upscale_to_x = int(self.width * self.hr_scale)
|
||||||
|
self.hr_upscale_to_y = int(self.height * self.hr_scale)
|
||||||
|
else:
|
||||||
|
self.extra_generation_params["Hires resize"] = f"{self.hr_resize_x}x{self.hr_resize_y}"
|
||||||
|
|
||||||
|
if self.hr_resize_y == 0:
|
||||||
|
self.hr_upscale_to_x = self.hr_resize_x
|
||||||
|
self.hr_upscale_to_y = self.hr_resize_x * self.height // self.width
|
||||||
|
elif self.hr_resize_x == 0:
|
||||||
|
self.hr_upscale_to_x = self.hr_resize_y * self.width // self.height
|
||||||
|
self.hr_upscale_to_y = self.hr_resize_y
|
||||||
|
else:
|
||||||
|
target_w = self.hr_resize_x
|
||||||
|
target_h = self.hr_resize_y
|
||||||
|
src_ratio = self.width / self.height
|
||||||
|
dst_ratio = self.hr_resize_x / self.hr_resize_y
|
||||||
|
|
||||||
|
if src_ratio < dst_ratio:
|
||||||
|
self.hr_upscale_to_x = self.hr_resize_x
|
||||||
|
self.hr_upscale_to_y = self.hr_resize_x * self.height // self.width
|
||||||
|
else:
|
||||||
|
self.hr_upscale_to_x = self.hr_resize_y * self.width // self.height
|
||||||
|
self.hr_upscale_to_y = self.hr_resize_y
|
||||||
|
|
||||||
|
self.truncate_x = (self.hr_upscale_to_x - target_w) // opt_f
|
||||||
|
self.truncate_y = (self.hr_upscale_to_y - target_h) // opt_f
|
||||||
|
|
||||||
|
# special case: the user has chosen to do nothing
|
||||||
|
if self.hr_upscale_to_x == self.width and self.hr_upscale_to_y == self.height:
|
||||||
|
self.enable_hr = False
|
||||||
|
self.denoising_strength = None
|
||||||
|
self.extra_generation_params.pop("Hires upscale", None)
|
||||||
|
self.extra_generation_params.pop("Hires resize", None)
|
||||||
|
return
|
||||||
|
|
||||||
|
if not state.processing_has_refined_job_count:
|
||||||
|
if state.job_count == -1:
|
||||||
|
state.job_count = self.n_iter
|
||||||
|
|
||||||
|
shared.total_tqdm.updateTotal((self.steps + (self.hr_second_pass_steps or self.steps)) * state.job_count)
|
||||||
|
state.job_count = state.job_count * 2
|
||||||
|
state.processing_has_refined_job_count = True
|
||||||
|
|
||||||
|
if self.hr_second_pass_steps:
|
||||||
|
self.extra_generation_params["Hires steps"] = self.hr_second_pass_steps
|
||||||
|
|
||||||
if self.hr_upscaler is not None:
|
if self.hr_upscaler is not None:
|
||||||
self.extra_generation_params["Hires upscaler"] = self.hr_upscaler
|
self.extra_generation_params["Hires upscaler"] = self.hr_upscaler
|
||||||
|
|
||||||
|
@ -695,8 +750,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||||
if not self.enable_hr:
|
if not self.enable_hr:
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
target_width = int(self.width * self.hr_scale)
|
target_width = self.hr_upscale_to_x
|
||||||
target_height = int(self.height * self.hr_scale)
|
target_height = self.hr_upscale_to_y
|
||||||
|
|
||||||
def save_intermediate(image, index):
|
def save_intermediate(image, index):
|
||||||
"""saves image before applying hires fix, if enabled in options; takes as an argument either an image or batch with latent space images"""
|
"""saves image before applying hires fix, if enabled in options; takes as an argument either an image or batch with latent space images"""
|
||||||
|
@ -705,15 +760,16 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||||
return
|
return
|
||||||
|
|
||||||
if not isinstance(image, Image.Image):
|
if not isinstance(image, Image.Image):
|
||||||
image = sd_samplers.sample_to_image(image, index)
|
image = sd_samplers.sample_to_image(image, index, approximation=0)
|
||||||
|
|
||||||
images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, suffix="-before-highres-fix")
|
info = create_infotext(self, self.all_prompts, self.all_seeds, self.all_subseeds, [], iteration=self.iteration, position_in_batch=index)
|
||||||
|
images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, info=info, suffix="-before-highres-fix")
|
||||||
|
|
||||||
if latent_scale_mode is not None:
|
if latent_scale_mode is not None:
|
||||||
for i in range(samples.shape[0]):
|
for i in range(samples.shape[0]):
|
||||||
save_intermediate(samples, i)
|
save_intermediate(samples, i)
|
||||||
|
|
||||||
samples = torch.nn.functional.interpolate(samples, size=(target_height // opt_f, target_width // opt_f), mode=latent_scale_mode)
|
samples = torch.nn.functional.interpolate(samples, size=(target_height // opt_f, target_width // opt_f), mode=latent_scale_mode["mode"], antialias=latent_scale_mode["antialias"])
|
||||||
|
|
||||||
# Avoid making the inpainting conditioning unless necessary as
|
# Avoid making the inpainting conditioning unless necessary as
|
||||||
# this does need some extra compute to decode / encode the image again.
|
# this does need some extra compute to decode / encode the image again.
|
||||||
|
@ -750,13 +806,15 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||||
|
|
||||||
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
|
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
|
||||||
|
|
||||||
|
samples = samples[:, :, self.truncate_y//2:samples.shape[2]-(self.truncate_y+1)//2, self.truncate_x//2:samples.shape[3]-(self.truncate_x+1)//2]
|
||||||
|
|
||||||
noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, p=self)
|
noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, p=self)
|
||||||
|
|
||||||
# GC now before running the next img2img to prevent running out of memory
|
# GC now before running the next img2img to prevent running out of memory
|
||||||
x = None
|
x = None
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
|
|
||||||
samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps, image_conditioning=image_conditioning)
|
samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
|
||||||
|
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
|
|
|
@ -34,24 +34,33 @@ def apply_optimizations():
|
||||||
ldm.modules.diffusionmodules.model.nonlinearity = silu
|
ldm.modules.diffusionmodules.model.nonlinearity = silu
|
||||||
ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th
|
ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th
|
||||||
|
|
||||||
|
optimization_method = None
|
||||||
|
|
||||||
if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)):
|
if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)):
|
||||||
print("Applying xformers cross attention optimization.")
|
print("Applying xformers cross attention optimization.")
|
||||||
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward
|
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward
|
||||||
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward
|
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward
|
||||||
|
optimization_method = 'xformers'
|
||||||
elif cmd_opts.opt_sub_quad_attention:
|
elif cmd_opts.opt_sub_quad_attention:
|
||||||
print("Applying sub-quadratic cross attention optimization.")
|
print("Applying sub-quadratic cross attention optimization.")
|
||||||
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.sub_quad_attention_forward
|
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.sub_quad_attention_forward
|
||||||
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.sub_quad_attnblock_forward
|
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.sub_quad_attnblock_forward
|
||||||
|
optimization_method = 'sub-quadratic'
|
||||||
elif cmd_opts.opt_split_attention_v1:
|
elif cmd_opts.opt_split_attention_v1:
|
||||||
print("Applying v1 cross attention optimization.")
|
print("Applying v1 cross attention optimization.")
|
||||||
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
|
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
|
||||||
|
optimization_method = 'V1'
|
||||||
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention_invokeai or not torch.cuda.is_available()):
|
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention_invokeai or not torch.cuda.is_available()):
|
||||||
print("Applying cross attention optimization (InvokeAI).")
|
print("Applying cross attention optimization (InvokeAI).")
|
||||||
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_invokeAI
|
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_invokeAI
|
||||||
|
optimization_method = 'InvokeAI'
|
||||||
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
|
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
|
||||||
print("Applying cross attention optimization (Doggettx).")
|
print("Applying cross attention optimization (Doggettx).")
|
||||||
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward
|
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward
|
||||||
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward
|
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward
|
||||||
|
optimization_method = 'Doggettx'
|
||||||
|
|
||||||
|
return optimization_method
|
||||||
|
|
||||||
|
|
||||||
def undo_optimizations():
|
def undo_optimizations():
|
||||||
|
@ -72,6 +81,7 @@ class StableDiffusionModelHijack:
|
||||||
layers = None
|
layers = None
|
||||||
circular_enabled = False
|
circular_enabled = False
|
||||||
clip = None
|
clip = None
|
||||||
|
optimization_method = None
|
||||||
|
|
||||||
embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase(cmd_opts.embeddings_dir)
|
embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase(cmd_opts.embeddings_dir)
|
||||||
|
|
||||||
|
@ -91,7 +101,7 @@ class StableDiffusionModelHijack:
|
||||||
m.cond_stage_model.model.token_embedding = EmbeddingsWithFixes(m.cond_stage_model.model.token_embedding, self)
|
m.cond_stage_model.model.token_embedding = EmbeddingsWithFixes(m.cond_stage_model.model.token_embedding, self)
|
||||||
m.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
|
m.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
|
||||||
|
|
||||||
apply_optimizations()
|
self.optimization_method = apply_optimizations()
|
||||||
|
|
||||||
self.clip = m.cond_stage_model
|
self.clip = m.cond_stage_model
|
||||||
|
|
||||||
|
|
|
@ -97,8 +97,11 @@ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=F
|
||||||
|
|
||||||
|
|
||||||
def should_hijack_inpainting(checkpoint_info):
|
def should_hijack_inpainting(checkpoint_info):
|
||||||
|
from modules import sd_models
|
||||||
|
|
||||||
ckpt_basename = os.path.basename(checkpoint_info.filename).lower()
|
ckpt_basename = os.path.basename(checkpoint_info.filename).lower()
|
||||||
cfg_basename = os.path.basename(checkpoint_info.config).lower()
|
cfg_basename = os.path.basename(sd_models.find_checkpoint_config(checkpoint_info)).lower()
|
||||||
|
|
||||||
return "inpainting" in ckpt_basename and not "inpainting" in cfg_basename
|
return "inpainting" in ckpt_basename and not "inpainting" in cfg_basename
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -20,7 +20,7 @@ from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inp
|
||||||
model_dir = "Stable-diffusion"
|
model_dir = "Stable-diffusion"
|
||||||
model_path = os.path.abspath(os.path.join(models_path, model_dir))
|
model_path = os.path.abspath(os.path.join(models_path, model_dir))
|
||||||
|
|
||||||
CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name', 'config'])
|
CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name'])
|
||||||
checkpoints_list = {}
|
checkpoints_list = {}
|
||||||
checkpoints_loaded = collections.OrderedDict()
|
checkpoints_loaded = collections.OrderedDict()
|
||||||
|
|
||||||
|
@ -48,6 +48,14 @@ def checkpoint_tiles():
|
||||||
return sorted([x.title for x in checkpoints_list.values()], key = alphanumeric_key)
|
return sorted([x.title for x in checkpoints_list.values()], key = alphanumeric_key)
|
||||||
|
|
||||||
|
|
||||||
|
def find_checkpoint_config(info):
|
||||||
|
config = os.path.splitext(info.filename)[0] + ".yaml"
|
||||||
|
if os.path.exists(config):
|
||||||
|
return config
|
||||||
|
|
||||||
|
return shared.cmd_opts.config
|
||||||
|
|
||||||
|
|
||||||
def list_models():
|
def list_models():
|
||||||
checkpoints_list.clear()
|
checkpoints_list.clear()
|
||||||
model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"])
|
model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"])
|
||||||
|
@ -73,7 +81,7 @@ def list_models():
|
||||||
if os.path.exists(cmd_ckpt):
|
if os.path.exists(cmd_ckpt):
|
||||||
h = model_hash(cmd_ckpt)
|
h = model_hash(cmd_ckpt)
|
||||||
title, short_model_name = modeltitle(cmd_ckpt, h)
|
title, short_model_name = modeltitle(cmd_ckpt, h)
|
||||||
checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name, shared.cmd_opts.config)
|
checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name)
|
||||||
shared.opts.data['sd_model_checkpoint'] = title
|
shared.opts.data['sd_model_checkpoint'] = title
|
||||||
elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
|
elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
|
||||||
print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr)
|
print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr)
|
||||||
|
@ -81,12 +89,7 @@ def list_models():
|
||||||
h = model_hash(filename)
|
h = model_hash(filename)
|
||||||
title, short_model_name = modeltitle(filename, h)
|
title, short_model_name = modeltitle(filename, h)
|
||||||
|
|
||||||
basename, _ = os.path.splitext(filename)
|
checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name)
|
||||||
config = basename + ".yaml"
|
|
||||||
if not os.path.exists(config):
|
|
||||||
config = shared.cmd_opts.config
|
|
||||||
|
|
||||||
checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name, config)
|
|
||||||
|
|
||||||
|
|
||||||
def get_closet_checkpoint_match(searchString):
|
def get_closet_checkpoint_match(searchString):
|
||||||
|
@ -168,7 +171,10 @@ def get_state_dict_from_checkpoint(pl_sd):
|
||||||
def read_state_dict(checkpoint_file, print_global_state=False, map_location=None):
|
def read_state_dict(checkpoint_file, print_global_state=False, map_location=None):
|
||||||
_, extension = os.path.splitext(checkpoint_file)
|
_, extension = os.path.splitext(checkpoint_file)
|
||||||
if extension.lower() == ".safetensors":
|
if extension.lower() == ".safetensors":
|
||||||
pl_sd = safetensors.torch.load_file(checkpoint_file, device=map_location or shared.weight_load_location)
|
device = map_location or shared.weight_load_location
|
||||||
|
if device is None:
|
||||||
|
device = devices.get_cuda_device_string() if torch.cuda.is_available() else "cpu"
|
||||||
|
pl_sd = safetensors.torch.load_file(checkpoint_file, device=device)
|
||||||
else:
|
else:
|
||||||
pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location)
|
pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location)
|
||||||
|
|
||||||
|
@ -278,12 +284,14 @@ def enable_midas_autodownload():
|
||||||
|
|
||||||
midas.api.load_model = load_model_wrapper
|
midas.api.load_model = load_model_wrapper
|
||||||
|
|
||||||
|
|
||||||
def load_model(checkpoint_info=None):
|
def load_model(checkpoint_info=None):
|
||||||
from modules import lowvram, sd_hijack
|
from modules import lowvram, sd_hijack
|
||||||
checkpoint_info = checkpoint_info or select_checkpoint()
|
checkpoint_info = checkpoint_info or select_checkpoint()
|
||||||
|
checkpoint_config = find_checkpoint_config(checkpoint_info)
|
||||||
|
|
||||||
if checkpoint_info.config != shared.cmd_opts.config:
|
if checkpoint_config != shared.cmd_opts.config:
|
||||||
print(f"Loading config from: {checkpoint_info.config}")
|
print(f"Loading config from: {checkpoint_config}")
|
||||||
|
|
||||||
if shared.sd_model:
|
if shared.sd_model:
|
||||||
sd_hijack.model_hijack.undo_hijack(shared.sd_model)
|
sd_hijack.model_hijack.undo_hijack(shared.sd_model)
|
||||||
|
@ -291,7 +299,7 @@ def load_model(checkpoint_info=None):
|
||||||
gc.collect()
|
gc.collect()
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
|
|
||||||
sd_config = OmegaConf.load(checkpoint_info.config)
|
sd_config = OmegaConf.load(checkpoint_config)
|
||||||
|
|
||||||
if should_hijack_inpainting(checkpoint_info):
|
if should_hijack_inpainting(checkpoint_info):
|
||||||
# Hardcoded config for now...
|
# Hardcoded config for now...
|
||||||
|
@ -300,9 +308,6 @@ def load_model(checkpoint_info=None):
|
||||||
sd_config.model.params.unet_config.params.in_channels = 9
|
sd_config.model.params.unet_config.params.in_channels = 9
|
||||||
sd_config.model.params.finetune_keys = None
|
sd_config.model.params.finetune_keys = None
|
||||||
|
|
||||||
# Create a "fake" config with a different name so that we know to unload it when switching models.
|
|
||||||
checkpoint_info = checkpoint_info._replace(config=checkpoint_info.config.replace(".yaml", "-inpainting.yaml"))
|
|
||||||
|
|
||||||
if not hasattr(sd_config.model.params, "use_ema"):
|
if not hasattr(sd_config.model.params, "use_ema"):
|
||||||
sd_config.model.params.use_ema = False
|
sd_config.model.params.use_ema = False
|
||||||
|
|
||||||
|
@ -312,6 +317,7 @@ def load_model(checkpoint_info=None):
|
||||||
sd_config.model.params.unet_config.params.use_fp16 = False
|
sd_config.model.params.unet_config.params.use_fp16 = False
|
||||||
|
|
||||||
sd_model = instantiate_from_config(sd_config.model)
|
sd_model = instantiate_from_config(sd_config.model)
|
||||||
|
|
||||||
load_model_weights(sd_model, checkpoint_info)
|
load_model_weights(sd_model, checkpoint_info)
|
||||||
|
|
||||||
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
||||||
|
@ -340,10 +346,13 @@ def reload_model_weights(sd_model=None, info=None):
|
||||||
if not sd_model:
|
if not sd_model:
|
||||||
sd_model = shared.sd_model
|
sd_model = shared.sd_model
|
||||||
|
|
||||||
|
current_checkpoint_info = sd_model.sd_checkpoint_info
|
||||||
|
checkpoint_config = find_checkpoint_config(current_checkpoint_info)
|
||||||
|
|
||||||
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
|
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
|
||||||
return
|
return
|
||||||
|
|
||||||
if sd_model.sd_checkpoint_info.config != checkpoint_info.config or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info):
|
if checkpoint_config != find_checkpoint_config(checkpoint_info) or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info):
|
||||||
del sd_model
|
del sd_model
|
||||||
checkpoints_loaded.clear()
|
checkpoints_loaded.clear()
|
||||||
load_model(checkpoint_info)
|
load_model(checkpoint_info)
|
||||||
|
@ -356,8 +365,13 @@ def reload_model_weights(sd_model=None, info=None):
|
||||||
|
|
||||||
sd_hijack.model_hijack.undo_hijack(sd_model)
|
sd_hijack.model_hijack.undo_hijack(sd_model)
|
||||||
|
|
||||||
|
try:
|
||||||
load_model_weights(sd_model, checkpoint_info)
|
load_model_weights(sd_model, checkpoint_info)
|
||||||
|
except Exception as e:
|
||||||
|
print("Failed to load checkpoint, restoring previous")
|
||||||
|
load_model_weights(sd_model, current_checkpoint_info)
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
sd_hijack.model_hijack.hijack(sd_model)
|
sd_hijack.model_hijack.hijack(sd_model)
|
||||||
script_callbacks.model_loaded_callback(sd_model)
|
script_callbacks.model_loaded_callback(sd_model)
|
||||||
|
|
||||||
|
@ -365,4 +379,5 @@ def reload_model_weights(sd_model=None, info=None):
|
||||||
sd_model.to(devices.device)
|
sd_model.to(devices.device)
|
||||||
|
|
||||||
print("Weights loaded.")
|
print("Weights loaded.")
|
||||||
|
|
||||||
return sd_model
|
return sd_model
|
||||||
|
|
|
@ -97,8 +97,9 @@ sampler_extra_params = {
|
||||||
|
|
||||||
def setup_img2img_steps(p, steps=None):
|
def setup_img2img_steps(p, steps=None):
|
||||||
if opts.img2img_fix_steps or steps is not None:
|
if opts.img2img_fix_steps or steps is not None:
|
||||||
steps = int((steps or p.steps) / min(p.denoising_strength, 0.999)) if p.denoising_strength > 0 else 0
|
requested_steps = (steps or p.steps)
|
||||||
t_enc = p.steps - 1
|
steps = int(requested_steps / min(p.denoising_strength, 0.999)) if p.denoising_strength > 0 else 0
|
||||||
|
t_enc = requested_steps - 1
|
||||||
else:
|
else:
|
||||||
steps = p.steps
|
steps = p.steps
|
||||||
t_enc = int(min(p.denoising_strength, 0.999) * steps)
|
t_enc = int(min(p.denoising_strength, 0.999) * steps)
|
||||||
|
|
|
@ -14,7 +14,7 @@ import modules.interrogate
|
||||||
import modules.memmon
|
import modules.memmon
|
||||||
import modules.styles
|
import modules.styles
|
||||||
import modules.devices as devices
|
import modules.devices as devices
|
||||||
from modules import localization, sd_vae, extensions, script_loading
|
from modules import localization, sd_vae, extensions, script_loading, errors
|
||||||
from modules.paths import models_path, script_path, sd_path
|
from modules.paths import models_path, script_path, sd_path
|
||||||
|
|
||||||
|
|
||||||
|
@ -86,6 +86,7 @@ parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencode
|
||||||
parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False)
|
parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False)
|
||||||
parser.add_argument("--api", action='store_true', help="use api=True to launch the API together with the webui (use --nowebui instead for only the API)")
|
parser.add_argument("--api", action='store_true', help="use api=True to launch the API together with the webui (use --nowebui instead for only the API)")
|
||||||
parser.add_argument("--api-auth", type=str, help='Set authentication for API like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
|
parser.add_argument("--api-auth", type=str, help='Set authentication for API like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
|
||||||
|
parser.add_argument("--api-log", action='store_true', help="use api-log=True to enable logging of all API requests")
|
||||||
parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the API instead of the webui")
|
parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the API instead of the webui")
|
||||||
parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load model to quickly launch UI")
|
parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load model to quickly launch UI")
|
||||||
parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None)
|
parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None)
|
||||||
|
@ -156,6 +157,7 @@ class State:
|
||||||
job = ""
|
job = ""
|
||||||
job_no = 0
|
job_no = 0
|
||||||
job_count = 0
|
job_count = 0
|
||||||
|
processing_has_refined_job_count = False
|
||||||
job_timestamp = '0'
|
job_timestamp = '0'
|
||||||
sampling_step = 0
|
sampling_step = 0
|
||||||
sampling_steps = 0
|
sampling_steps = 0
|
||||||
|
@ -186,6 +188,7 @@ class State:
|
||||||
"interrupted": self.interrupted,
|
"interrupted": self.interrupted,
|
||||||
"job": self.job,
|
"job": self.job,
|
||||||
"job_count": self.job_count,
|
"job_count": self.job_count,
|
||||||
|
"job_timestamp": self.job_timestamp,
|
||||||
"job_no": self.job_no,
|
"job_no": self.job_no,
|
||||||
"sampling_step": self.sampling_step,
|
"sampling_step": self.sampling_step,
|
||||||
"sampling_steps": self.sampling_steps,
|
"sampling_steps": self.sampling_steps,
|
||||||
|
@ -196,6 +199,7 @@ class State:
|
||||||
def begin(self):
|
def begin(self):
|
||||||
self.sampling_step = 0
|
self.sampling_step = 0
|
||||||
self.job_count = -1
|
self.job_count = -1
|
||||||
|
self.processing_has_refined_job_count = False
|
||||||
self.job_no = 0
|
self.job_no = 0
|
||||||
self.job_timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
|
self.job_timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
|
||||||
self.current_latent = None
|
self.current_latent = None
|
||||||
|
@ -216,12 +220,13 @@ class State:
|
||||||
|
|
||||||
"""sets self.current_image from self.current_latent if enough sampling steps have been made after the last call to this"""
|
"""sets self.current_image from self.current_latent if enough sampling steps have been made after the last call to this"""
|
||||||
def set_current_image(self):
|
def set_current_image(self):
|
||||||
|
if not parallel_processing_allowed:
|
||||||
|
return
|
||||||
|
|
||||||
if self.sampling_step - self.current_image_sampling_step >= opts.show_progress_every_n_steps and opts.show_progress_every_n_steps > 0:
|
if self.sampling_step - self.current_image_sampling_step >= opts.show_progress_every_n_steps and opts.show_progress_every_n_steps > 0:
|
||||||
self.do_set_current_image()
|
self.do_set_current_image()
|
||||||
|
|
||||||
def do_set_current_image(self):
|
def do_set_current_image(self):
|
||||||
if not parallel_processing_allowed:
|
|
||||||
return
|
|
||||||
if self.current_latent is None:
|
if self.current_latent is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -233,6 +238,7 @@ class State:
|
||||||
|
|
||||||
self.current_image_sampling_step = self.sampling_step
|
self.current_image_sampling_step = self.sampling_step
|
||||||
|
|
||||||
|
|
||||||
state = State()
|
state = State()
|
||||||
|
|
||||||
artist_db = modules.artists.ArtistsDatabase(os.path.join(script_path, 'artists.csv'))
|
artist_db = modules.artists.ArtistsDatabase(os.path.join(script_path, 'artists.csv'))
|
||||||
|
@ -359,7 +365,7 @@ options_templates.update(options_section(('system', "System"), {
|
||||||
options_templates.update(options_section(('training', "Training"), {
|
options_templates.update(options_section(('training', "Training"), {
|
||||||
"unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."),
|
"unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."),
|
||||||
"pin_memory": OptionInfo(False, "Turn on pin_memory for DataLoader. Makes training slightly faster but can increase memory usage."),
|
"pin_memory": OptionInfo(False, "Turn on pin_memory for DataLoader. Makes training slightly faster but can increase memory usage."),
|
||||||
"save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training can be resumed with HN itself and matching optim file."),
|
"save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training of embedding or HN can be resumed with the matching optim file."),
|
||||||
"dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
|
"dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
|
||||||
"dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
|
"dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
|
||||||
"training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
|
"training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
|
||||||
|
@ -498,7 +504,12 @@ class Options:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if self.data_labels[key].onchange is not None:
|
if self.data_labels[key].onchange is not None:
|
||||||
|
try:
|
||||||
self.data_labels[key].onchange()
|
self.data_labels[key].onchange()
|
||||||
|
except Exception as e:
|
||||||
|
errors.display(e, f"changing setting {key} to {value}")
|
||||||
|
setattr(self, key, oldval)
|
||||||
|
return False
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@ -563,8 +574,11 @@ if os.path.exists(config_filename):
|
||||||
|
|
||||||
latent_upscale_default_mode = "Latent"
|
latent_upscale_default_mode = "Latent"
|
||||||
latent_upscale_modes = {
|
latent_upscale_modes = {
|
||||||
"Latent": "bilinear",
|
"Latent": {"mode": "bilinear", "antialias": False},
|
||||||
"Latent (nearest)": "nearest",
|
"Latent (antialiased)": {"mode": "bilinear", "antialias": True},
|
||||||
|
"Latent (bicubic)": {"mode": "bicubic", "antialias": False},
|
||||||
|
"Latent (bicubic antialiased)": {"mode": "bicubic", "antialias": True},
|
||||||
|
"Latent (nearest)": {"mode": "nearest", "antialias": False},
|
||||||
}
|
}
|
||||||
|
|
||||||
sd_upscalers = []
|
sd_upscalers = []
|
||||||
|
|
|
@ -58,14 +58,19 @@ class LearnRateScheduler:
|
||||||
|
|
||||||
self.finished = False
|
self.finished = False
|
||||||
|
|
||||||
def apply(self, optimizer, step_number):
|
def step(self, step_number):
|
||||||
if step_number < self.end_step:
|
if step_number < self.end_step:
|
||||||
return
|
return False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
(self.learn_rate, self.end_step) = next(self.schedules)
|
(self.learn_rate, self.end_step) = next(self.schedules)
|
||||||
except Exception:
|
except StopIteration:
|
||||||
self.finished = True
|
self.finished = True
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def apply(self, optimizer, step_number):
|
||||||
|
if not self.step(step_number):
|
||||||
return
|
return
|
||||||
|
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
|
|
|
@ -124,6 +124,7 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre
|
||||||
|
|
||||||
files = listfiles(src)
|
files = listfiles(src)
|
||||||
|
|
||||||
|
shared.state.job = "preprocess"
|
||||||
shared.state.textinfo = "Preprocessing..."
|
shared.state.textinfo = "Preprocessing..."
|
||||||
shared.state.job_count = len(files)
|
shared.state.job_count = len(files)
|
||||||
|
|
||||||
|
|
|
@ -28,6 +28,7 @@ class Embedding:
|
||||||
self.cached_checksum = None
|
self.cached_checksum = None
|
||||||
self.sd_checkpoint = None
|
self.sd_checkpoint = None
|
||||||
self.sd_checkpoint_name = None
|
self.sd_checkpoint_name = None
|
||||||
|
self.optimizer_state_dict = None
|
||||||
|
|
||||||
def save(self, filename):
|
def save(self, filename):
|
||||||
embedding_data = {
|
embedding_data = {
|
||||||
|
@ -41,6 +42,13 @@ class Embedding:
|
||||||
|
|
||||||
torch.save(embedding_data, filename)
|
torch.save(embedding_data, filename)
|
||||||
|
|
||||||
|
if shared.opts.save_optimizer_state and self.optimizer_state_dict is not None:
|
||||||
|
optimizer_saved_dict = {
|
||||||
|
'hash': self.checksum(),
|
||||||
|
'optimizer_state_dict': self.optimizer_state_dict,
|
||||||
|
}
|
||||||
|
torch.save(optimizer_saved_dict, filename + '.optim')
|
||||||
|
|
||||||
def checksum(self):
|
def checksum(self):
|
||||||
if self.cached_checksum is not None:
|
if self.cached_checksum is not None:
|
||||||
return self.cached_checksum
|
return self.cached_checksum
|
||||||
|
@ -95,9 +103,10 @@ class EmbeddingDatabase:
|
||||||
self.expected_shape = self.get_expected_shape()
|
self.expected_shape = self.get_expected_shape()
|
||||||
|
|
||||||
def process_file(path, filename):
|
def process_file(path, filename):
|
||||||
name = os.path.splitext(filename)[0]
|
name, ext = os.path.splitext(filename)
|
||||||
|
ext = ext.upper()
|
||||||
|
|
||||||
if os.path.splitext(filename.upper())[-1] in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
|
if ext in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
|
||||||
embed_image = Image.open(path)
|
embed_image = Image.open(path)
|
||||||
if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text:
|
if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text:
|
||||||
data = embedding_from_b64(embed_image.text['sd-ti-embedding'])
|
data = embedding_from_b64(embed_image.text['sd-ti-embedding'])
|
||||||
|
@ -105,8 +114,10 @@ class EmbeddingDatabase:
|
||||||
else:
|
else:
|
||||||
data = extract_image_data_embed(embed_image)
|
data = extract_image_data_embed(embed_image)
|
||||||
name = data.get('name', name)
|
name = data.get('name', name)
|
||||||
else:
|
elif ext in ['.BIN', '.PT']:
|
||||||
data = torch.load(path, map_location="cpu")
|
data = torch.load(path, map_location="cpu")
|
||||||
|
else:
|
||||||
|
return
|
||||||
|
|
||||||
# textual inversion embeddings
|
# textual inversion embeddings
|
||||||
if 'string_to_param' in data:
|
if 'string_to_param' in data:
|
||||||
|
@ -240,11 +251,12 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat
|
||||||
if save_model_every or create_image_every:
|
if save_model_every or create_image_every:
|
||||||
assert log_directory, "Log directory is empty"
|
assert log_directory, "Log directory is empty"
|
||||||
|
|
||||||
def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
||||||
save_embedding_every = save_embedding_every or 0
|
save_embedding_every = save_embedding_every or 0
|
||||||
create_image_every = create_image_every or 0
|
create_image_every = create_image_every or 0
|
||||||
validate_train_inputs(embedding_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_embedding_every, create_image_every, log_directory, name="embedding")
|
validate_train_inputs(embedding_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_embedding_every, create_image_every, log_directory, name="embedding")
|
||||||
|
|
||||||
|
shared.state.job = "train-embedding"
|
||||||
shared.state.textinfo = "Initializing textual inversion training..."
|
shared.state.textinfo = "Initializing textual inversion training..."
|
||||||
shared.state.job_count = steps
|
shared.state.job_count = steps
|
||||||
|
|
||||||
|
@ -282,6 +294,11 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
|
||||||
return embedding, filename
|
return embedding, filename
|
||||||
scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
|
scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
|
||||||
|
|
||||||
|
clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else \
|
||||||
|
torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else \
|
||||||
|
None
|
||||||
|
if clip_grad:
|
||||||
|
clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, ititial_step, verbose=False)
|
||||||
# dataset loading may take a while, so input validations and early returns should be done before this
|
# dataset loading may take a while, so input validations and early returns should be done before this
|
||||||
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
||||||
old_parallel_processing_allowed = shared.parallel_processing_allowed
|
old_parallel_processing_allowed = shared.parallel_processing_allowed
|
||||||
|
@ -300,6 +317,19 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
|
||||||
|
|
||||||
embedding.vec.requires_grad = True
|
embedding.vec.requires_grad = True
|
||||||
optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate, weight_decay=0.0)
|
optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate, weight_decay=0.0)
|
||||||
|
if shared.opts.save_optimizer_state:
|
||||||
|
optimizer_state_dict = None
|
||||||
|
if os.path.exists(filename + '.optim'):
|
||||||
|
optimizer_saved_dict = torch.load(filename + '.optim', map_location='cpu')
|
||||||
|
if embedding.checksum() == optimizer_saved_dict.get('hash', None):
|
||||||
|
optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
|
||||||
|
|
||||||
|
if optimizer_state_dict is not None:
|
||||||
|
optimizer.load_state_dict(optimizer_state_dict)
|
||||||
|
print("Loaded existing optimizer from checkpoint")
|
||||||
|
else:
|
||||||
|
print("No saved optimizer exists in checkpoint")
|
||||||
|
|
||||||
scaler = torch.cuda.amp.GradScaler()
|
scaler = torch.cuda.amp.GradScaler()
|
||||||
|
|
||||||
batch_size = ds.batch_size
|
batch_size = ds.batch_size
|
||||||
|
@ -315,6 +345,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
|
||||||
forced_filename = "<none>"
|
forced_filename = "<none>"
|
||||||
embedding_yet_to_be_embedded = False
|
embedding_yet_to_be_embedded = False
|
||||||
|
|
||||||
|
is_training_inpainting_model = shared.sd_model.model.conditioning_key in {'hybrid', 'concat'}
|
||||||
|
img_c = None
|
||||||
|
|
||||||
pbar = tqdm.tqdm(total=steps - initial_step)
|
pbar = tqdm.tqdm(total=steps - initial_step)
|
||||||
try:
|
try:
|
||||||
for i in range((steps-initial_step) * gradient_step):
|
for i in range((steps-initial_step) * gradient_step):
|
||||||
|
@ -332,14 +365,22 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
|
||||||
if shared.state.interrupted:
|
if shared.state.interrupted:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
if clip_grad:
|
||||||
|
clip_grad_sched.step(embedding.step)
|
||||||
|
|
||||||
with devices.autocast():
|
with devices.autocast():
|
||||||
# c = stack_conds(batch.cond).to(devices.device)
|
|
||||||
# mask = torch.tensor(batch.emb_index).to(devices.device, non_blocking=pin_memory)
|
|
||||||
# print(mask)
|
|
||||||
# c[:, 1:1+embedding.vec.shape[0]] = embedding.vec.to(devices.device, non_blocking=pin_memory)
|
|
||||||
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
|
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
|
||||||
c = shared.sd_model.cond_stage_model(batch.cond_text)
|
c = shared.sd_model.cond_stage_model(batch.cond_text)
|
||||||
loss = shared.sd_model(x, c)[0] / gradient_step
|
|
||||||
|
if is_training_inpainting_model:
|
||||||
|
if img_c is None:
|
||||||
|
img_c = processing.txt2img_image_conditioning(shared.sd_model, c, training_width, training_height)
|
||||||
|
|
||||||
|
cond = {"c_concat": [img_c], "c_crossattn": [c]}
|
||||||
|
else:
|
||||||
|
cond = c
|
||||||
|
|
||||||
|
loss = shared.sd_model(x, cond)[0] / gradient_step
|
||||||
del x
|
del x
|
||||||
|
|
||||||
_loss_step += loss.item()
|
_loss_step += loss.item()
|
||||||
|
@ -348,6 +389,10 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
|
||||||
# go back until we reach gradient accumulation steps
|
# go back until we reach gradient accumulation steps
|
||||||
if (j + 1) % gradient_step != 0:
|
if (j + 1) % gradient_step != 0:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
if clip_grad:
|
||||||
|
clip_grad(embedding.vec, clip_grad_sched.learn_rate)
|
||||||
|
|
||||||
scaler.step(optimizer)
|
scaler.step(optimizer)
|
||||||
scaler.update()
|
scaler.update()
|
||||||
embedding.step += 1
|
embedding.step += 1
|
||||||
|
@ -366,9 +411,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
|
||||||
# Before saving, change name to match current checkpoint.
|
# Before saving, change name to match current checkpoint.
|
||||||
embedding_name_every = f'{embedding_name}-{steps_done}'
|
embedding_name_every = f'{embedding_name}-{steps_done}'
|
||||||
last_saved_file = os.path.join(embedding_dir, f'{embedding_name_every}.pt')
|
last_saved_file = os.path.join(embedding_dir, f'{embedding_name_every}.pt')
|
||||||
#if shared.opts.save_optimizer_state:
|
save_embedding(embedding, optimizer, checkpoint, embedding_name_every, last_saved_file, remove_cached_checksum=True)
|
||||||
#embedding.optimizer_state_dict = optimizer.state_dict()
|
|
||||||
save_embedding(embedding, checkpoint, embedding_name_every, last_saved_file, remove_cached_checksum=True)
|
|
||||||
embedding_yet_to_be_embedded = True
|
embedding_yet_to_be_embedded = True
|
||||||
|
|
||||||
write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, steps_per_epoch, {
|
write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, steps_per_epoch, {
|
||||||
|
@ -458,7 +501,7 @@ Last saved image: {html.escape(last_saved_image)}<br/>
|
||||||
</p>
|
</p>
|
||||||
"""
|
"""
|
||||||
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
|
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
|
||||||
save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True)
|
save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True)
|
||||||
except Exception:
|
except Exception:
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
pass
|
pass
|
||||||
|
@ -470,7 +513,7 @@ Last saved image: {html.escape(last_saved_image)}<br/>
|
||||||
|
|
||||||
return embedding, filename
|
return embedding, filename
|
||||||
|
|
||||||
def save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True):
|
def save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True):
|
||||||
old_embedding_name = embedding.name
|
old_embedding_name = embedding.name
|
||||||
old_sd_checkpoint = embedding.sd_checkpoint if hasattr(embedding, "sd_checkpoint") else None
|
old_sd_checkpoint = embedding.sd_checkpoint if hasattr(embedding, "sd_checkpoint") else None
|
||||||
old_sd_checkpoint_name = embedding.sd_checkpoint_name if hasattr(embedding, "sd_checkpoint_name") else None
|
old_sd_checkpoint_name = embedding.sd_checkpoint_name if hasattr(embedding, "sd_checkpoint_name") else None
|
||||||
|
@ -481,6 +524,7 @@ def save_embedding(embedding, checkpoint, embedding_name, filename, remove_cache
|
||||||
if remove_cached_checksum:
|
if remove_cached_checksum:
|
||||||
embedding.cached_checksum = None
|
embedding.cached_checksum = None
|
||||||
embedding.name = embedding_name
|
embedding.name = embedding_name
|
||||||
|
embedding.optimizer_state_dict = optimizer.state_dict()
|
||||||
embedding.save(filename)
|
embedding.save(filename)
|
||||||
except:
|
except:
|
||||||
embedding.sd_checkpoint = old_sd_checkpoint
|
embedding.sd_checkpoint = old_sd_checkpoint
|
||||||
|
|
|
@ -8,7 +8,7 @@ import modules.processing as processing
|
||||||
from modules.ui import plaintext_to_html
|
from modules.ui import plaintext_to_html
|
||||||
|
|
||||||
|
|
||||||
def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, *args):
|
def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, *args):
|
||||||
p = StableDiffusionProcessingTxt2Img(
|
p = StableDiffusionProcessingTxt2Img(
|
||||||
sd_model=shared.sd_model,
|
sd_model=shared.sd_model,
|
||||||
outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
|
outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
|
||||||
|
@ -35,6 +35,9 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
|
||||||
denoising_strength=denoising_strength if enable_hr else None,
|
denoising_strength=denoising_strength if enable_hr else None,
|
||||||
hr_scale=hr_scale,
|
hr_scale=hr_scale,
|
||||||
hr_upscaler=hr_upscaler,
|
hr_upscaler=hr_upscaler,
|
||||||
|
hr_second_pass_steps=hr_second_pass_steps,
|
||||||
|
hr_resize_x=hr_resize_x,
|
||||||
|
hr_resize_y=hr_resize_y,
|
||||||
)
|
)
|
||||||
|
|
||||||
p.scripts = modules.scripts.scripts_txt2img
|
p.scripts = modules.scripts.scripts_txt2img
|
||||||
|
|
|
@ -162,16 +162,14 @@ def save_files(js_data, images, do_make_zip, index):
|
||||||
return gr.File.update(value=fullfns, visible=True), plaintext_to_html(f"Saved: {filenames[0]}")
|
return gr.File.update(value=fullfns, visible=True), plaintext_to_html(f"Saved: {filenames[0]}")
|
||||||
|
|
||||||
|
|
||||||
|
def calc_time_left(progress, threshold, label, force_display, show_eta):
|
||||||
|
|
||||||
def calc_time_left(progress, threshold, label, force_display):
|
|
||||||
if progress == 0:
|
if progress == 0:
|
||||||
return ""
|
return ""
|
||||||
else:
|
else:
|
||||||
time_since_start = time.time() - shared.state.time_start
|
time_since_start = time.time() - shared.state.time_start
|
||||||
eta = (time_since_start/progress)
|
eta = (time_since_start/progress)
|
||||||
eta_relative = eta-time_since_start
|
eta_relative = eta-time_since_start
|
||||||
if (eta_relative > threshold and progress > 0.02) or force_display:
|
if (eta_relative > threshold and show_eta) or force_display:
|
||||||
if eta_relative > 3600:
|
if eta_relative > 3600:
|
||||||
return label + time.strftime('%H:%M:%S', time.gmtime(eta_relative))
|
return label + time.strftime('%H:%M:%S', time.gmtime(eta_relative))
|
||||||
elif eta_relative > 60:
|
elif eta_relative > 60:
|
||||||
|
@ -193,7 +191,10 @@ def check_progress_call(id_part):
|
||||||
if shared.state.sampling_steps > 0:
|
if shared.state.sampling_steps > 0:
|
||||||
progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps
|
progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps
|
||||||
|
|
||||||
time_left = calc_time_left( progress, 1, " ETA: ", shared.state.time_left_force_display )
|
# Show progress percentage and time left at the same moment, and base it also on steps done
|
||||||
|
show_eta = progress >= 0.01 or shared.state.sampling_step >= 10
|
||||||
|
|
||||||
|
time_left = calc_time_left(progress, 1, " ETA: ", shared.state.time_left_force_display, show_eta)
|
||||||
if time_left != "":
|
if time_left != "":
|
||||||
shared.state.time_left_force_display = True
|
shared.state.time_left_force_display = True
|
||||||
|
|
||||||
|
@ -201,7 +202,7 @@ def check_progress_call(id_part):
|
||||||
|
|
||||||
progressbar = ""
|
progressbar = ""
|
||||||
if opts.show_progressbar:
|
if opts.show_progressbar:
|
||||||
progressbar = f"""<div class='progressDiv'><div class='progress' style="overflow:visible;width:{progress * 100}%;white-space:nowrap;">{" " * 2 + str(int(progress*100))+"%" + time_left if progress > 0.01 else ""}</div></div>"""
|
progressbar = f"""<div class='progressDiv'><div class='progress' style="overflow:visible;width:{progress * 100}%;white-space:nowrap;">{" " * 2 + str(int(progress*100))+"%" + time_left if show_eta else ""}</div></div>"""
|
||||||
|
|
||||||
image = gr_show(False)
|
image = gr_show(False)
|
||||||
preview_visibility = gr_show(False)
|
preview_visibility = gr_show(False)
|
||||||
|
@ -635,10 +636,11 @@ def create_sampler_and_steps_selection(choices, tabname):
|
||||||
if opts.samplers_in_dropdown:
|
if opts.samplers_in_dropdown:
|
||||||
with FormRow(elem_id=f"sampler_selection_{tabname}"):
|
with FormRow(elem_id=f"sampler_selection_{tabname}"):
|
||||||
sampler_index = gr.Dropdown(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index")
|
sampler_index = gr.Dropdown(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index")
|
||||||
steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling Steps", value=20)
|
sampler_index.save_to_config = True
|
||||||
|
steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20)
|
||||||
else:
|
else:
|
||||||
with FormGroup(elem_id=f"sampler_selection_{tabname}"):
|
with FormGroup(elem_id=f"sampler_selection_{tabname}"):
|
||||||
steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling Steps", value=20)
|
steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20)
|
||||||
sampler_index = gr.Radio(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index")
|
sampler_index = gr.Radio(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index")
|
||||||
|
|
||||||
return steps, sampler_index
|
return steps, sampler_index
|
||||||
|
@ -707,11 +709,17 @@ def create_ui():
|
||||||
enable_hr = gr.Checkbox(label='Hires. fix', value=False, elem_id="txt2img_enable_hr")
|
enable_hr = gr.Checkbox(label='Hires. fix', value=False, elem_id="txt2img_enable_hr")
|
||||||
|
|
||||||
elif category == "hires_fix":
|
elif category == "hires_fix":
|
||||||
with FormRow(visible=False, elem_id="txt2img_hires_fix") as hr_options:
|
with FormGroup(visible=False, elem_id="txt2img_hires_fix") as hr_options:
|
||||||
|
with FormRow(elem_id="txt2img_hires_fix_row1"):
|
||||||
hr_upscaler = gr.Dropdown(label="Upscaler", elem_id="txt2img_hr_upscaler", choices=[*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]], value=shared.latent_upscale_default_mode)
|
hr_upscaler = gr.Dropdown(label="Upscaler", elem_id="txt2img_hr_upscaler", choices=[*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]], value=shared.latent_upscale_default_mode)
|
||||||
hr_scale = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Upscale by", value=2.0, elem_id="txt2img_hr_scale")
|
hr_second_pass_steps = gr.Slider(minimum=0, maximum=150, step=1, label='Hires steps', value=0, elem_id="txt2img_hires_steps")
|
||||||
denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7, elem_id="txt2img_denoising_strength")
|
denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7, elem_id="txt2img_denoising_strength")
|
||||||
|
|
||||||
|
with FormRow(elem_id="txt2img_hires_fix_row2"):
|
||||||
|
hr_scale = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Upscale by", value=2.0, elem_id="txt2img_hr_scale")
|
||||||
|
hr_resize_x = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize width to", value=0, elem_id="txt2img_hr_resize_x")
|
||||||
|
hr_resize_y = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize height to", value=0, elem_id="txt2img_hr_resize_y")
|
||||||
|
|
||||||
elif category == "batch":
|
elif category == "batch":
|
||||||
if not opts.dimensions_and_batch_together:
|
if not opts.dimensions_and_batch_together:
|
||||||
with FormRow(elem_id="txt2img_column_batch"):
|
with FormRow(elem_id="txt2img_column_batch"):
|
||||||
|
@ -751,6 +759,9 @@ def create_ui():
|
||||||
denoising_strength,
|
denoising_strength,
|
||||||
hr_scale,
|
hr_scale,
|
||||||
hr_upscaler,
|
hr_upscaler,
|
||||||
|
hr_second_pass_steps,
|
||||||
|
hr_resize_x,
|
||||||
|
hr_resize_y,
|
||||||
] + custom_inputs,
|
] + custom_inputs,
|
||||||
|
|
||||||
outputs=[
|
outputs=[
|
||||||
|
@ -802,6 +813,9 @@ def create_ui():
|
||||||
(hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)),
|
(hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)),
|
||||||
(hr_scale, "Hires upscale"),
|
(hr_scale, "Hires upscale"),
|
||||||
(hr_upscaler, "Hires upscaler"),
|
(hr_upscaler, "Hires upscaler"),
|
||||||
|
(hr_second_pass_steps, "Hires steps"),
|
||||||
|
(hr_resize_x, "Hires resize-1"),
|
||||||
|
(hr_resize_y, "Hires resize-2"),
|
||||||
*modules.scripts.scripts_txt2img.infotext_fields
|
*modules.scripts.scripts_txt2img.infotext_fields
|
||||||
]
|
]
|
||||||
parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields)
|
parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields)
|
||||||
|
@ -1279,38 +1293,48 @@ def create_ui():
|
||||||
|
|
||||||
with gr.Tab(label="Train"):
|
with gr.Tab(label="Train"):
|
||||||
gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images <a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\" style=\"font-weight:bold;\">[wiki]</a></p>")
|
gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images <a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\" style=\"font-weight:bold;\">[wiki]</a></p>")
|
||||||
with gr.Row():
|
with FormRow():
|
||||||
train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
|
train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
|
||||||
create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name")
|
create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name")
|
||||||
with gr.Row():
|
|
||||||
train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=[x for x in shared.hypernetworks.keys()])
|
train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=[x for x in shared.hypernetworks.keys()])
|
||||||
create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted([x for x in shared.hypernetworks.keys()])}, "refresh_train_hypernetwork_name")
|
create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted([x for x in shared.hypernetworks.keys()])}, "refresh_train_hypernetwork_name")
|
||||||
with gr.Row():
|
|
||||||
|
with FormRow():
|
||||||
embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005", elem_id="train_embedding_learn_rate")
|
embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005", elem_id="train_embedding_learn_rate")
|
||||||
hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001", elem_id="train_hypernetwork_learn_rate")
|
hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001", elem_id="train_hypernetwork_learn_rate")
|
||||||
|
|
||||||
|
with FormRow():
|
||||||
|
clip_grad_mode = gr.Dropdown(value="disabled", label="Gradient Clipping", choices=["disabled", "value", "norm"])
|
||||||
|
clip_grad_value = gr.Textbox(placeholder="Gradient clip value", value="0.1", show_label=False)
|
||||||
|
|
||||||
|
with FormRow():
|
||||||
batch_size = gr.Number(label='Batch size', value=1, precision=0, elem_id="train_batch_size")
|
batch_size = gr.Number(label='Batch size', value=1, precision=0, elem_id="train_batch_size")
|
||||||
gradient_step = gr.Number(label='Gradient accumulation steps', value=1, precision=0, elem_id="train_gradient_step")
|
gradient_step = gr.Number(label='Gradient accumulation steps', value=1, precision=0, elem_id="train_gradient_step")
|
||||||
|
|
||||||
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images", elem_id="train_dataset_directory")
|
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images", elem_id="train_dataset_directory")
|
||||||
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion", elem_id="train_log_directory")
|
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion", elem_id="train_log_directory")
|
||||||
template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"), elem_id="train_template_file")
|
template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"), elem_id="train_template_file")
|
||||||
training_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_training_width")
|
training_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_training_width")
|
||||||
training_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_training_height")
|
training_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_training_height")
|
||||||
steps = gr.Number(label='Max steps', value=100000, precision=0, elem_id="train_steps")
|
steps = gr.Number(label='Max steps', value=100000, precision=0, elem_id="train_steps")
|
||||||
|
|
||||||
|
with FormRow():
|
||||||
create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0, elem_id="train_create_image_every")
|
create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0, elem_id="train_create_image_every")
|
||||||
save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0, elem_id="train_save_embedding_every")
|
save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0, elem_id="train_save_embedding_every")
|
||||||
|
|
||||||
save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True, elem_id="train_save_image_with_stored_embedding")
|
save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True, elem_id="train_save_image_with_stored_embedding")
|
||||||
preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False, elem_id="train_preview_from_txt2img")
|
preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False, elem_id="train_preview_from_txt2img")
|
||||||
with gr.Row():
|
|
||||||
shuffle_tags = gr.Checkbox(label="Shuffle tags by ',' when creating prompts.", value=False, elem_id="train_shuffle_tags")
|
shuffle_tags = gr.Checkbox(label="Shuffle tags by ',' when creating prompts.", value=False, elem_id="train_shuffle_tags")
|
||||||
tag_drop_out = gr.Slider(minimum=0, maximum=1, step=0.1, label="Drop out tags when creating prompts.", value=0, elem_id="train_tag_drop_out")
|
tag_drop_out = gr.Slider(minimum=0, maximum=1, step=0.1, label="Drop out tags when creating prompts.", value=0, elem_id="train_tag_drop_out")
|
||||||
with gr.Row():
|
|
||||||
latent_sampling_method = gr.Radio(label='Choose latent sampling method', value="once", choices=['once', 'deterministic', 'random'], elem_id="train_latent_sampling_method")
|
latent_sampling_method = gr.Radio(label='Choose latent sampling method', value="once", choices=['once', 'deterministic', 'random'], elem_id="train_latent_sampling_method")
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
|
train_embedding = gr.Button(value="Train Embedding", variant='primary', elem_id="train_train_embedding")
|
||||||
interrupt_training = gr.Button(value="Interrupt", elem_id="train_interrupt_training")
|
interrupt_training = gr.Button(value="Interrupt", elem_id="train_interrupt_training")
|
||||||
train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary', elem_id="train_train_hypernetwork")
|
train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary', elem_id="train_train_hypernetwork")
|
||||||
train_embedding = gr.Button(value="Train Embedding", variant='primary', elem_id="train_train_embedding")
|
|
||||||
|
|
||||||
params = script_callbacks.UiTrainTabParams(txt2img_preview_params)
|
params = script_callbacks.UiTrainTabParams(txt2img_preview_params)
|
||||||
|
|
||||||
|
@ -1400,6 +1424,8 @@ def create_ui():
|
||||||
training_width,
|
training_width,
|
||||||
training_height,
|
training_height,
|
||||||
steps,
|
steps,
|
||||||
|
clip_grad_mode,
|
||||||
|
clip_grad_value,
|
||||||
shuffle_tags,
|
shuffle_tags,
|
||||||
tag_drop_out,
|
tag_drop_out,
|
||||||
latent_sampling_method,
|
latent_sampling_method,
|
||||||
|
@ -1429,6 +1455,8 @@ def create_ui():
|
||||||
training_width,
|
training_width,
|
||||||
training_height,
|
training_height,
|
||||||
steps,
|
steps,
|
||||||
|
clip_grad_mode,
|
||||||
|
clip_grad_value,
|
||||||
shuffle_tags,
|
shuffle_tags,
|
||||||
tag_drop_out,
|
tag_drop_out,
|
||||||
latent_sampling_method,
|
latent_sampling_method,
|
||||||
|
@ -1793,6 +1821,7 @@ def create_ui():
|
||||||
visit(img2img_interface, loadsave, "img2img")
|
visit(img2img_interface, loadsave, "img2img")
|
||||||
visit(extras_interface, loadsave, "extras")
|
visit(extras_interface, loadsave, "extras")
|
||||||
visit(modelmerger_interface, loadsave, "modelmerger")
|
visit(modelmerger_interface, loadsave, "modelmerger")
|
||||||
|
visit(train_interface, loadsave, "train")
|
||||||
|
|
||||||
if not error_loading and (not os.path.exists(ui_config_file) or settings_count != len(ui_settings)):
|
if not error_loading and (not os.path.exists(ui_config_file) or settings_count != len(ui_settings)):
|
||||||
with open(ui_config_file, "w", encoding="utf8") as file:
|
with open(ui_config_file, "w", encoding="utf8") as file:
|
||||||
|
|
|
@ -9,7 +9,7 @@ gradio==3.15.0
|
||||||
invisible-watermark
|
invisible-watermark
|
||||||
numpy
|
numpy
|
||||||
omegaconf
|
omegaconf
|
||||||
opencv-python
|
opencv-contrib-python
|
||||||
requests
|
requests
|
||||||
piexif
|
piexif
|
||||||
Pillow
|
Pillow
|
||||||
|
|
|
@ -5,7 +5,7 @@ basicsr==1.4.2
|
||||||
gfpgan==1.3.8
|
gfpgan==1.3.8
|
||||||
gradio==3.15.0
|
gradio==3.15.0
|
||||||
numpy==1.23.3
|
numpy==1.23.3
|
||||||
Pillow==9.2.0
|
Pillow==9.4.0
|
||||||
realesrgan==0.3.0
|
realesrgan==0.3.0
|
||||||
torch
|
torch
|
||||||
omegaconf==2.2.3
|
omegaconf==2.2.3
|
||||||
|
@ -26,5 +26,5 @@ lark==1.1.2
|
||||||
inflection==0.5.1
|
inflection==0.5.1
|
||||||
GitPython==3.1.27
|
GitPython==3.1.27
|
||||||
torchsde==0.2.5
|
torchsde==0.2.5
|
||||||
safetensors==0.2.5
|
safetensors==0.2.7
|
||||||
httpcore<=0.15
|
httpcore<=0.15
|
||||||
|
|
|
@ -4,7 +4,7 @@ function gradioApp() {
|
||||||
}
|
}
|
||||||
|
|
||||||
function get_uiCurrentTab() {
|
function get_uiCurrentTab() {
|
||||||
return gradioApp().querySelector('.tabs button:not(.border-transparent)')
|
return gradioApp().querySelector('#tabs button:not(.border-transparent)')
|
||||||
}
|
}
|
||||||
|
|
||||||
function get_uiCurrentTabContent() {
|
function get_uiCurrentTabContent() {
|
||||||
|
|
|
@ -10,7 +10,7 @@ import numpy as np
|
||||||
import modules.scripts as scripts
|
import modules.scripts as scripts
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
from modules import images, paths, sd_samplers
|
from modules import images, paths, sd_samplers, processing
|
||||||
from modules.hypernetworks import hypernetwork
|
from modules.hypernetworks import hypernetwork
|
||||||
from modules.processing import process_images, Processed, StableDiffusionProcessingTxt2Img
|
from modules.processing import process_images, Processed, StableDiffusionProcessingTxt2Img
|
||||||
from modules.shared import opts, cmd_opts, state
|
from modules.shared import opts, cmd_opts, state
|
||||||
|
@ -285,6 +285,7 @@ re_range_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d
|
||||||
re_range_count = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\[(\d+)\s*\])?\s*")
|
re_range_count = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\[(\d+)\s*\])?\s*")
|
||||||
re_range_count_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\[(\d+(?:.\d*)?)\s*\])?\s*")
|
re_range_count_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\[(\d+(?:.\d*)?)\s*\])?\s*")
|
||||||
|
|
||||||
|
|
||||||
class Script(scripts.Script):
|
class Script(scripts.Script):
|
||||||
def title(self):
|
def title(self):
|
||||||
return "X/Y plot"
|
return "X/Y plot"
|
||||||
|
@ -403,12 +404,33 @@ class Script(scripts.Script):
|
||||||
print(f"X/Y plot will create {len(xs) * len(ys) * p.n_iter} images on a {len(xs)}x{len(ys)} grid. (Total steps to process: {total_steps * p.n_iter})")
|
print(f"X/Y plot will create {len(xs) * len(ys) * p.n_iter} images on a {len(xs)}x{len(ys)} grid. (Total steps to process: {total_steps * p.n_iter})")
|
||||||
shared.total_tqdm.updateTotal(total_steps * p.n_iter)
|
shared.total_tqdm.updateTotal(total_steps * p.n_iter)
|
||||||
|
|
||||||
|
grid_infotext = [None]
|
||||||
|
|
||||||
def cell(x, y):
|
def cell(x, y):
|
||||||
pc = copy(p)
|
pc = copy(p)
|
||||||
x_opt.apply(pc, x, xs)
|
x_opt.apply(pc, x, xs)
|
||||||
y_opt.apply(pc, y, ys)
|
y_opt.apply(pc, y, ys)
|
||||||
|
|
||||||
return process_images(pc)
|
res = process_images(pc)
|
||||||
|
|
||||||
|
if grid_infotext[0] is None:
|
||||||
|
pc.extra_generation_params = copy(pc.extra_generation_params)
|
||||||
|
|
||||||
|
if x_opt.label != 'Nothing':
|
||||||
|
pc.extra_generation_params["X Type"] = x_opt.label
|
||||||
|
pc.extra_generation_params["X Values"] = x_values
|
||||||
|
if x_opt.label in ["Seed", "Var. seed"] and not no_fixed_seeds:
|
||||||
|
pc.extra_generation_params["Fixed X Values"] = ", ".join([str(x) for x in xs])
|
||||||
|
|
||||||
|
if y_opt.label != 'Nothing':
|
||||||
|
pc.extra_generation_params["Y Type"] = y_opt.label
|
||||||
|
pc.extra_generation_params["Y Values"] = y_values
|
||||||
|
if y_opt.label in ["Seed", "Var. seed"] and not no_fixed_seeds:
|
||||||
|
pc.extra_generation_params["Fixed Y Values"] = ", ".join([str(y) for y in ys])
|
||||||
|
|
||||||
|
grid_infotext[0] = processing.create_infotext(pc, pc.all_prompts, pc.all_seeds, pc.all_subseeds)
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
with SharedSettingsStackHelper():
|
with SharedSettingsStackHelper():
|
||||||
processed = draw_xy_grid(
|
processed = draw_xy_grid(
|
||||||
|
@ -423,6 +445,6 @@ class Script(scripts.Script):
|
||||||
)
|
)
|
||||||
|
|
||||||
if opts.grid_save:
|
if opts.grid_save:
|
||||||
images.save_image(processed.images[0], p.outpath_grids, "xy_grid", extension=opts.grid_format, prompt=p.prompt, seed=processed.seed, grid=True, p=p)
|
images.save_image(processed.images[0], p.outpath_grids, "xy_grid", info=grid_infotext[0], extension=opts.grid_format, prompt=p.prompt, seed=processed.seed, grid=True, p=p)
|
||||||
|
|
||||||
return processed
|
return processed
|
||||||
|
|
|
@ -611,7 +611,7 @@ img2maskimg, #img2maskimg > .h-60, #img2maskimg > .h-60 > div, #img2maskimg > .h
|
||||||
padding-top: 0.9em;
|
padding-top: 0.9em;
|
||||||
}
|
}
|
||||||
|
|
||||||
#img2img_settings div.gr-form .gr-form, #txt2img_settings div.gr-form .gr-form{
|
#img2img_settings div.gr-form .gr-form, #txt2img_settings div.gr-form .gr-form, #train_tabs div.gr-form .gr-form{
|
||||||
border: none;
|
border: none;
|
||||||
padding-bottom: 0.5em;
|
padding-bottom: 0.5em;
|
||||||
}
|
}
|
||||||
|
|
16
webui.py
16
webui.py
|
@ -9,7 +9,7 @@ from fastapi import FastAPI
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.middleware.gzip import GZipMiddleware
|
from fastapi.middleware.gzip import GZipMiddleware
|
||||||
|
|
||||||
from modules import import_hook
|
from modules import import_hook, errors
|
||||||
from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call
|
from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call
|
||||||
from modules.paths import script_path
|
from modules.paths import script_path
|
||||||
|
|
||||||
|
@ -61,7 +61,15 @@ def initialize():
|
||||||
modelloader.load_upscalers()
|
modelloader.load_upscalers()
|
||||||
|
|
||||||
modules.sd_vae.refresh_vae_list()
|
modules.sd_vae.refresh_vae_list()
|
||||||
|
|
||||||
|
try:
|
||||||
modules.sd_models.load_model()
|
modules.sd_models.load_model()
|
||||||
|
except Exception as e:
|
||||||
|
errors.display(e, "loading stable diffusion model")
|
||||||
|
print("", file=sys.stderr)
|
||||||
|
print("Stable diffusion model failed to load, exiting", file=sys.stderr)
|
||||||
|
exit(1)
|
||||||
|
|
||||||
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()))
|
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()))
|
||||||
shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
|
shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
|
||||||
shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
|
shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
|
||||||
|
@ -92,11 +100,11 @@ def initialize():
|
||||||
|
|
||||||
def setup_cors(app):
|
def setup_cors(app):
|
||||||
if cmd_opts.cors_allow_origins and cmd_opts.cors_allow_origins_regex:
|
if cmd_opts.cors_allow_origins and cmd_opts.cors_allow_origins_regex:
|
||||||
app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_origin_regex=cmd_opts.cors_allow_origins_regex, allow_methods=['*'])
|
app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_origin_regex=cmd_opts.cors_allow_origins_regex, allow_methods=['*'], allow_credentials=True, allow_headers=['*'])
|
||||||
elif cmd_opts.cors_allow_origins:
|
elif cmd_opts.cors_allow_origins:
|
||||||
app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_methods=['*'])
|
app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_methods=['*'], allow_credentials=True, allow_headers=['*'])
|
||||||
elif cmd_opts.cors_allow_origins_regex:
|
elif cmd_opts.cors_allow_origins_regex:
|
||||||
app.add_middleware(CORSMiddleware, allow_origin_regex=cmd_opts.cors_allow_origins_regex, allow_methods=['*'])
|
app.add_middleware(CORSMiddleware, allow_origin_regex=cmd_opts.cors_allow_origins_regex, allow_methods=['*'], allow_credentials=True, allow_headers=['*'])
|
||||||
|
|
||||||
|
|
||||||
def create_api(app):
|
def create_api(app):
|
||||||
|
|
Loading…
Reference in a new issue