Merge branch 'master' into api-authorization

This commit is contained in:
Maiko Tan 2022-11-19 20:13:07 +08:00
commit 336c341a7c
No known key found for this signature in database
GPG key ID: 0F3B49C721E5F453
28 changed files with 193 additions and 150 deletions

View file

@ -9,9 +9,9 @@ from fastapi.security import HTTPBasic, HTTPBasicCredentials
from secrets import compare_digest from secrets import compare_digest
import modules.shared as shared import modules.shared as shared
from modules import sd_samplers
from modules.api.models import * from modules.api.models import *
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
from modules.sd_samplers import all_samplers
from modules.extras import run_extras, run_pnginfo from modules.extras import run_extras, run_pnginfo
from PIL import PngImagePlugin from PIL import PngImagePlugin
from modules.sd_models import checkpoints_list from modules.sd_models import checkpoints_list
@ -28,8 +28,12 @@ def upscaler_to_index(name: str):
raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be on of these: {' , '.join([x.name for x in sd_upscalers])}") raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be on of these: {' , '.join([x.name for x in sd_upscalers])}")
sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None) def validate_sampler_name(name):
config = sd_samplers.all_samplers_map.get(name, None)
if config is None:
raise HTTPException(status_code=404, detail="Sampler not found")
return name
def setUpscalers(req: dict): def setUpscalers(req: dict):
reqDict = vars(req) reqDict = vars(req)
@ -77,6 +81,7 @@ class Api:
self.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse) self.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse)
self.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"]) self.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"])
self.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"]) self.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"])
self.add_api_route("/sdapi/v1/skip", self.skip, methods=["POST"])
self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=OptionsModel) self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=OptionsModel)
self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"]) self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=FlagsModel) self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=FlagsModel)
@ -103,14 +108,9 @@ class Api:
raise HTTPException(status_code=401, detail="Incorrect username or password", headers={"WWW-Authenticate": "Basic"}) raise HTTPException(status_code=401, detail="Incorrect username or password", headers={"WWW-Authenticate": "Basic"})
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI): def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
sampler_index = sampler_to_index(txt2imgreq.sampler_index)
if sampler_index is None:
raise HTTPException(status_code=404, detail="Sampler not found")
populate = txt2imgreq.copy(update={ # Override __init__ params populate = txt2imgreq.copy(update={ # Override __init__ params
"sd_model": shared.sd_model, "sd_model": shared.sd_model,
"sampler_index": sampler_index[0], "sampler_name": validate_sampler_name(txt2imgreq.sampler_index),
"do_not_save_samples": True, "do_not_save_samples": True,
"do_not_save_grid": True "do_not_save_grid": True
} }
@ -130,12 +130,6 @@ class Api:
return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js()) return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI): def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI):
sampler_index = sampler_to_index(img2imgreq.sampler_index)
if sampler_index is None:
raise HTTPException(status_code=404, detail="Sampler not found")
init_images = img2imgreq.init_images init_images = img2imgreq.init_images
if init_images is None: if init_images is None:
raise HTTPException(status_code=404, detail="Init image not found") raise HTTPException(status_code=404, detail="Init image not found")
@ -144,10 +138,9 @@ class Api:
if mask: if mask:
mask = decode_base64_to_image(mask) mask = decode_base64_to_image(mask)
populate = img2imgreq.copy(update={ # Override __init__ params populate = img2imgreq.copy(update={ # Override __init__ params
"sd_model": shared.sd_model, "sd_model": shared.sd_model,
"sampler_index": sampler_index[0], "sampler_name": validate_sampler_name(img2imgreq.sampler_index),
"do_not_save_samples": True, "do_not_save_samples": True,
"do_not_save_grid": True, "do_not_save_grid": True,
"mask": mask "mask": mask
@ -266,6 +259,9 @@ class Api:
return {} return {}
def skip(self):
shared.state.skip()
def get_config(self): def get_config(self):
options = {} options = {}
for key in shared.opts.data.keys(): for key in shared.opts.data.keys():
@ -277,14 +273,10 @@ class Api:
return options return options
def set_config(self, req: OptionsModel): def set_config(self, req: Dict[str, Any]):
# currently req has all options fields even if you send a dict like { "send_seed": false }, which means it will
# overwrite all options with default values.
raise RuntimeError('Setting options via API is not supported')
reqDict = vars(req) for o in req:
for o in reqDict: setattr(shared.opts, o, req[o])
setattr(shared.opts, o, reqDict[o])
shared.opts.save(shared.config_filename) shared.opts.save(shared.config_filename)
return return
@ -293,7 +285,7 @@ class Api:
return vars(shared.cmd_opts) return vars(shared.cmd_opts)
def get_samplers(self): def get_samplers(self):
return [{"name":sampler[0], "aliases":sampler[2], "options":sampler[3]} for sampler in all_samplers] return [{"name":sampler[0], "aliases":sampler[2], "options":sampler[3]} for sampler in sd_samplers.all_samplers]
def get_upscalers(self): def get_upscalers(self):
upscalers = [] upscalers = []

View file

@ -176,9 +176,9 @@ class InterrogateResponse(BaseModel):
caption: str = Field(default=None, title="Caption", description="The generated caption for the image.") caption: str = Field(default=None, title="Caption", description="The generated caption for the image.")
fields = {} fields = {}
for key, value in opts.data.items(): for key, metadata in opts.data_labels.items():
metadata = opts.data_labels.get(key) value = opts.data.get(key)
optType = opts.typemap.get(type(value), type(value)) optType = opts.typemap.get(type(metadata.default), type(value))
if (metadata is not None): if (metadata is not None):
fields.update({key: (Optional[optType], Field( fields.update({key: (Optional[optType], Field(

View file

@ -65,9 +65,12 @@ class Extension:
self.can_update = False self.can_update = False
self.status = "latest" self.status = "latest"
def pull(self): def fetch_and_reset_hard(self):
repo = git.Repo(self.path) repo = git.Repo(self.path)
repo.remotes.origin.pull() # Fix: `error: Your local changes to the following files would be overwritten by merge`,
# because WSL2 Docker set 755 file permissions instead of 644, this results to the error.
repo.git.fetch('--all')
repo.git.reset('--hard', 'origin')
def list_extensions(): def list_extensions():

View file

@ -73,6 +73,7 @@ def integrate_settings_paste_fields(component_dict):
'sd_hypernetwork': 'Hypernet', 'sd_hypernetwork': 'Hypernet',
'sd_hypernetwork_strength': 'Hypernet strength', 'sd_hypernetwork_strength': 'Hypernet strength',
'CLIP_stop_at_last_layers': 'Clip skip', 'CLIP_stop_at_last_layers': 'Clip skip',
'inpainting_mask_weight': 'Conditional mask weight',
'sd_model_checkpoint': 'Model hash', 'sd_model_checkpoint': 'Model hash',
} }
settings_paste_fields = [ settings_paste_fields = [

View file

@ -12,7 +12,7 @@ import torch
import tqdm import tqdm
from einops import rearrange, repeat from einops import rearrange, repeat
from ldm.util import default from ldm.util import default
from modules import devices, processing, sd_models, shared from modules import devices, processing, sd_models, shared, sd_samplers
from modules.textual_inversion import textual_inversion from modules.textual_inversion import textual_inversion
from modules.textual_inversion.learn_schedule import LearnRateScheduler from modules.textual_inversion.learn_schedule import LearnRateScheduler
from torch import einsum from torch import einsum
@ -535,7 +535,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
p.prompt = preview_prompt p.prompt = preview_prompt
p.negative_prompt = preview_negative_prompt p.negative_prompt = preview_negative_prompt
p.steps = preview_steps p.steps = preview_steps
p.sampler_index = preview_sampler_index p.sampler_name = sd_samplers.samplers[preview_sampler_index].name
p.cfg_scale = preview_cfg_scale p.cfg_scale = preview_cfg_scale
p.seed = preview_seed p.seed = preview_seed
p.width = preview_width p.width = preview_width

View file

@ -303,7 +303,7 @@ class FilenameGenerator:
'width': lambda self: self.image.width, 'width': lambda self: self.image.width,
'height': lambda self: self.image.height, 'height': lambda self: self.image.height,
'styles': lambda self: self.p and sanitize_filename_part(", ".join([style for style in self.p.styles if not style == "None"]) or "None", replace_spaces=False), 'styles': lambda self: self.p and sanitize_filename_part(", ".join([style for style in self.p.styles if not style == "None"]) or "None", replace_spaces=False),
'sampler': lambda self: self.p and sanitize_filename_part(sd_samplers.samplers[self.p.sampler_index].name, replace_spaces=False), 'sampler': lambda self: self.p and sanitize_filename_part(self.p.sampler_name, replace_spaces=False),
'model_hash': lambda self: getattr(self.p, "sd_model_hash", shared.sd_model.sd_model_hash), 'model_hash': lambda self: getattr(self.p, "sd_model_hash", shared.sd_model.sd_model_hash),
'date': lambda self: datetime.datetime.now().strftime('%Y-%m-%d'), 'date': lambda self: datetime.datetime.now().strftime('%Y-%m-%d'),
'datetime': lambda self, *args: self.datetime(*args), # accepts formats: [datetime], [datetime<Format>], [datetime<Format><Time Zone>] 'datetime': lambda self, *args: self.datetime(*args), # accepts formats: [datetime], [datetime<Format>], [datetime<Format><Time Zone>]

View file

@ -6,7 +6,7 @@ import traceback
import numpy as np import numpy as np
from PIL import Image, ImageOps, ImageChops from PIL import Image, ImageOps, ImageChops
from modules import devices from modules import devices, sd_samplers
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
from modules.shared import opts, state from modules.shared import opts, state
import modules.shared as shared import modules.shared as shared
@ -99,7 +99,7 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
seed_resize_from_h=seed_resize_from_h, seed_resize_from_h=seed_resize_from_h,
seed_resize_from_w=seed_resize_from_w, seed_resize_from_w=seed_resize_from_w,
seed_enable_extras=seed_enable_extras, seed_enable_extras=seed_enable_extras,
sampler_index=sampler_index, sampler_index=sd_samplers.samplers_for_img2img[sampler_index].name,
batch_size=batch_size, batch_size=batch_size,
n_iter=n_iter, n_iter=n_iter,
steps=steps, steps=steps,

View file

@ -2,6 +2,7 @@ import json
import math import math
import os import os
import sys import sys
import warnings
import torch import torch
import numpy as np import numpy as np
@ -66,19 +67,15 @@ def apply_overlay(image, paste_loc, index, overlays):
return image return image
def get_correct_sampler(p):
if isinstance(p, modules.processing.StableDiffusionProcessingTxt2Img):
return sd_samplers.samplers
elif isinstance(p, modules.processing.StableDiffusionProcessingImg2Img):
return sd_samplers.samplers_for_img2img
elif isinstance(p, modules.api.processing.StableDiffusionProcessingAPI):
return sd_samplers.samplers
class StableDiffusionProcessing(): class StableDiffusionProcessing():
""" """
The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
""" """
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_index: int = 0, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None): def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, sampler_index: int = None):
if sampler_index is not None:
warnings.warn("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name")
self.sd_model = sd_model self.sd_model = sd_model
self.outpath_samples: str = outpath_samples self.outpath_samples: str = outpath_samples
self.outpath_grids: str = outpath_grids self.outpath_grids: str = outpath_grids
@ -91,7 +88,7 @@ class StableDiffusionProcessing():
self.subseed_strength: float = subseed_strength self.subseed_strength: float = subseed_strength
self.seed_resize_from_h: int = seed_resize_from_h self.seed_resize_from_h: int = seed_resize_from_h
self.seed_resize_from_w: int = seed_resize_from_w self.seed_resize_from_w: int = seed_resize_from_w
self.sampler_index: int = sampler_index self.sampler_name: str = sampler_name
self.batch_size: int = batch_size self.batch_size: int = batch_size
self.n_iter: int = n_iter self.n_iter: int = n_iter
self.steps: int = steps self.steps: int = steps
@ -116,6 +113,7 @@ class StableDiffusionProcessing():
self.s_tmax = s_tmax or float('inf') # not representable as a standard ui option self.s_tmax = s_tmax or float('inf') # not representable as a standard ui option
self.s_noise = s_noise or opts.s_noise self.s_noise = s_noise or opts.s_noise
self.override_settings = {k: v for k, v in (override_settings or {}).items() if k not in shared.restricted_opts} self.override_settings = {k: v for k, v in (override_settings or {}).items() if k not in shared.restricted_opts}
self.is_using_inpainting_conditioning = False
if not seed_enable_extras: if not seed_enable_extras:
self.subseed = -1 self.subseed = -1
@ -126,6 +124,7 @@ class StableDiffusionProcessing():
self.scripts = None self.scripts = None
self.script_args = None self.script_args = None
self.all_prompts = None self.all_prompts = None
self.all_negative_prompts = None
self.all_seeds = None self.all_seeds = None
self.all_subseeds = None self.all_subseeds = None
@ -136,6 +135,8 @@ class StableDiffusionProcessing():
# Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size. # Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
return x.new_zeros(x.shape[0], 5, 1, 1) return x.new_zeros(x.shape[0], 5, 1, 1)
self.is_using_inpainting_conditioning = True
height = height or self.height height = height or self.height
width = width or self.width width = width or self.width
@ -154,6 +155,8 @@ class StableDiffusionProcessing():
# Dummy zero conditioning if we're not using inpainting model. # Dummy zero conditioning if we're not using inpainting model.
return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1) return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1)
self.is_using_inpainting_conditioning = True
# Handle the different mask inputs # Handle the different mask inputs
if image_mask is not None: if image_mask is not None:
if torch.is_tensor(image_mask): if torch.is_tensor(image_mask):
@ -200,7 +203,7 @@ class StableDiffusionProcessing():
class Processed: class Processed:
def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None): def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_negative_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None):
self.images = images_list self.images = images_list
self.prompt = p.prompt self.prompt = p.prompt
self.negative_prompt = p.negative_prompt self.negative_prompt = p.negative_prompt
@ -210,8 +213,7 @@ class Processed:
self.info = info self.info = info
self.width = p.width self.width = p.width
self.height = p.height self.height = p.height
self.sampler_index = p.sampler_index self.sampler_name = p.sampler_name
self.sampler = sd_samplers.samplers[p.sampler_index].name
self.cfg_scale = p.cfg_scale self.cfg_scale = p.cfg_scale
self.steps = p.steps self.steps = p.steps
self.batch_size = p.batch_size self.batch_size = p.batch_size
@ -238,17 +240,20 @@ class Processed:
self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0] self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0]
self.seed = int(self.seed if type(self.seed) != list else self.seed[0]) if self.seed is not None else -1 self.seed = int(self.seed if type(self.seed) != list else self.seed[0]) if self.seed is not None else -1
self.subseed = int(self.subseed if type(self.subseed) != list else self.subseed[0]) if self.subseed is not None else -1 self.subseed = int(self.subseed if type(self.subseed) != list else self.subseed[0]) if self.subseed is not None else -1
self.is_using_inpainting_conditioning = p.is_using_inpainting_conditioning
self.all_prompts = all_prompts or [self.prompt] self.all_prompts = all_prompts or p.all_prompts or [self.prompt]
self.all_seeds = all_seeds or [self.seed] self.all_negative_prompts = all_negative_prompts or p.all_negative_prompts or [self.negative_prompt]
self.all_subseeds = all_subseeds or [self.subseed] self.all_seeds = all_seeds or p.all_seeds or [self.seed]
self.all_subseeds = all_subseeds or p.all_subseeds or [self.subseed]
self.infotexts = infotexts or [info] self.infotexts = infotexts or [info]
def js(self): def js(self):
obj = { obj = {
"prompt": self.prompt, "prompt": self.all_prompts[0],
"all_prompts": self.all_prompts, "all_prompts": self.all_prompts,
"negative_prompt": self.negative_prompt, "negative_prompt": self.all_negative_prompts[0],
"all_negative_prompts": self.all_negative_prompts,
"seed": self.seed, "seed": self.seed,
"all_seeds": self.all_seeds, "all_seeds": self.all_seeds,
"subseed": self.subseed, "subseed": self.subseed,
@ -256,8 +261,7 @@ class Processed:
"subseed_strength": self.subseed_strength, "subseed_strength": self.subseed_strength,
"width": self.width, "width": self.width,
"height": self.height, "height": self.height,
"sampler_index": self.sampler_index, "sampler_name": self.sampler_name,
"sampler": self.sampler,
"cfg_scale": self.cfg_scale, "cfg_scale": self.cfg_scale,
"steps": self.steps, "steps": self.steps,
"batch_size": self.batch_size, "batch_size": self.batch_size,
@ -273,6 +277,7 @@ class Processed:
"styles": self.styles, "styles": self.styles,
"job_timestamp": self.job_timestamp, "job_timestamp": self.job_timestamp,
"clip_skip": self.clip_skip, "clip_skip": self.clip_skip,
"is_using_inpainting_conditioning": self.is_using_inpainting_conditioning,
} }
return json.dumps(obj) return json.dumps(obj)
@ -384,7 +389,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
generation_params = { generation_params = {
"Steps": p.steps, "Steps": p.steps,
"Sampler": get_correct_sampler(p)[p.sampler_index].name, "Sampler": p.sampler_name,
"CFG scale": p.cfg_scale, "CFG scale": p.cfg_scale,
"Seed": all_seeds[index], "Seed": all_seeds[index],
"Face restoration": (opts.face_restoration_model if p.restore_faces else None), "Face restoration": (opts.face_restoration_model if p.restore_faces else None),
@ -399,6 +404,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength), "Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
"Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"), "Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
"Denoising strength": getattr(p, 'denoising_strength', None), "Denoising strength": getattr(p, 'denoising_strength', None),
"Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None,
"Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta), "Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta),
"Clip skip": None if clip_skip <= 1 else clip_skip, "Clip skip": None if clip_skip <= 1 else clip_skip,
"ENSD": None if opts.eta_noise_seed_delta == 0 else opts.eta_noise_seed_delta, "ENSD": None if opts.eta_noise_seed_delta == 0 else opts.eta_noise_seed_delta,
@ -408,7 +414,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None]) generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None])
negative_prompt_text = "\nNegative prompt: " + p.negative_prompt if p.negative_prompt else "" negative_prompt_text = "\nNegative prompt: " + p.all_negative_prompts[0] if p.all_negative_prompts[0] else ""
return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip() return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip()
@ -437,10 +443,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
else: else:
assert p.prompt is not None assert p.prompt is not None
with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file:
processed = Processed(p, [], p.seed, "")
file.write(processed.infotext(p, 0))
devices.torch_gc() devices.torch_gc()
seed = get_fixed_seed(p.seed) seed = get_fixed_seed(p.seed)
@ -451,12 +453,15 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
comments = {} comments = {}
shared.prompt_styles.apply_styles(p)
if type(p.prompt) == list: if type(p.prompt) == list:
p.all_prompts = p.prompt p.all_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, p.styles) for x in p.prompt]
else: else:
p.all_prompts = p.batch_size * p.n_iter * [p.prompt] p.all_prompts = p.batch_size * p.n_iter * [shared.prompt_styles.apply_styles_to_prompt(p.prompt, p.styles)]
if type(p.negative_prompt) == list:
p.all_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, p.styles) for x in p.negative_prompt]
else:
p.all_negative_prompts = p.batch_size * p.n_iter * [shared.prompt_styles.apply_negative_styles_to_prompt(p.negative_prompt, p.styles)]
if type(seed) == list: if type(seed) == list:
p.all_seeds = seed p.all_seeds = seed
@ -471,6 +476,10 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
def infotext(iteration=0, position_in_batch=0): def infotext(iteration=0, position_in_batch=0):
return create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, comments, iteration, position_in_batch) return create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, comments, iteration, position_in_batch)
with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file:
processed = Processed(p, [], p.seed, "")
file.write(processed.infotext(p, 0))
if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings: if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings:
model_hijack.embedding_db.load_textual_inversion_embeddings() model_hijack.embedding_db.load_textual_inversion_embeddings()
@ -495,6 +504,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
break break
prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size] prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size]
seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size] seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size] subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
@ -505,7 +515,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
p.scripts.process_batch(p, batch_number=n, prompts=prompts, seeds=seeds, subseeds=subseeds) p.scripts.process_batch(p, batch_number=n, prompts=prompts, seeds=seeds, subseeds=subseeds)
with devices.autocast(): with devices.autocast():
uc = prompt_parser.get_learned_conditioning(shared.sd_model, len(prompts) * [p.negative_prompt], p.steps) uc = prompt_parser.get_learned_conditioning(shared.sd_model, negative_prompts, p.steps)
c = prompt_parser.get_multicond_learned_conditioning(shared.sd_model, prompts, p.steps) c = prompt_parser.get_multicond_learned_conditioning(shared.sd_model, prompts, p.steps)
if len(model_hijack.comments) > 0: if len(model_hijack.comments) > 0:
@ -591,7 +601,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
devices.torch_gc() devices.torch_gc()
res = Processed(p, output_images, p.all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], all_prompts=p.all_prompts, all_seeds=p.all_seeds, all_subseeds=p.all_subseeds, index_of_first_image=index_of_first_image, infotexts=infotexts) res = Processed(p, output_images, p.all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], index_of_first_image=index_of_first_image, infotexts=infotexts)
if p.scripts is not None: if p.scripts is not None:
p.scripts.postprocess(p, res) p.scripts.postprocess(p, res)
@ -645,7 +655,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts): def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model) self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
if not self.enable_hr: if not self.enable_hr:
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
@ -706,7 +716,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
shared.state.nextjob() shared.state.nextjob()
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model) self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
@ -730,7 +740,6 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
self.denoising_strength: float = denoising_strength self.denoising_strength: float = denoising_strength
self.init_latent = None self.init_latent = None
self.image_mask = mask self.image_mask = mask
#self.image_unblurred_mask = None
self.latent_mask = None self.latent_mask = None
self.mask_for_overlay = None self.mask_for_overlay = None
self.mask_blur = mask_blur self.mask_blur = mask_blur
@ -743,39 +752,39 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
self.image_conditioning = None self.image_conditioning = None
def init(self, all_prompts, all_seeds, all_subseeds): def init(self, all_prompts, all_seeds, all_subseeds):
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers_for_img2img, self.sampler_index, self.sd_model) self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
crop_region = None crop_region = None
if self.image_mask is not None: image_mask = self.image_mask
self.image_mask = self.image_mask.convert('L')
if image_mask is not None:
image_mask = image_mask.convert('L')
if self.inpainting_mask_invert: if self.inpainting_mask_invert:
self.image_mask = ImageOps.invert(self.image_mask) image_mask = ImageOps.invert(image_mask)
#self.image_unblurred_mask = self.image_mask
if self.mask_blur > 0: if self.mask_blur > 0:
self.image_mask = self.image_mask.filter(ImageFilter.GaussianBlur(self.mask_blur)) image_mask = image_mask.filter(ImageFilter.GaussianBlur(self.mask_blur))
if self.inpaint_full_res: if self.inpaint_full_res:
self.mask_for_overlay = self.image_mask self.mask_for_overlay = image_mask
mask = self.image_mask.convert('L') mask = image_mask.convert('L')
crop_region = masking.get_crop_region(np.array(mask), self.inpaint_full_res_padding) crop_region = masking.get_crop_region(np.array(mask), self.inpaint_full_res_padding)
crop_region = masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height) crop_region = masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height)
x1, y1, x2, y2 = crop_region x1, y1, x2, y2 = crop_region
mask = mask.crop(crop_region) mask = mask.crop(crop_region)
self.image_mask = images.resize_image(2, mask, self.width, self.height) image_mask = images.resize_image(2, mask, self.width, self.height)
self.paste_to = (x1, y1, x2-x1, y2-y1) self.paste_to = (x1, y1, x2-x1, y2-y1)
else: else:
self.image_mask = images.resize_image(self.resize_mode, self.image_mask, self.width, self.height) image_mask = images.resize_image(self.resize_mode, image_mask, self.width, self.height)
np_mask = np.array(self.image_mask) np_mask = np.array(image_mask)
np_mask = np.clip((np_mask.astype(np.float32)) * 2, 0, 255).astype(np.uint8) np_mask = np.clip((np_mask.astype(np.float32)) * 2, 0, 255).astype(np.uint8)
self.mask_for_overlay = Image.fromarray(np_mask) self.mask_for_overlay = Image.fromarray(np_mask)
self.overlay_images = [] self.overlay_images = []
latent_mask = self.latent_mask if self.latent_mask is not None else self.image_mask latent_mask = self.latent_mask if self.latent_mask is not None else image_mask
add_color_corrections = opts.img2img_color_correction and self.color_corrections is None add_color_corrections = opts.img2img_color_correction and self.color_corrections is None
if add_color_corrections: if add_color_corrections:
@ -787,7 +796,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
if crop_region is None: if crop_region is None:
image = images.resize_image(self.resize_mode, image, self.width, self.height) image = images.resize_image(self.resize_mode, image, self.width, self.height)
if self.image_mask is not None: if image_mask is not None:
image_masked = Image.new('RGBa', (image.width, image.height)) image_masked = Image.new('RGBa', (image.width, image.height))
image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(self.mask_for_overlay.convert('L'))) image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(self.mask_for_overlay.convert('L')))
@ -797,7 +806,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
image = image.crop(crop_region) image = image.crop(crop_region)
image = images.resize_image(2, image, self.width, self.height) image = images.resize_image(2, image, self.width, self.height)
if self.image_mask is not None: if image_mask is not None:
if self.inpainting_fill != 1: if self.inpainting_fill != 1:
image = masking.fill(image, latent_mask) image = masking.fill(image, latent_mask)
@ -829,7 +838,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image)) self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image))
if self.image_mask is not None: if image_mask is not None:
init_mask = latent_mask init_mask = latent_mask
latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2])) latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2]))
latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255 latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255
@ -846,7 +855,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
elif self.inpainting_fill == 3: elif self.inpainting_fill == 3:
self.init_latent = self.init_latent * self.mask self.init_latent = self.init_latent * self.mask
self.image_conditioning = self.img2img_image_conditioning(image, self.init_latent, self.image_mask) self.image_conditioning = self.img2img_image_conditioning(image, self.init_latent, image_mask)
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts): def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)

View file

@ -96,8 +96,8 @@ class StableDiffusionModelHijack:
if type(model_embeddings.token_embedding) == EmbeddingsWithFixes: if type(model_embeddings.token_embedding) == EmbeddingsWithFixes:
model_embeddings.token_embedding = model_embeddings.token_embedding.wrapped model_embeddings.token_embedding = model_embeddings.token_embedding.wrapped
self.apply_circular(False)
self.layers = None self.layers = None
self.circular_enabled = False
self.clip = None self.clip = None
def apply_circular(self, enable): def apply_circular(self, enable):

View file

@ -165,16 +165,9 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"):
cache_enabled = shared.opts.sd_checkpoint_cache > 0 cache_enabled = shared.opts.sd_checkpoint_cache > 0
if cache_enabled:
sd_vae.restore_base_vae(model)
vae_file = sd_vae.resolve_vae(checkpoint_file, vae_file=vae_file)
if cache_enabled and checkpoint_info in checkpoints_loaded: if cache_enabled and checkpoint_info in checkpoints_loaded:
# use checkpoint cache # use checkpoint cache
vae_name = sd_vae.get_filename(vae_file) if vae_file else None print(f"Loading weights [{sd_model_hash}] from cache")
vae_message = f" with {vae_name} VAE" if vae_name else ""
print(f"Loading weights [{sd_model_hash}]{vae_message} from cache")
model.load_state_dict(checkpoints_loaded[checkpoint_info]) model.load_state_dict(checkpoints_loaded[checkpoint_info])
else: else:
# load from file # load from file
@ -220,6 +213,7 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"):
model.sd_model_checkpoint = checkpoint_file model.sd_model_checkpoint = checkpoint_file
model.sd_checkpoint_info = checkpoint_info model.sd_checkpoint_info = checkpoint_info
vae_file = sd_vae.resolve_vae(checkpoint_file, vae_file=vae_file)
sd_vae.load_vae(model, vae_file) sd_vae.load_vae(model, vae_file)

View file

@ -46,13 +46,20 @@ all_samplers = [
SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {}), SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {}),
SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {}), SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {}),
] ]
all_samplers_map = {x.name: x for x in all_samplers}
samplers = [] samplers = []
samplers_for_img2img = [] samplers_for_img2img = []
def create_sampler_with_index(list_of_configs, index, model): def create_sampler(name, model):
config = list_of_configs[index] if name is not None:
config = all_samplers_map.get(name, None)
else:
config = all_samplers[0]
assert config is not None, f'bad sampler name: {name}'
sampler = config.constructor(model) sampler = config.constructor(model)
sampler.config = config sampler.config = config

View file

@ -83,47 +83,54 @@ def refresh_vae_list(vae_path=vae_path, model_path=model_path):
return vae_list return vae_list
def resolve_vae(checkpoint_file, vae_file="auto"): def get_vae_from_settings(vae_file="auto"):
global first_load, vae_dict, vae_list # else, we load from settings, if not set to be default
# if vae_file argument is provided, it takes priority, but not saved
if vae_file and vae_file not in default_vae_list:
if not os.path.isfile(vae_file):
vae_file = "auto"
print("VAE provided as function argument doesn't exist")
# for the first load, if vae-path is provided, it takes priority, saved, and failure is reported
if first_load and shared.cmd_opts.vae_path is not None:
if os.path.isfile(shared.cmd_opts.vae_path):
vae_file = shared.cmd_opts.vae_path
shared.opts.data['sd_vae'] = get_filename(vae_file)
else:
print("VAE provided as command line argument doesn't exist")
# else, we load from settings
if vae_file == "auto" and shared.opts.sd_vae is not None: if vae_file == "auto" and shared.opts.sd_vae is not None:
# if saved VAE settings isn't recognized, fallback to auto # if saved VAE settings isn't recognized, fallback to auto
vae_file = vae_dict.get(shared.opts.sd_vae, "auto") vae_file = vae_dict.get(shared.opts.sd_vae, "auto")
# if VAE selected but not found, fallback to auto # if VAE selected but not found, fallback to auto
if vae_file not in default_vae_values and not os.path.isfile(vae_file): if vae_file not in default_vae_values and not os.path.isfile(vae_file):
vae_file = "auto" vae_file = "auto"
print("Selected VAE doesn't exist") print(f"Selected VAE doesn't exist: {vae_file}")
return vae_file
def resolve_vae(checkpoint_file=None, vae_file="auto"):
global first_load, vae_dict, vae_list
# if vae_file argument is provided, it takes priority, but not saved
if vae_file and vae_file not in default_vae_list:
if not os.path.isfile(vae_file):
print(f"VAE provided as function argument doesn't exist: {vae_file}")
vae_file = "auto"
# for the first load, if vae-path is provided, it takes priority, saved, and failure is reported
if first_load and shared.cmd_opts.vae_path is not None:
if os.path.isfile(shared.cmd_opts.vae_path):
vae_file = shared.cmd_opts.vae_path
shared.opts.data['sd_vae'] = get_filename(vae_file)
else:
print(f"VAE provided as command line argument doesn't exist: {vae_file}")
# fallback to selector in settings, if vae selector not set to act as default fallback
if not shared.opts.sd_vae_as_default:
vae_file = get_vae_from_settings(vae_file)
# vae-path cmd arg takes priority for auto # vae-path cmd arg takes priority for auto
if vae_file == "auto" and shared.cmd_opts.vae_path is not None: if vae_file == "auto" and shared.cmd_opts.vae_path is not None:
if os.path.isfile(shared.cmd_opts.vae_path): if os.path.isfile(shared.cmd_opts.vae_path):
vae_file = shared.cmd_opts.vae_path vae_file = shared.cmd_opts.vae_path
print("Using VAE provided as command line argument") print(f"Using VAE provided as command line argument: {vae_file}")
# if still not found, try look for ".vae.pt" beside model # if still not found, try look for ".vae.pt" beside model
model_path = os.path.splitext(checkpoint_file)[0] model_path = os.path.splitext(checkpoint_file)[0]
if vae_file == "auto": if vae_file == "auto":
vae_file_try = model_path + ".vae.pt" vae_file_try = model_path + ".vae.pt"
if os.path.isfile(vae_file_try): if os.path.isfile(vae_file_try):
vae_file = vae_file_try vae_file = vae_file_try
print("Using VAE found beside selected model") print(f"Using VAE found similar to selected model: {vae_file}")
# if still not found, try look for ".vae.ckpt" beside model # if still not found, try look for ".vae.ckpt" beside model
if vae_file == "auto": if vae_file == "auto":
vae_file_try = model_path + ".vae.ckpt" vae_file_try = model_path + ".vae.ckpt"
if os.path.isfile(vae_file_try): if os.path.isfile(vae_file_try):
vae_file = vae_file_try vae_file = vae_file_try
print("Using VAE found beside selected model") print(f"Using VAE found similar to selected model: {vae_file}")
# No more fallbacks for auto # No more fallbacks for auto
if vae_file == "auto": if vae_file == "auto":
vae_file = None vae_file = None
@ -139,6 +146,7 @@ def load_vae(model, vae_file=None):
# save_settings = False # save_settings = False
if vae_file: if vae_file:
assert os.path.isfile(vae_file), f"VAE file doesn't exist: {vae_file}"
print(f"Loading VAE weights from: {vae_file}") print(f"Loading VAE weights from: {vae_file}")
vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location) vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location)
vae_dict_1 = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys} vae_dict_1 = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys}

View file

@ -335,7 +335,8 @@ options_templates.update(options_section(('training', "Training"), {
options_templates.update(options_section(('sd', "Stable Diffusion"), { options_templates.update(options_section(('sd', "Stable Diffusion"), {
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, refresh=sd_models.list_models), "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, refresh=sd_models.list_models),
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), "sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
"sd_vae": OptionInfo("auto", "SD VAE", gr.Dropdown, lambda: {"choices": list(sd_vae.vae_list)}, refresh=sd_vae.refresh_vae_list), "sd_vae": OptionInfo("auto", "SD VAE", gr.Dropdown, lambda: {"choices": sd_vae.vae_list}, refresh=sd_vae.refresh_vae_list),
"sd_vae_as_default": OptionInfo(False, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
"sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks), "sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
"sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}), "sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}),
"inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), "inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),

View file

@ -65,17 +65,6 @@ class StyleDatabase:
def apply_negative_styles_to_prompt(self, prompt, styles): def apply_negative_styles_to_prompt(self, prompt, styles):
return apply_styles_to_prompt(prompt, [self.styles.get(x, self.no_style).negative_prompt for x in styles]) return apply_styles_to_prompt(prompt, [self.styles.get(x, self.no_style).negative_prompt for x in styles])
def apply_styles(self, p: StableDiffusionProcessing) -> None:
if isinstance(p.prompt, list):
p.prompt = [self.apply_styles_to_prompt(prompt, p.styles) for prompt in p.prompt]
else:
p.prompt = self.apply_styles_to_prompt(p.prompt, p.styles)
if isinstance(p.negative_prompt, list):
p.negative_prompt = [self.apply_negative_styles_to_prompt(prompt, p.styles) for prompt in p.negative_prompt]
else:
p.negative_prompt = self.apply_negative_styles_to_prompt(p.negative_prompt, p.styles)
def save_styles(self, path: str) -> None: def save_styles(self, path: str) -> None:
# Write to temporary file first, so we don't nuke the file if something goes wrong # Write to temporary file first, so we don't nuke the file if something goes wrong
fd, temp_path = tempfile.mkstemp(".csv") fd, temp_path = tempfile.mkstemp(".csv")

View file

@ -10,7 +10,7 @@ import csv
from PIL import Image, PngImagePlugin from PIL import Image, PngImagePlugin
from modules import shared, devices, sd_hijack, processing, sd_models, images from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers
import modules.textual_inversion.dataset import modules.textual_inversion.dataset
from modules.textual_inversion.learn_schedule import LearnRateScheduler from modules.textual_inversion.learn_schedule import LearnRateScheduler
@ -345,7 +345,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
p.prompt = preview_prompt p.prompt = preview_prompt
p.negative_prompt = preview_negative_prompt p.negative_prompt = preview_negative_prompt
p.steps = preview_steps p.steps = preview_steps
p.sampler_index = preview_sampler_index p.sampler_name = sd_samplers.samplers[preview_sampler_index].name
p.cfg_scale = preview_cfg_scale p.cfg_scale = preview_cfg_scale
p.seed = preview_seed p.seed = preview_seed
p.width = preview_width p.width = preview_width

View file

@ -18,7 +18,7 @@ def create_embedding(name, initialization_text, nvpt, overwrite_old):
def preprocess(*args): def preprocess(*args):
modules.textual_inversion.preprocess.preprocess(*args) modules.textual_inversion.preprocess.preprocess(*args)
return "Preprocessing finished.", "" return f"Preprocessing {'interrupted' if shared.state.interrupted else 'finished'}.", ""
def train_embedding(*args): def train_embedding(*args):

View file

@ -1,4 +1,5 @@
import modules.scripts import modules.scripts
from modules import sd_samplers
from modules.processing import StableDiffusionProcessing, Processed, StableDiffusionProcessingTxt2Img, \ from modules.processing import StableDiffusionProcessing, Processed, StableDiffusionProcessingTxt2Img, \
StableDiffusionProcessingImg2Img, process_images StableDiffusionProcessingImg2Img, process_images
from modules.shared import opts, cmd_opts from modules.shared import opts, cmd_opts
@ -21,7 +22,7 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
seed_resize_from_h=seed_resize_from_h, seed_resize_from_h=seed_resize_from_h,
seed_resize_from_w=seed_resize_from_w, seed_resize_from_w=seed_resize_from_w,
seed_enable_extras=seed_enable_extras, seed_enable_extras=seed_enable_extras,
sampler_index=sampler_index, sampler_name=sd_samplers.samplers[sampler_index].name,
batch_size=batch_size, batch_size=batch_size,
n_iter=n_iter, n_iter=n_iter,
steps=steps, steps=steps,

View file

@ -69,8 +69,11 @@ sample_img2img = sample_img2img if os.path.exists(sample_img2img) else None
css_hide_progressbar = """ css_hide_progressbar = """
.wrap .m-12 svg { display:none!important; } .wrap .m-12 svg { display:none!important; }
.wrap .m-12::before { content:"Loading..." } .wrap .m-12::before { content:"Loading..." }
.wrap .z-20 svg { display:none!important; }
.wrap .z-20::before { content:"Loading..." }
.progress-bar { display:none!important; } .progress-bar { display:none!important; }
.meta-text { display:none!important; } .meta-text { display:none!important; }
.meta-text-center { display:none!important; }
""" """
# Using constants for these since the variation selector isn't visible. # Using constants for these since the variation selector isn't visible.
@ -142,7 +145,7 @@ def save_files(js_data, images, do_make_zip, index):
filenames.append(os.path.basename(txt_fullfn)) filenames.append(os.path.basename(txt_fullfn))
fullfns.append(txt_fullfn) fullfns.append(txt_fullfn)
writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]]) writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler_name"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]])
# Make Zip # Make Zip
if do_make_zip: if do_make_zip:
@ -1249,6 +1252,8 @@ def create_ui(wrap_gradio_gpu_call):
gr.HTML(value="") gr.HTML(value="")
with gr.Column(): with gr.Column():
with gr.Row():
interrupt_preprocessing = gr.Button("Interrupt")
run_preprocess = gr.Button(value="Preprocess", variant='primary') run_preprocess = gr.Button(value="Preprocess", variant='primary')
process_split.change( process_split.change(
@ -1422,6 +1427,12 @@ def create_ui(wrap_gradio_gpu_call):
outputs=[], outputs=[],
) )
interrupt_preprocessing.click(
fn=lambda: shared.state.interrupt(),
inputs=[],
outputs=[],
)
def create_setting_component(key, is_quicksettings=False): def create_setting_component(key, is_quicksettings=False):
def fun(): def fun():
return opts.data[key] if key in opts.data else opts.data_labels[key].default return opts.data[key] if key in opts.data else opts.data_labels[key].default

View file

@ -36,9 +36,9 @@ def apply_and_restart(disable_list, update_list):
continue continue
try: try:
ext.pull() ext.fetch_and_reset_hard()
except Exception: except Exception:
print(f"Error pulling updates for {ext.name}:", file=sys.stderr) print(f"Error getting updates for {ext.name}:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr) print(traceback.format_exc(), file=sys.stderr)
shared.opts.disabled_extensions = disabled shared.opts.disabled_extensions = disabled

View file

@ -1,3 +1,4 @@
accelerate
basicsr basicsr
diffusers diffusers
fairscale==0.4.4 fairscale==0.4.4

View file

@ -1,5 +1,6 @@
transformers==4.19.2 transformers==4.19.2
diffusers==0.3.0 diffusers==0.3.0
accelerate==0.12.0
basicsr==1.4.2 basicsr==1.4.2
gfpgan==1.3.8 gfpgan==1.3.8
gradio==3.9 gradio==3.9

View file

@ -157,7 +157,7 @@ class Script(scripts.Script):
def run(self, p, _, override_sampler, override_prompt, original_prompt, original_negative_prompt, override_steps, st, override_strength, cfg, randomness, sigma_adjustment): def run(self, p, _, override_sampler, override_prompt, original_prompt, original_negative_prompt, override_steps, st, override_strength, cfg, randomness, sigma_adjustment):
# Override # Override
if override_sampler: if override_sampler:
p.sampler_index = [sampler.name for sampler in sd_samplers.samplers].index("Euler") p.sampler_name = "Euler"
if override_prompt: if override_prompt:
p.prompt = original_prompt p.prompt = original_prompt
p.negative_prompt = original_negative_prompt p.negative_prompt = original_negative_prompt
@ -191,7 +191,7 @@ class Script(scripts.Script):
combined_noise = ((1 - randomness) * rec_noise + randomness * rand_noise) / ((randomness**2 + (1-randomness)**2) ** 0.5) combined_noise = ((1 - randomness) * rec_noise + randomness * rand_noise) / ((randomness**2 + (1-randomness)**2) ** 0.5)
sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, p.sampler_index, p.sd_model) sampler = sd_samplers.create_sampler(p.sampler_name, p.sd_model)
sigmas = sampler.model_wrap.get_sigmas(p.steps) sigmas = sampler.model_wrap.get_sigmas(p.steps)

View file

@ -10,9 +10,9 @@ import numpy as np
import modules.scripts as scripts import modules.scripts as scripts
import gradio as gr import gradio as gr
from modules import images from modules import images, sd_samplers
from modules.hypernetworks import hypernetwork from modules.hypernetworks import hypernetwork
from modules.processing import process_images, Processed, get_correct_sampler, StableDiffusionProcessingTxt2Img from modules.processing import process_images, Processed, StableDiffusionProcessingTxt2Img
from modules.shared import opts, cmd_opts, state from modules.shared import opts, cmd_opts, state
import modules.shared as shared import modules.shared as shared
import modules.sd_samplers import modules.sd_samplers
@ -60,9 +60,9 @@ def apply_order(p, x, xs):
p.prompt = prompt_tmp + p.prompt p.prompt = prompt_tmp + p.prompt
def build_samplers_dict(p): def build_samplers_dict():
samplers_dict = {} samplers_dict = {}
for i, sampler in enumerate(get_correct_sampler(p)): for i, sampler in enumerate(sd_samplers.all_samplers):
samplers_dict[sampler.name.lower()] = i samplers_dict[sampler.name.lower()] = i
for alias in sampler.aliases: for alias in sampler.aliases:
samplers_dict[alias.lower()] = i samplers_dict[alias.lower()] = i
@ -70,7 +70,7 @@ def build_samplers_dict(p):
def apply_sampler(p, x, xs): def apply_sampler(p, x, xs):
sampler_index = build_samplers_dict(p).get(x.lower(), None) sampler_index = build_samplers_dict().get(x.lower(), None)
if sampler_index is None: if sampler_index is None:
raise RuntimeError(f"Unknown sampler: {x}") raise RuntimeError(f"Unknown sampler: {x}")
@ -78,7 +78,7 @@ def apply_sampler(p, x, xs):
def confirm_samplers(p, xs): def confirm_samplers(p, xs):
samplers_dict = build_samplers_dict(p) samplers_dict = build_samplers_dict()
for x in xs: for x in xs:
if x.lower() not in samplers_dict.keys(): if x.lower() not in samplers_dict.keys():
raise RuntimeError(f"Unknown sampler: {x}") raise RuntimeError(f"Unknown sampler: {x}")

View file

@ -4,5 +4,6 @@ set PYTHON=
set GIT= set GIT=
set VENV_DIR= set VENV_DIR=
set COMMANDLINE_ARGS= set COMMANDLINE_ARGS=
set ACCELERATE=
call webui.bat call webui.bat

View file

@ -40,4 +40,7 @@ export COMMANDLINE_ARGS=""
#export CODEFORMER_COMMIT_HASH="" #export CODEFORMER_COMMIT_HASH=""
#export BLIP_COMMIT_HASH="" #export BLIP_COMMIT_HASH=""
# Uncomment to enable accelerated launch
#export ACCELERATE="True"
########################################### ###########################################

View file

@ -28,15 +28,27 @@ goto :show_stdout_stderr
:activate_venv :activate_venv
set PYTHON="%~dp0%VENV_DIR%\Scripts\Python.exe" set PYTHON="%~dp0%VENV_DIR%\Scripts\Python.exe"
echo venv %PYTHON% echo venv %PYTHON%
if [%ACCELERATE%] == ["True"] goto :accelerate
goto :launch goto :launch
:skip_venv :skip_venv
:accelerate
echo "Checking for accelerate"
set ACCELERATE="%~dp0%VENV_DIR%\Scripts\accelerate.exe"
if EXIST %ACCELERATE% goto :accelerate_launch
:launch :launch
%PYTHON% launch.py %* %PYTHON% launch.py %*
pause pause
exit /b exit /b
:accelerate_launch
echo "Accelerating"
%ACCELERATE% launch --num_cpu_threads_per_process=6 launch.py
pause
exit /b
:show_stdout_stderr :show_stdout_stderr
echo. echo.

View file

@ -82,6 +82,7 @@ def initialize():
modules.sd_models.load_model() modules.sd_models.load_model()
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights())) shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()))
shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork))) shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork)))
shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength) shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength)

View file

@ -134,7 +134,15 @@ else
exit 1 exit 1
fi fi
if [[ ! -z "${ACCELERATE}" ]] && [ ${ACCELERATE}="True" ] && [ -x "$(command -v accelerate)" ]
then
printf "\n%s\n" "${delimiter}"
printf "Accelerating launch.py..."
printf "\n%s\n" "${delimiter}"
accelerate launch --num_cpu_threads_per_process=6 "${LAUNCH_SCRIPT}" "$@"
else
printf "\n%s\n" "${delimiter}" printf "\n%s\n" "${delimiter}"
printf "Launching launch.py..." printf "Launching launch.py..."
printf "\n%s\n" "${delimiter}" printf "\n%s\n" "${delimiter}"
"${python_cmd}" "${LAUNCH_SCRIPT}" "$@" "${python_cmd}" "${LAUNCH_SCRIPT}" "$@"
fi