option to let users select which samplers they want to hide
This commit is contained in:
parent
6e7057b31b
commit
5f24b7bcf4
4 changed files with 35 additions and 16 deletions
|
@ -11,9 +11,8 @@ import cv2
|
|||
from skimage import exposure
|
||||
|
||||
import modules.sd_hijack
|
||||
from modules import devices, prompt_parser, masking
|
||||
from modules import devices, prompt_parser, masking, sd_samplers
|
||||
from modules.sd_hijack import model_hijack
|
||||
from modules.sd_samplers import samplers, samplers_for_img2img
|
||||
from modules.shared import opts, cmd_opts, state
|
||||
import modules.shared as shared
|
||||
import modules.face_restoration
|
||||
|
@ -110,7 +109,7 @@ class Processed:
|
|||
self.width = p.width
|
||||
self.height = p.height
|
||||
self.sampler_index = p.sampler_index
|
||||
self.sampler = samplers[p.sampler_index].name
|
||||
self.sampler = sd_samplers.samplers[p.sampler_index].name
|
||||
self.cfg_scale = p.cfg_scale
|
||||
self.steps = p.steps
|
||||
self.batch_size = p.batch_size
|
||||
|
@ -265,7 +264,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
|
|||
|
||||
generation_params = {
|
||||
"Steps": p.steps,
|
||||
"Sampler": samplers[p.sampler_index].name,
|
||||
"Sampler": sd_samplers.samplers[p.sampler_index].name,
|
||||
"CFG scale": p.cfg_scale,
|
||||
"Seed": all_seeds[index],
|
||||
"Face restoration": (opts.face_restoration_model if p.restore_faces else None),
|
||||
|
@ -478,7 +477,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||
self.firstphase_height_truncated = int(scale * self.height)
|
||||
|
||||
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
|
||||
self.sampler = samplers[self.sampler_index].constructor(self.sd_model)
|
||||
self.sampler = sd_samplers.samplers[self.sampler_index].constructor(self.sd_model)
|
||||
|
||||
if not self.enable_hr:
|
||||
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
|
||||
|
@ -521,7 +520,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||
|
||||
shared.state.nextjob()
|
||||
|
||||
self.sampler = samplers[self.sampler_index].constructor(self.sd_model)
|
||||
self.sampler = sd_samplers.samplers[self.sampler_index].constructor(self.sd_model)
|
||||
noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
|
||||
|
||||
# GC now before running the next img2img to prevent running out of memory
|
||||
|
@ -556,7 +555,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
|||
self.nmask = None
|
||||
|
||||
def init(self, all_prompts, all_seeds, all_subseeds):
|
||||
self.sampler = samplers_for_img2img[self.sampler_index].constructor(self.sd_model)
|
||||
self.sampler = sd_samplers.samplers_for_img2img[self.sampler_index].constructor(self.sd_model)
|
||||
crop_region = None
|
||||
|
||||
if self.image_mask is not None:
|
||||
|
|
|
@ -32,12 +32,27 @@ samplers_data_k_diffusion = [
|
|||
if hasattr(k_diffusion.sampling, funcname)
|
||||
]
|
||||
|
||||
samplers = [
|
||||
all_samplers = [
|
||||
*samplers_data_k_diffusion,
|
||||
SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), []),
|
||||
SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), []),
|
||||
]
|
||||
samplers_for_img2img = [x for x in samplers if x.name not in ['PLMS', 'DPM fast', 'DPM adaptive']]
|
||||
|
||||
samplers = []
|
||||
samplers_for_img2img = []
|
||||
|
||||
|
||||
def set_samplers():
|
||||
global samplers, samplers_for_img2img
|
||||
|
||||
hidden = set(opts.hide_samplers)
|
||||
hidden_img2img = set(opts.hide_samplers + ['PLMS', 'DPM fast', 'DPM adaptive'])
|
||||
|
||||
samplers = [x for x in all_samplers if x.name not in hidden]
|
||||
samplers_for_img2img = [x for x in all_samplers if x.name not in hidden_img2img]
|
||||
|
||||
|
||||
set_samplers()
|
||||
|
||||
sampler_extra_params = {
|
||||
'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
|
||||
|
|
|
@ -13,6 +13,7 @@ import modules.memmon
|
|||
import modules.sd_models
|
||||
import modules.styles
|
||||
import modules.devices as devices
|
||||
from modules import sd_samplers
|
||||
from modules.paths import script_path, sd_path
|
||||
|
||||
sd_model_file = os.path.join(script_path, 'model.ckpt')
|
||||
|
@ -238,6 +239,7 @@ options_templates.update(options_section(('ui', "User interface"), {
|
|||
}))
|
||||
|
||||
options_templates.update(options_section(('sampler-params', "Sampler parameters"), {
|
||||
"hide_samplers": OptionInfo([], "Hide samplers in user interface (requires restart)", gr.CheckboxGroup, lambda: {"choices": [x.name for x in sd_samplers.all_samplers]}),
|
||||
"eta_ddim": OptionInfo(0.0, "eta (noise multiplier) for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
||||
"eta_ancestral": OptionInfo(1.0, "eta (noise multiplier) for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
||||
"ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}),
|
||||
|
@ -246,6 +248,7 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters"
|
|||
's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
||||
}))
|
||||
|
||||
|
||||
class Options:
|
||||
data = None
|
||||
data_labels = options_templates
|
||||
|
|
4
webui.py
4
webui.py
|
@ -2,7 +2,7 @@ import os
|
|||
import threading
|
||||
import time
|
||||
import importlib
|
||||
from modules import devices
|
||||
from modules import devices, sd_samplers
|
||||
from modules.paths import script_path
|
||||
import signal
|
||||
import threading
|
||||
|
@ -109,6 +109,8 @@ def webui():
|
|||
time.sleep(0.5)
|
||||
break
|
||||
|
||||
sd_samplers.set_samplers()
|
||||
|
||||
print('Reloading Custom Scripts')
|
||||
modules.scripts.reload_scripts(os.path.join(script_path, "scripts"))
|
||||
print('Reloading modules: modules.ui')
|
||||
|
|
Loading…
Reference in a new issue