Merge branch 'a1111' into vae-fix-none

This commit is contained in:
Muhammad Rizqi Nur 2022-11-19 16:38:21 +07:00
commit 8662b5e57f
15 changed files with 66 additions and 72 deletions

View file

@ -6,9 +6,9 @@ from threading import Lock
from gradio.processing_utils import encode_pil_to_base64, decode_base64_to_file, decode_base64_to_image from gradio.processing_utils import encode_pil_to_base64, decode_base64_to_file, decode_base64_to_image
from fastapi import APIRouter, Depends, FastAPI, HTTPException from fastapi import APIRouter, Depends, FastAPI, HTTPException
import modules.shared as shared import modules.shared as shared
from modules import sd_samplers
from modules.api.models import * from modules.api.models import *
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
from modules.sd_samplers import all_samplers
from modules.extras import run_extras, run_pnginfo from modules.extras import run_extras, run_pnginfo
from PIL import PngImagePlugin from PIL import PngImagePlugin
from modules.sd_models import checkpoints_list from modules.sd_models import checkpoints_list
@ -25,8 +25,12 @@ def upscaler_to_index(name: str):
raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be on of these: {' , '.join([x.name for x in sd_upscalers])}") raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be on of these: {' , '.join([x.name for x in sd_upscalers])}")
sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None) def validate_sampler_name(name):
config = sd_samplers.all_samplers_map.get(name, None)
if config is None:
raise HTTPException(status_code=404, detail="Sampler not found")
return name
def setUpscalers(req: dict): def setUpscalers(req: dict):
reqDict = vars(req) reqDict = vars(req)
@ -82,14 +86,9 @@ class Api:
self.app.add_api_route("/sdapi/v1/artists", self.get_artists, methods=["GET"], response_model=List[ArtistItem]) self.app.add_api_route("/sdapi/v1/artists", self.get_artists, methods=["GET"], response_model=List[ArtistItem])
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI): def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
sampler_index = sampler_to_index(txt2imgreq.sampler_index)
if sampler_index is None:
raise HTTPException(status_code=404, detail="Sampler not found")
populate = txt2imgreq.copy(update={ # Override __init__ params populate = txt2imgreq.copy(update={ # Override __init__ params
"sd_model": shared.sd_model, "sd_model": shared.sd_model,
"sampler_index": sampler_index[0], "sampler_name": validate_sampler_name(txt2imgreq.sampler_index),
"do_not_save_samples": True, "do_not_save_samples": True,
"do_not_save_grid": True "do_not_save_grid": True
} }
@ -109,12 +108,6 @@ class Api:
return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js()) return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI): def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI):
sampler_index = sampler_to_index(img2imgreq.sampler_index)
if sampler_index is None:
raise HTTPException(status_code=404, detail="Sampler not found")
init_images = img2imgreq.init_images init_images = img2imgreq.init_images
if init_images is None: if init_images is None:
raise HTTPException(status_code=404, detail="Init image not found") raise HTTPException(status_code=404, detail="Init image not found")
@ -123,10 +116,9 @@ class Api:
if mask: if mask:
mask = decode_base64_to_image(mask) mask = decode_base64_to_image(mask)
populate = img2imgreq.copy(update={ # Override __init__ params populate = img2imgreq.copy(update={ # Override __init__ params
"sd_model": shared.sd_model, "sd_model": shared.sd_model,
"sampler_index": sampler_index[0], "sampler_name": validate_sampler_name(img2imgreq.sampler_index),
"do_not_save_samples": True, "do_not_save_samples": True,
"do_not_save_grid": True, "do_not_save_grid": True,
"mask": mask "mask": mask
@ -272,7 +264,7 @@ class Api:
return vars(shared.cmd_opts) return vars(shared.cmd_opts)
def get_samplers(self): def get_samplers(self):
return [{"name":sampler[0], "aliases":sampler[2], "options":sampler[3]} for sampler in all_samplers] return [{"name":sampler[0], "aliases":sampler[2], "options":sampler[3]} for sampler in sd_samplers.all_samplers]
def get_upscalers(self): def get_upscalers(self):
upscalers = [] upscalers = []

View file

@ -65,9 +65,12 @@ class Extension:
self.can_update = False self.can_update = False
self.status = "latest" self.status = "latest"
def pull(self): def fetch_and_reset_hard(self):
repo = git.Repo(self.path) repo = git.Repo(self.path)
repo.remotes.origin.pull() # Fix: `error: Your local changes to the following files would be overwritten by merge`,
# because WSL2 Docker set 755 file permissions instead of 644, this results to the error.
repo.git.fetch('--all')
repo.git.reset('--hard', 'origin')
def list_extensions(): def list_extensions():

View file

@ -12,7 +12,7 @@ import torch
import tqdm import tqdm
from einops import rearrange, repeat from einops import rearrange, repeat
from ldm.util import default from ldm.util import default
from modules import devices, processing, sd_models, shared from modules import devices, processing, sd_models, shared, sd_samplers
from modules.textual_inversion import textual_inversion from modules.textual_inversion import textual_inversion
from modules.textual_inversion.learn_schedule import LearnRateScheduler from modules.textual_inversion.learn_schedule import LearnRateScheduler
from torch import einsum from torch import einsum
@ -535,7 +535,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
p.prompt = preview_prompt p.prompt = preview_prompt
p.negative_prompt = preview_negative_prompt p.negative_prompt = preview_negative_prompt
p.steps = preview_steps p.steps = preview_steps
p.sampler_index = preview_sampler_index p.sampler_name = sd_samplers.samplers[preview_sampler_index].name
p.cfg_scale = preview_cfg_scale p.cfg_scale = preview_cfg_scale
p.seed = preview_seed p.seed = preview_seed
p.width = preview_width p.width = preview_width

View file

@ -303,7 +303,7 @@ class FilenameGenerator:
'width': lambda self: self.image.width, 'width': lambda self: self.image.width,
'height': lambda self: self.image.height, 'height': lambda self: self.image.height,
'styles': lambda self: self.p and sanitize_filename_part(", ".join([style for style in self.p.styles if not style == "None"]) or "None", replace_spaces=False), 'styles': lambda self: self.p and sanitize_filename_part(", ".join([style for style in self.p.styles if not style == "None"]) or "None", replace_spaces=False),
'sampler': lambda self: self.p and sanitize_filename_part(sd_samplers.samplers[self.p.sampler_index].name, replace_spaces=False), 'sampler': lambda self: self.p and sanitize_filename_part(self.p.sampler_name, replace_spaces=False),
'model_hash': lambda self: getattr(self.p, "sd_model_hash", shared.sd_model.sd_model_hash), 'model_hash': lambda self: getattr(self.p, "sd_model_hash", shared.sd_model.sd_model_hash),
'date': lambda self: datetime.datetime.now().strftime('%Y-%m-%d'), 'date': lambda self: datetime.datetime.now().strftime('%Y-%m-%d'),
'datetime': lambda self, *args: self.datetime(*args), # accepts formats: [datetime], [datetime<Format>], [datetime<Format><Time Zone>] 'datetime': lambda self, *args: self.datetime(*args), # accepts formats: [datetime], [datetime<Format>], [datetime<Format><Time Zone>]

View file

@ -6,7 +6,7 @@ import traceback
import numpy as np import numpy as np
from PIL import Image, ImageOps, ImageChops from PIL import Image, ImageOps, ImageChops
from modules import devices from modules import devices, sd_samplers
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
from modules.shared import opts, state from modules.shared import opts, state
import modules.shared as shared import modules.shared as shared
@ -99,7 +99,7 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
seed_resize_from_h=seed_resize_from_h, seed_resize_from_h=seed_resize_from_h,
seed_resize_from_w=seed_resize_from_w, seed_resize_from_w=seed_resize_from_w,
seed_enable_extras=seed_enable_extras, seed_enable_extras=seed_enable_extras,
sampler_index=sampler_index, sampler_index=sd_samplers.samplers_for_img2img[sampler_index].name,
batch_size=batch_size, batch_size=batch_size,
n_iter=n_iter, n_iter=n_iter,
steps=steps, steps=steps,

View file

@ -2,6 +2,7 @@ import json
import math import math
import os import os
import sys import sys
import warnings
import torch import torch
import numpy as np import numpy as np
@ -66,19 +67,15 @@ def apply_overlay(image, paste_loc, index, overlays):
return image return image
def get_correct_sampler(p):
if isinstance(p, modules.processing.StableDiffusionProcessingTxt2Img):
return sd_samplers.samplers
elif isinstance(p, modules.processing.StableDiffusionProcessingImg2Img):
return sd_samplers.samplers_for_img2img
elif isinstance(p, modules.api.processing.StableDiffusionProcessingAPI):
return sd_samplers.samplers
class StableDiffusionProcessing(): class StableDiffusionProcessing():
""" """
The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
""" """
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_index: int = 0, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None): def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, sampler_index: int = None):
if sampler_index is not None:
warnings.warn("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name")
self.sd_model = sd_model self.sd_model = sd_model
self.outpath_samples: str = outpath_samples self.outpath_samples: str = outpath_samples
self.outpath_grids: str = outpath_grids self.outpath_grids: str = outpath_grids
@ -91,7 +88,7 @@ class StableDiffusionProcessing():
self.subseed_strength: float = subseed_strength self.subseed_strength: float = subseed_strength
self.seed_resize_from_h: int = seed_resize_from_h self.seed_resize_from_h: int = seed_resize_from_h
self.seed_resize_from_w: int = seed_resize_from_w self.seed_resize_from_w: int = seed_resize_from_w
self.sampler_index: int = sampler_index self.sampler_name: str = sampler_name
self.batch_size: int = batch_size self.batch_size: int = batch_size
self.n_iter: int = n_iter self.n_iter: int = n_iter
self.steps: int = steps self.steps: int = steps
@ -210,8 +207,7 @@ class Processed:
self.info = info self.info = info
self.width = p.width self.width = p.width
self.height = p.height self.height = p.height
self.sampler_index = p.sampler_index self.sampler_name = p.sampler_name
self.sampler = sd_samplers.samplers[p.sampler_index].name
self.cfg_scale = p.cfg_scale self.cfg_scale = p.cfg_scale
self.steps = p.steps self.steps = p.steps
self.batch_size = p.batch_size self.batch_size = p.batch_size
@ -256,8 +252,7 @@ class Processed:
"subseed_strength": self.subseed_strength, "subseed_strength": self.subseed_strength,
"width": self.width, "width": self.width,
"height": self.height, "height": self.height,
"sampler_index": self.sampler_index, "sampler_name": self.sampler_name,
"sampler": self.sampler,
"cfg_scale": self.cfg_scale, "cfg_scale": self.cfg_scale,
"steps": self.steps, "steps": self.steps,
"batch_size": self.batch_size, "batch_size": self.batch_size,
@ -384,7 +379,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
generation_params = { generation_params = {
"Steps": p.steps, "Steps": p.steps,
"Sampler": get_correct_sampler(p)[p.sampler_index].name, "Sampler": p.sampler_name,
"CFG scale": p.cfg_scale, "CFG scale": p.cfg_scale,
"Seed": all_seeds[index], "Seed": all_seeds[index],
"Face restoration": (opts.face_restoration_model if p.restore_faces else None), "Face restoration": (opts.face_restoration_model if p.restore_faces else None),
@ -399,6 +394,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength), "Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
"Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"), "Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
"Denoising strength": getattr(p, 'denoising_strength', None), "Denoising strength": getattr(p, 'denoising_strength', None),
"Inpainting strength": (None if getattr(p, 'denoising_strength', None) is None else getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight)),
"Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta), "Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta),
"Clip skip": None if clip_skip <= 1 else clip_skip, "Clip skip": None if clip_skip <= 1 else clip_skip,
"ENSD": None if opts.eta_noise_seed_delta == 0 else opts.eta_noise_seed_delta, "ENSD": None if opts.eta_noise_seed_delta == 0 else opts.eta_noise_seed_delta,
@ -645,7 +641,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts): def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model) self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
if not self.enable_hr: if not self.enable_hr:
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
@ -706,7 +702,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
shared.state.nextjob() shared.state.nextjob()
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model) self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
@ -743,7 +739,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
self.image_conditioning = None self.image_conditioning = None
def init(self, all_prompts, all_seeds, all_subseeds): def init(self, all_prompts, all_seeds, all_subseeds):
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers_for_img2img, self.sampler_index, self.sd_model) self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
crop_region = None crop_region = None
if self.image_mask is not None: if self.image_mask is not None:

View file

@ -165,16 +165,9 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"):
cache_enabled = shared.opts.sd_checkpoint_cache > 0 cache_enabled = shared.opts.sd_checkpoint_cache > 0
if cache_enabled:
sd_vae.restore_base_vae(model)
vae_file = sd_vae.resolve_vae(checkpoint_file, vae_file=vae_file)
if cache_enabled and checkpoint_info in checkpoints_loaded: if cache_enabled and checkpoint_info in checkpoints_loaded:
# use checkpoint cache # use checkpoint cache
vae_name = sd_vae.get_filename(vae_file) if vae_file else None print(f"Loading weights [{sd_model_hash}] from cache")
vae_message = f" with {vae_name} VAE" if vae_name else ""
print(f"Loading weights [{sd_model_hash}]{vae_message} from cache")
model.load_state_dict(checkpoints_loaded[checkpoint_info]) model.load_state_dict(checkpoints_loaded[checkpoint_info])
else: else:
# load from file # load from file
@ -222,6 +215,7 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"):
sd_vae.delete_base_vae() sd_vae.delete_base_vae()
sd_vae.clear_loaded_vae() sd_vae.clear_loaded_vae()
vae_file = sd_vae.resolve_vae(checkpoint_file, vae_file=vae_file)
sd_vae.load_vae(model, vae_file) sd_vae.load_vae(model, vae_file)

View file

@ -46,13 +46,20 @@ all_samplers = [
SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {}), SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {}),
SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {}), SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {}),
] ]
all_samplers_map = {x.name: x for x in all_samplers}
samplers = [] samplers = []
samplers_for_img2img = [] samplers_for_img2img = []
def create_sampler_with_index(list_of_configs, index, model): def create_sampler(name, model):
config = list_of_configs[index] if name is not None:
config = all_samplers_map.get(name, None)
else:
config = all_samplers[0]
assert config is not None, f'bad sampler name: {name}'
sampler = config.constructor(model) sampler = config.constructor(model)
sampler.config = config sampler.config = config

View file

@ -95,7 +95,7 @@ def get_vae_from_settings(vae_file="auto"):
# if VAE selected but not found, fallback to auto # if VAE selected but not found, fallback to auto
if vae_file not in default_vae_values and not os.path.isfile(vae_file): if vae_file not in default_vae_values and not os.path.isfile(vae_file):
vae_file = "auto" vae_file = "auto"
print("Selected VAE doesn't exist") print(f"Selected VAE doesn't exist: {vae_file}")
return vae_file return vae_file
@ -105,15 +105,15 @@ def resolve_vae(checkpoint_file=None, vae_file="auto"):
# if vae_file argument is provided, it takes priority, but not saved # if vae_file argument is provided, it takes priority, but not saved
if vae_file and vae_file not in default_vae_list: if vae_file and vae_file not in default_vae_list:
if not os.path.isfile(vae_file): if not os.path.isfile(vae_file):
print(f"VAE provided as function argument doesn't exist: {vae_file}")
vae_file = "auto" vae_file = "auto"
print("VAE provided as function argument doesn't exist")
# for the first load, if vae-path is provided, it takes priority, saved, and failure is reported # for the first load, if vae-path is provided, it takes priority, saved, and failure is reported
if first_load and shared.cmd_opts.vae_path is not None: if first_load and shared.cmd_opts.vae_path is not None:
if os.path.isfile(shared.cmd_opts.vae_path): if os.path.isfile(shared.cmd_opts.vae_path):
vae_file = shared.cmd_opts.vae_path vae_file = shared.cmd_opts.vae_path
shared.opts.data['sd_vae'] = get_filename(vae_file) shared.opts.data['sd_vae'] = get_filename(vae_file)
else: else:
print("VAE provided as command line argument doesn't exist") print(f"VAE provided as command line argument doesn't exist: {vae_file}")
# fallback to selector in settings, if vae selector not set to act as default fallback # fallback to selector in settings, if vae selector not set to act as default fallback
if not shared.opts.sd_vae_as_default: if not shared.opts.sd_vae_as_default:
vae_file = get_vae_from_settings(vae_file) vae_file = get_vae_from_settings(vae_file)
@ -121,20 +121,20 @@ def resolve_vae(checkpoint_file=None, vae_file="auto"):
if vae_file == "auto" and shared.cmd_opts.vae_path is not None: if vae_file == "auto" and shared.cmd_opts.vae_path is not None:
if os.path.isfile(shared.cmd_opts.vae_path): if os.path.isfile(shared.cmd_opts.vae_path):
vae_file = shared.cmd_opts.vae_path vae_file = shared.cmd_opts.vae_path
print("Using VAE provided as command line argument") print(f"Using VAE provided as command line argument: {vae_file}")
# if still not found, try look for ".vae.pt" beside model # if still not found, try look for ".vae.pt" beside model
model_path = os.path.splitext(checkpoint_file)[0] model_path = os.path.splitext(checkpoint_file)[0]
if vae_file == "auto": if vae_file == "auto":
vae_file_try = model_path + ".vae.pt" vae_file_try = model_path + ".vae.pt"
if os.path.isfile(vae_file_try): if os.path.isfile(vae_file_try):
vae_file = vae_file_try vae_file = vae_file_try
print("Using VAE found beside selected model") print(f"Using VAE found similar to selected model: {vae_file}")
# if still not found, try look for ".vae.ckpt" beside model # if still not found, try look for ".vae.ckpt" beside model
if vae_file == "auto": if vae_file == "auto":
vae_file_try = model_path + ".vae.ckpt" vae_file_try = model_path + ".vae.ckpt"
if os.path.isfile(vae_file_try): if os.path.isfile(vae_file_try):
vae_file = vae_file_try vae_file = vae_file_try
print("Using VAE found beside selected model") print(f"Using VAE found similar to selected model: {vae_file}")
# No more fallbacks for auto # No more fallbacks for auto
if vae_file == "auto": if vae_file == "auto":
vae_file = None vae_file = None
@ -150,6 +150,7 @@ def load_vae(model, vae_file=None):
# save_settings = False # save_settings = False
if vae_file: if vae_file:
assert os.path.isfile(vae_file), f"VAE file doesn't exist: {vae_file}"
print(f"Loading VAE weights from: {vae_file}") print(f"Loading VAE weights from: {vae_file}")
store_base_vae(model) store_base_vae(model)
vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location) vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location)

View file

@ -10,7 +10,7 @@ import csv
from PIL import Image, PngImagePlugin from PIL import Image, PngImagePlugin
from modules import shared, devices, sd_hijack, processing, sd_models, images from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers
import modules.textual_inversion.dataset import modules.textual_inversion.dataset
from modules.textual_inversion.learn_schedule import LearnRateScheduler from modules.textual_inversion.learn_schedule import LearnRateScheduler
@ -345,7 +345,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
p.prompt = preview_prompt p.prompt = preview_prompt
p.negative_prompt = preview_negative_prompt p.negative_prompt = preview_negative_prompt
p.steps = preview_steps p.steps = preview_steps
p.sampler_index = preview_sampler_index p.sampler_name = sd_samplers.samplers[preview_sampler_index].name
p.cfg_scale = preview_cfg_scale p.cfg_scale = preview_cfg_scale
p.seed = preview_seed p.seed = preview_seed
p.width = preview_width p.width = preview_width

View file

@ -1,4 +1,5 @@
import modules.scripts import modules.scripts
from modules import sd_samplers
from modules.processing import StableDiffusionProcessing, Processed, StableDiffusionProcessingTxt2Img, \ from modules.processing import StableDiffusionProcessing, Processed, StableDiffusionProcessingTxt2Img, \
StableDiffusionProcessingImg2Img, process_images StableDiffusionProcessingImg2Img, process_images
from modules.shared import opts, cmd_opts from modules.shared import opts, cmd_opts
@ -21,7 +22,7 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
seed_resize_from_h=seed_resize_from_h, seed_resize_from_h=seed_resize_from_h,
seed_resize_from_w=seed_resize_from_w, seed_resize_from_w=seed_resize_from_w,
seed_enable_extras=seed_enable_extras, seed_enable_extras=seed_enable_extras,
sampler_index=sampler_index, sampler_name=sd_samplers.samplers[sampler_index].name,
batch_size=batch_size, batch_size=batch_size,
n_iter=n_iter, n_iter=n_iter,
steps=steps, steps=steps,

View file

@ -142,7 +142,7 @@ def save_files(js_data, images, do_make_zip, index):
filenames.append(os.path.basename(txt_fullfn)) filenames.append(os.path.basename(txt_fullfn))
fullfns.append(txt_fullfn) fullfns.append(txt_fullfn)
writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]]) writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler_name"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]])
# Make Zip # Make Zip
if do_make_zip: if do_make_zip:

View file

@ -36,9 +36,9 @@ def apply_and_restart(disable_list, update_list):
continue continue
try: try:
ext.pull() ext.fetch_and_reset_hard()
except Exception: except Exception:
print(f"Error pulling updates for {ext.name}:", file=sys.stderr) print(f"Error getting updates for {ext.name}:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr) print(traceback.format_exc(), file=sys.stderr)
shared.opts.disabled_extensions = disabled shared.opts.disabled_extensions = disabled

View file

@ -157,7 +157,7 @@ class Script(scripts.Script):
def run(self, p, _, override_sampler, override_prompt, original_prompt, original_negative_prompt, override_steps, st, override_strength, cfg, randomness, sigma_adjustment): def run(self, p, _, override_sampler, override_prompt, original_prompt, original_negative_prompt, override_steps, st, override_strength, cfg, randomness, sigma_adjustment):
# Override # Override
if override_sampler: if override_sampler:
p.sampler_index = [sampler.name for sampler in sd_samplers.samplers].index("Euler") p.sampler_name = "Euler"
if override_prompt: if override_prompt:
p.prompt = original_prompt p.prompt = original_prompt
p.negative_prompt = original_negative_prompt p.negative_prompt = original_negative_prompt
@ -191,7 +191,7 @@ class Script(scripts.Script):
combined_noise = ((1 - randomness) * rec_noise + randomness * rand_noise) / ((randomness**2 + (1-randomness)**2) ** 0.5) combined_noise = ((1 - randomness) * rec_noise + randomness * rand_noise) / ((randomness**2 + (1-randomness)**2) ** 0.5)
sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, p.sampler_index, p.sd_model) sampler = sd_samplers.create_sampler(p.sampler_name, p.sd_model)
sigmas = sampler.model_wrap.get_sigmas(p.steps) sigmas = sampler.model_wrap.get_sigmas(p.steps)

View file

@ -10,9 +10,9 @@ import numpy as np
import modules.scripts as scripts import modules.scripts as scripts
import gradio as gr import gradio as gr
from modules import images from modules import images, sd_samplers
from modules.hypernetworks import hypernetwork from modules.hypernetworks import hypernetwork
from modules.processing import process_images, Processed, get_correct_sampler, StableDiffusionProcessingTxt2Img from modules.processing import process_images, Processed, StableDiffusionProcessingTxt2Img
from modules.shared import opts, cmd_opts, state from modules.shared import opts, cmd_opts, state
import modules.shared as shared import modules.shared as shared
import modules.sd_samplers import modules.sd_samplers
@ -60,9 +60,9 @@ def apply_order(p, x, xs):
p.prompt = prompt_tmp + p.prompt p.prompt = prompt_tmp + p.prompt
def build_samplers_dict(p): def build_samplers_dict():
samplers_dict = {} samplers_dict = {}
for i, sampler in enumerate(get_correct_sampler(p)): for i, sampler in enumerate(sd_samplers.all_samplers):
samplers_dict[sampler.name.lower()] = i samplers_dict[sampler.name.lower()] = i
for alias in sampler.aliases: for alias in sampler.aliases:
samplers_dict[alias.lower()] = i samplers_dict[alias.lower()] = i
@ -70,7 +70,7 @@ def build_samplers_dict(p):
def apply_sampler(p, x, xs): def apply_sampler(p, x, xs):
sampler_index = build_samplers_dict(p).get(x.lower(), None) sampler_index = build_samplers_dict().get(x.lower(), None)
if sampler_index is None: if sampler_index is None:
raise RuntimeError(f"Unknown sampler: {x}") raise RuntimeError(f"Unknown sampler: {x}")
@ -78,7 +78,7 @@ def apply_sampler(p, x, xs):
def confirm_samplers(p, xs): def confirm_samplers(p, xs):
samplers_dict = build_samplers_dict(p) samplers_dict = build_samplers_dict()
for x in xs: for x in xs:
if x.lower() not in samplers_dict.keys(): if x.lower() not in samplers_dict.keys():
raise RuntimeError(f"Unknown sampler: {x}") raise RuntimeError(f"Unknown sampler: {x}")