option to let users select which samplers they want to hide

This commit is contained in:
AUTOMATIC 2022-10-06 12:08:48 +03:00
parent 6e7057b31b
commit 5f24b7bcf4
4 changed files with 35 additions and 16 deletions

View file

@ -11,9 +11,8 @@ import cv2
from skimage import exposure from skimage import exposure
import modules.sd_hijack import modules.sd_hijack
from modules import devices, prompt_parser, masking from modules import devices, prompt_parser, masking, sd_samplers
from modules.sd_hijack import model_hijack from modules.sd_hijack import model_hijack
from modules.sd_samplers import samplers, samplers_for_img2img
from modules.shared import opts, cmd_opts, state from modules.shared import opts, cmd_opts, state
import modules.shared as shared import modules.shared as shared
import modules.face_restoration import modules.face_restoration
@ -110,7 +109,7 @@ class Processed:
self.width = p.width self.width = p.width
self.height = p.height self.height = p.height
self.sampler_index = p.sampler_index self.sampler_index = p.sampler_index
self.sampler = samplers[p.sampler_index].name self.sampler = sd_samplers.samplers[p.sampler_index].name
self.cfg_scale = p.cfg_scale self.cfg_scale = p.cfg_scale
self.steps = p.steps self.steps = p.steps
self.batch_size = p.batch_size self.batch_size = p.batch_size
@ -265,7 +264,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
generation_params = { generation_params = {
"Steps": p.steps, "Steps": p.steps,
"Sampler": samplers[p.sampler_index].name, "Sampler": sd_samplers.samplers[p.sampler_index].name,
"CFG scale": p.cfg_scale, "CFG scale": p.cfg_scale,
"Seed": all_seeds[index], "Seed": all_seeds[index],
"Face restoration": (opts.face_restoration_model if p.restore_faces else None), "Face restoration": (opts.face_restoration_model if p.restore_faces else None),
@ -478,7 +477,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.firstphase_height_truncated = int(scale * self.height) self.firstphase_height_truncated = int(scale * self.height)
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength): def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
self.sampler = samplers[self.sampler_index].constructor(self.sd_model) self.sampler = sd_samplers.samplers[self.sampler_index].constructor(self.sd_model)
if not self.enable_hr: if not self.enable_hr:
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
@ -521,7 +520,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
shared.state.nextjob() shared.state.nextjob()
self.sampler = samplers[self.sampler_index].constructor(self.sd_model) self.sampler = sd_samplers.samplers[self.sampler_index].constructor(self.sd_model)
noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
# GC now before running the next img2img to prevent running out of memory # GC now before running the next img2img to prevent running out of memory
@ -556,7 +555,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
self.nmask = None self.nmask = None
def init(self, all_prompts, all_seeds, all_subseeds): def init(self, all_prompts, all_seeds, all_subseeds):
self.sampler = samplers_for_img2img[self.sampler_index].constructor(self.sd_model) self.sampler = sd_samplers.samplers_for_img2img[self.sampler_index].constructor(self.sd_model)
crop_region = None crop_region = None
if self.image_mask is not None: if self.image_mask is not None:

View file

@ -32,12 +32,27 @@ samplers_data_k_diffusion = [
if hasattr(k_diffusion.sampling, funcname) if hasattr(k_diffusion.sampling, funcname)
] ]
samplers = [ all_samplers = [
*samplers_data_k_diffusion, *samplers_data_k_diffusion,
SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), []), SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), []),
SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), []), SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), []),
] ]
samplers_for_img2img = [x for x in samplers if x.name not in ['PLMS', 'DPM fast', 'DPM adaptive']]
samplers = []
samplers_for_img2img = []
def set_samplers():
global samplers, samplers_for_img2img
hidden = set(opts.hide_samplers)
hidden_img2img = set(opts.hide_samplers + ['PLMS', 'DPM fast', 'DPM adaptive'])
samplers = [x for x in all_samplers if x.name not in hidden]
samplers_for_img2img = [x for x in all_samplers if x.name not in hidden_img2img]
set_samplers()
sampler_extra_params = { sampler_extra_params = {
'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'], 'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'],

View file

@ -13,6 +13,7 @@ import modules.memmon
import modules.sd_models import modules.sd_models
import modules.styles import modules.styles
import modules.devices as devices import modules.devices as devices
from modules import sd_samplers
from modules.paths import script_path, sd_path from modules.paths import script_path, sd_path
sd_model_file = os.path.join(script_path, 'model.ckpt') sd_model_file = os.path.join(script_path, 'model.ckpt')
@ -238,6 +239,7 @@ options_templates.update(options_section(('ui', "User interface"), {
})) }))
options_templates.update(options_section(('sampler-params', "Sampler parameters"), { options_templates.update(options_section(('sampler-params', "Sampler parameters"), {
"hide_samplers": OptionInfo([], "Hide samplers in user interface (requires restart)", gr.CheckboxGroup, lambda: {"choices": [x.name for x in sd_samplers.all_samplers]}),
"eta_ddim": OptionInfo(0.0, "eta (noise multiplier) for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), "eta_ddim": OptionInfo(0.0, "eta (noise multiplier) for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
"eta_ancestral": OptionInfo(1.0, "eta (noise multiplier) for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), "eta_ancestral": OptionInfo(1.0, "eta (noise multiplier) for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
"ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}), "ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}),
@ -246,6 +248,7 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters"
's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), 's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
})) }))
class Options: class Options:
data = None data = None
data_labels = options_templates data_labels = options_templates

View file

@ -2,7 +2,7 @@ import os
import threading import threading
import time import time
import importlib import importlib
from modules import devices from modules import devices, sd_samplers
from modules.paths import script_path from modules.paths import script_path
import signal import signal
import threading import threading
@ -109,6 +109,8 @@ def webui():
time.sleep(0.5) time.sleep(0.5)
break break
sd_samplers.set_samplers()
print('Reloading Custom Scripts') print('Reloading Custom Scripts')
modules.scripts.reload_scripts(os.path.join(script_path, "scripts")) modules.scripts.reload_scripts(os.path.join(script_path, "scripts"))
print('Reloading modules: modules.ui') print('Reloading modules: modules.ui')