add an option to enable sections from extras tab in txt2img/img2img

fix some style inconsistenices
This commit is contained in:
AUTOMATIC 2023-01-26 23:29:27 +03:00
parent 645f4e7ef8
commit 7a14c8ab45
9 changed files with 133 additions and 23 deletions

View file

@ -13,7 +13,7 @@ from skimage import exposure
from typing import Any, Dict, List, Optional
import modules.sd_hijack
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks, extra_networks, sd_vae_approx
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks, extra_networks, sd_vae_approx, scripts
from modules.sd_hijack import model_hijack
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
@ -658,6 +658,11 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
image = Image.fromarray(x_sample)
if p.scripts is not None:
pp = scripts.PostprocessImageArgs(image)
p.scripts.postprocess_image(p, pp)
image = pp.image
if p.color_corrections is not None and i < len(p.color_corrections):
if opts.save and not p.do_not_save_samples and opts.save_images_before_color_correction:
image_without_cc = apply_overlay(image, p.paste_to, i, p.overlay_images)

View file

@ -6,12 +6,16 @@ from collections import namedtuple
import gradio as gr
from modules.processing import StableDiffusionProcessing
from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing
AlwaysVisible = object()
class PostprocessImageArgs:
def __init__(self, image):
self.image = image
class Script:
filename = None
args_from = None
@ -65,7 +69,7 @@ class Script:
args contains all values returned by components from ui()
"""
raise NotImplementedError()
pass
def process(self, p, *args):
"""
@ -100,6 +104,13 @@ class Script:
pass
def postprocess_image(self, p, pp: PostprocessImageArgs, *args):
"""
Called for every image after it has been generated.
"""
pass
def postprocess(self, p, processed, *args):
"""
This function is called after processing ends for AlwaysVisible scripts.
@ -247,11 +258,15 @@ class ScriptRunner:
self.infotext_fields = []
def initialize_scripts(self, is_img2img):
from modules import scripts_auto_postprocessing
self.scripts.clear()
self.alwayson_scripts.clear()
self.selectable_scripts.clear()
for script_class, path, basedir, script_module in scripts_data:
auto_processing_scripts = scripts_auto_postprocessing.create_auto_preprocessing_script_data()
for script_class, path, basedir, script_module in auto_processing_scripts + scripts_data:
script = script_class()
script.filename = path
script.is_txt2img = not is_img2img
@ -332,7 +347,7 @@ class ScriptRunner:
return inputs
def run(self, p: StableDiffusionProcessing, *args):
def run(self, p, *args):
script_index = args[0]
if script_index == 0:
@ -386,6 +401,15 @@ class ScriptRunner:
print(f"Error running postprocess_batch: {script.filename}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
def postprocess_image(self, p, pp: PostprocessImageArgs):
for script in self.alwayson_scripts:
try:
script_args = p.script_args[script.args_from:script.args_to]
script.postprocess_image(p, pp, *script_args)
except Exception:
print(f"Error running postprocess_batch: {script.filename}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
def before_component(self, component, **kwargs):
for script in self.scripts:
try:

View file

@ -0,0 +1,42 @@
from modules import scripts, scripts_postprocessing, shared
class ScriptPostprocessingForMainUI(scripts.Script):
def __init__(self, script_postproc):
self.script: scripts_postprocessing.ScriptPostprocessing = script_postproc
self.postprocessing_controls = None
def title(self):
return self.script.name
def show(self, is_img2img):
return scripts.AlwaysVisible
def ui(self, is_img2img):
self.postprocessing_controls = self.script.ui()
return self.postprocessing_controls.values()
def postprocess_image(self, p, script_pp, *args):
args_dict = {k: v for k, v in zip(self.postprocessing_controls, args)}
pp = scripts_postprocessing.PostprocessedImage(script_pp.image)
pp.info = {}
self.script.process(pp, **args_dict)
p.extra_generation_params.update(pp.info)
script_pp.image = pp.image
def create_auto_preprocessing_script_data():
from modules import scripts
res = []
for name in shared.opts.postprocessing_enable_in_main_ui:
script = next(iter([x for x in scripts.postprocessing_scripts_data if x.script_class.name == name]), None)
if script is None:
continue
constructor = lambda s=script: ScriptPostprocessingForMainUI(s.script_class())
res.append(scripts.ScriptClassData(script_class=constructor, path=script.path, basedir=script.basedir, module=script.module))
return res

View file

@ -46,6 +46,8 @@ class ScriptPostprocessing:
pass
def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
try:
res = func(*args, **kwargs)
@ -68,6 +70,9 @@ class ScriptPostprocessingRunner:
script: ScriptPostprocessing = script_class()
script.filename = path
if script.name == "Simple Upscale":
continue
self.scripts.append(script)
def create_script_ui(self, script, inputs):
@ -87,12 +92,11 @@ class ScriptPostprocessingRunner:
import modules.scripts
self.initialize_scripts(modules.scripts.postprocessing_scripts_data)
scripts_order = [x.lower().strip() for x in shared.opts.postprocessing_scipts_order.split(",")]
scripts_order = shared.opts.postprocessing_operation_order
def script_score(name):
name = name.lower()
for i, possible_match in enumerate(scripts_order):
if possible_match in name:
if possible_match == name:
return i
return len(self.scripts)
@ -145,3 +149,4 @@ class ScriptPostprocessingRunner:
def image_changed(self):
for script in self.scripts_in_preferred_order():
script.image_changed()

View file

@ -13,8 +13,8 @@ import modules.interrogate
import modules.memmon
import modules.styles
import modules.devices as devices
from modules import localization, sd_vae, extensions, script_loading, errors, ui_components
from modules.paths import models_path, script_path, sd_path
from modules import localization, sd_vae, extensions, script_loading, errors, ui_components, shared_items
from modules.paths import models_path, script_path
demo = None
@ -264,12 +264,6 @@ interrogator = modules.interrogate.InterrogateModels("interrogate")
face_restorers = []
def realesrgan_models_names():
import modules.realesrgan_model
return [x.name for x in modules.realesrgan_model.get_realesrgan_models(None)]
class OptionInfo:
def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, section=None, refresh=None):
self.default = default
@ -360,7 +354,7 @@ options_templates.update(options_section(('saving-to-dirs', "Saving to a directo
options_templates.update(options_section(('upscaling', "Upscaling"), {
"ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers. 0 = no tiling.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}),
"ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for ESRGAN upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
"realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": realesrgan_models_names()}),
"realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": shared_items.realesrgan_models_names()}),
"upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}),
}))
@ -483,7 +477,8 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters"
}))
options_templates.update(options_section(('postprocessing', "Postprocessing"), {
'postprocessing_scipts_order': OptionInfo("upscale, gfpgan, codeformer", "Postprocessing operation order"),
'postprocessing_enable_in_main_ui': OptionInfo([], "Enable postprocessing operations in txt2img and img2img tabs", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}),
'postprocessing_operation_order': OptionInfo([], "Postprocessing operation order", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}),
'upscaling_max_images_in_cache': OptionInfo(5, "Maximum number of images in upscaling cache", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
}))

10
modules/shared_items.py Normal file
View file

@ -0,0 +1,10 @@
def realesrgan_models_names():
import modules.realesrgan_model
return [x.name for x in modules.realesrgan_model.get_realesrgan_models(None)]
def postprocessing_scripts():
import modules.scripts
return modules.scripts.scripts_postproc.scripts

View file

@ -48,3 +48,11 @@ class FormColorPicker(gr.ColorPicker, gr.components.FormComponent):
def get_block_name(self):
return "colorpicker"
class DropdownMulti(gr.Dropdown):
"""Same as gr.Dropdown but always multiselect"""
def __init__(self, **kwargs):
super().__init__(multiselect=True, **kwargs)
def get_block_name(self):
return "dropdown"

View file

@ -104,3 +104,28 @@ class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing):
def image_changed(self):
upscale_cache.clear()
class ScriptPostprocessingUpscaleSimple(ScriptPostprocessingUpscale):
name = "Simple Upscale"
order = 900
def ui(self):
with FormRow():
upscaler_name = gr.Dropdown(label='Upscaler', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name)
upscale_by = gr.Slider(minimum=0.05, maximum=8.0, step=0.05, label="Upscale by", value=2)
return {
"upscale_by": upscale_by,
"upscaler_name": upscaler_name,
}
def process(self, pp: scripts_postprocessing.PostprocessedImage, upscale_by=2.0, upscaler_name=None):
if upscaler_name is None or upscaler_name == "None":
return
upscaler1 = next(iter([x for x in shared.sd_upscalers if x.name == upscaler_name]), None)
assert upscaler1, f'could not find upscaler named {upscaler_name}'
pp.image = self.upscale(pp.image, pp.info, upscaler1, 0, upscale_by, 0, 0, False)
pp.info[f"Postprocess upscaler"] = upscaler1.name

View file

@ -164,7 +164,7 @@
min-height: 3.2em;
}
#txt2img_styles ul, #img2img_styles ul{
ul.list-none{
max-height: 35em;
z-index: 2000;
}
@ -714,9 +714,6 @@ footer {
white-space: nowrap;
min-width: auto;
}
#txt2img_hires_fix{
margin-left: -0.8em;
}
#img2img_copy_to_img2img, #img2img_copy_to_sketch, #img2img_copy_to_inpaint, #img2img_copy_to_inpaint_sketch{
margin-left: 0em;
@ -744,7 +741,6 @@ footer {
.dark .gr-compact{
background-color: rgb(31 41 55 / var(--tw-bg-opacity));
margin-left: 0.8em;
}
.gr-compact{