option to let users select which samplers they want to hide
This commit is contained in:
parent
6e7057b31b
commit
5f24b7bcf4
4 changed files with 35 additions and 16 deletions
|
@ -11,9 +11,8 @@ import cv2
|
||||||
from skimage import exposure
|
from skimage import exposure
|
||||||
|
|
||||||
import modules.sd_hijack
|
import modules.sd_hijack
|
||||||
from modules import devices, prompt_parser, masking
|
from modules import devices, prompt_parser, masking, sd_samplers
|
||||||
from modules.sd_hijack import model_hijack
|
from modules.sd_hijack import model_hijack
|
||||||
from modules.sd_samplers import samplers, samplers_for_img2img
|
|
||||||
from modules.shared import opts, cmd_opts, state
|
from modules.shared import opts, cmd_opts, state
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
import modules.face_restoration
|
import modules.face_restoration
|
||||||
|
@ -110,7 +109,7 @@ class Processed:
|
||||||
self.width = p.width
|
self.width = p.width
|
||||||
self.height = p.height
|
self.height = p.height
|
||||||
self.sampler_index = p.sampler_index
|
self.sampler_index = p.sampler_index
|
||||||
self.sampler = samplers[p.sampler_index].name
|
self.sampler = sd_samplers.samplers[p.sampler_index].name
|
||||||
self.cfg_scale = p.cfg_scale
|
self.cfg_scale = p.cfg_scale
|
||||||
self.steps = p.steps
|
self.steps = p.steps
|
||||||
self.batch_size = p.batch_size
|
self.batch_size = p.batch_size
|
||||||
|
@ -265,7 +264,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
|
||||||
|
|
||||||
generation_params = {
|
generation_params = {
|
||||||
"Steps": p.steps,
|
"Steps": p.steps,
|
||||||
"Sampler": samplers[p.sampler_index].name,
|
"Sampler": sd_samplers.samplers[p.sampler_index].name,
|
||||||
"CFG scale": p.cfg_scale,
|
"CFG scale": p.cfg_scale,
|
||||||
"Seed": all_seeds[index],
|
"Seed": all_seeds[index],
|
||||||
"Face restoration": (opts.face_restoration_model if p.restore_faces else None),
|
"Face restoration": (opts.face_restoration_model if p.restore_faces else None),
|
||||||
|
@ -478,7 +477,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||||
self.firstphase_height_truncated = int(scale * self.height)
|
self.firstphase_height_truncated = int(scale * self.height)
|
||||||
|
|
||||||
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
|
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
|
||||||
self.sampler = samplers[self.sampler_index].constructor(self.sd_model)
|
self.sampler = sd_samplers.samplers[self.sampler_index].constructor(self.sd_model)
|
||||||
|
|
||||||
if not self.enable_hr:
|
if not self.enable_hr:
|
||||||
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
|
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
|
||||||
|
@ -521,7 +520,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||||
|
|
||||||
shared.state.nextjob()
|
shared.state.nextjob()
|
||||||
|
|
||||||
self.sampler = samplers[self.sampler_index].constructor(self.sd_model)
|
self.sampler = sd_samplers.samplers[self.sampler_index].constructor(self.sd_model)
|
||||||
noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
|
noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
|
||||||
|
|
||||||
# GC now before running the next img2img to prevent running out of memory
|
# GC now before running the next img2img to prevent running out of memory
|
||||||
|
@ -556,7 +555,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
||||||
self.nmask = None
|
self.nmask = None
|
||||||
|
|
||||||
def init(self, all_prompts, all_seeds, all_subseeds):
|
def init(self, all_prompts, all_seeds, all_subseeds):
|
||||||
self.sampler = samplers_for_img2img[self.sampler_index].constructor(self.sd_model)
|
self.sampler = sd_samplers.samplers_for_img2img[self.sampler_index].constructor(self.sd_model)
|
||||||
crop_region = None
|
crop_region = None
|
||||||
|
|
||||||
if self.image_mask is not None:
|
if self.image_mask is not None:
|
||||||
|
|
|
@ -32,12 +32,27 @@ samplers_data_k_diffusion = [
|
||||||
if hasattr(k_diffusion.sampling, funcname)
|
if hasattr(k_diffusion.sampling, funcname)
|
||||||
]
|
]
|
||||||
|
|
||||||
samplers = [
|
all_samplers = [
|
||||||
*samplers_data_k_diffusion,
|
*samplers_data_k_diffusion,
|
||||||
SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), []),
|
SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), []),
|
||||||
SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), []),
|
SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), []),
|
||||||
]
|
]
|
||||||
samplers_for_img2img = [x for x in samplers if x.name not in ['PLMS', 'DPM fast', 'DPM adaptive']]
|
|
||||||
|
samplers = []
|
||||||
|
samplers_for_img2img = []
|
||||||
|
|
||||||
|
|
||||||
|
def set_samplers():
|
||||||
|
global samplers, samplers_for_img2img
|
||||||
|
|
||||||
|
hidden = set(opts.hide_samplers)
|
||||||
|
hidden_img2img = set(opts.hide_samplers + ['PLMS', 'DPM fast', 'DPM adaptive'])
|
||||||
|
|
||||||
|
samplers = [x for x in all_samplers if x.name not in hidden]
|
||||||
|
samplers_for_img2img = [x for x in all_samplers if x.name not in hidden_img2img]
|
||||||
|
|
||||||
|
|
||||||
|
set_samplers()
|
||||||
|
|
||||||
sampler_extra_params = {
|
sampler_extra_params = {
|
||||||
'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
|
'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
|
||||||
|
|
|
@ -13,6 +13,7 @@ import modules.memmon
|
||||||
import modules.sd_models
|
import modules.sd_models
|
||||||
import modules.styles
|
import modules.styles
|
||||||
import modules.devices as devices
|
import modules.devices as devices
|
||||||
|
from modules import sd_samplers
|
||||||
from modules.paths import script_path, sd_path
|
from modules.paths import script_path, sd_path
|
||||||
|
|
||||||
sd_model_file = os.path.join(script_path, 'model.ckpt')
|
sd_model_file = os.path.join(script_path, 'model.ckpt')
|
||||||
|
@ -238,6 +239,7 @@ options_templates.update(options_section(('ui', "User interface"), {
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('sampler-params', "Sampler parameters"), {
|
options_templates.update(options_section(('sampler-params', "Sampler parameters"), {
|
||||||
|
"hide_samplers": OptionInfo([], "Hide samplers in user interface (requires restart)", gr.CheckboxGroup, lambda: {"choices": [x.name for x in sd_samplers.all_samplers]}),
|
||||||
"eta_ddim": OptionInfo(0.0, "eta (noise multiplier) for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
"eta_ddim": OptionInfo(0.0, "eta (noise multiplier) for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
||||||
"eta_ancestral": OptionInfo(1.0, "eta (noise multiplier) for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
"eta_ancestral": OptionInfo(1.0, "eta (noise multiplier) for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
||||||
"ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}),
|
"ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}),
|
||||||
|
@ -246,6 +248,7 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters"
|
||||||
's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
|
||||||
class Options:
|
class Options:
|
||||||
data = None
|
data = None
|
||||||
data_labels = options_templates
|
data_labels = options_templates
|
||||||
|
|
4
webui.py
4
webui.py
|
@ -2,7 +2,7 @@ import os
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
import importlib
|
import importlib
|
||||||
from modules import devices
|
from modules import devices, sd_samplers
|
||||||
from modules.paths import script_path
|
from modules.paths import script_path
|
||||||
import signal
|
import signal
|
||||||
import threading
|
import threading
|
||||||
|
@ -109,6 +109,8 @@ def webui():
|
||||||
time.sleep(0.5)
|
time.sleep(0.5)
|
||||||
break
|
break
|
||||||
|
|
||||||
|
sd_samplers.set_samplers()
|
||||||
|
|
||||||
print('Reloading Custom Scripts')
|
print('Reloading Custom Scripts')
|
||||||
modules.scripts.reload_scripts(os.path.join(script_path, "scripts"))
|
modules.scripts.reload_scripts(os.path.join(script_path, "scripts"))
|
||||||
print('Reloading modules: modules.ui')
|
print('Reloading modules: modules.ui')
|
||||||
|
|
Loading…
Reference in a new issue