add a button for XY Plot to fill in available values for axes that support this

This commit is contained in:
AUTOMATIC 2023-01-16 17:36:56 +03:00
parent d073637e10
commit 55947857f0
3 changed files with 68 additions and 46 deletions

View file

@ -20,6 +20,7 @@ titles = {
"\u{1f4be}": "Save style", "\u{1f4be}": "Save style",
"\U0001F5D1": "Clear prompt", "\U0001F5D1": "Clear prompt",
"\u{1f4cb}": "Apply selected styles to current prompt", "\u{1f4cb}": "Apply selected styles to current prompt",
"\u{1f4d2}": "Paste available values into the field",
"Inpaint a part of image": "Draw a mask over an image, and the script will regenerate the masked area with content according to prompt", "Inpaint a part of image": "Draw a mask over an image, and the script will regenerate the masked area with content according to prompt",
"SD upscale": "Upscale image normally, split result into tiles, improve each tile using img2img, merge whole image back", "SD upscale": "Upscale image normally, split result into tiles, improve each tile using img2img, merge whole image back",

View file

@ -10,7 +10,7 @@ import numpy as np
import modules.scripts as scripts import modules.scripts as scripts
import gradio as gr import gradio as gr
from modules import images, paths, sd_samplers, processing from modules import images, paths, sd_samplers, processing, sd_models, sd_vae
from modules.hypernetworks import hypernetwork from modules.hypernetworks import hypernetwork
from modules.processing import process_images, Processed, StableDiffusionProcessingTxt2Img from modules.processing import process_images, Processed, StableDiffusionProcessingTxt2Img
from modules.shared import opts, cmd_opts, state from modules.shared import opts, cmd_opts, state
@ -22,8 +22,9 @@ import glob
import os import os
import re import re
from modules.ui_components import ToolButton
up_down_arrow_symbol = "\u2195\ufe0f" fill_values_symbol = "\U0001f4d2" # 📒
def apply_field(field): def apply_field(field):
@ -178,34 +179,49 @@ def str_permutations(x):
"""dummy function for specifying it in AxisOption's type when you want to get a list of permutations""" """dummy function for specifying it in AxisOption's type when you want to get a list of permutations"""
return x return x
AxisOption = namedtuple("AxisOption", ["label", "type", "apply", "format_value", "confirm", "cost"])
AxisOptionImg2Img = namedtuple("AxisOptionImg2Img", ["label", "type", "apply", "format_value", "confirm", "cost"]) class AxisOption:
def __init__(self, label, type, apply, format_value=format_value_add_label, confirm=None, cost=0.0, choices=None):
self.label = label
self.type = type
self.apply = apply
self.format_value = format_value
self.confirm = confirm
self.cost = cost
self.choices = choices
self.is_img2img = False
class AxisOptionImg2Img(AxisOption):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.is_img2img = False
axis_options = [ axis_options = [
AxisOption("Nothing", str, do_nothing, format_nothing, None, 0), AxisOption("Nothing", str, do_nothing, format_value=format_nothing),
AxisOption("Seed", int, apply_field("seed"), format_value_add_label, None, 0), AxisOption("Seed", int, apply_field("seed")),
AxisOption("Var. seed", int, apply_field("subseed"), format_value_add_label, None, 0), AxisOption("Var. seed", int, apply_field("subseed")),
AxisOption("Var. strength", float, apply_field("subseed_strength"), format_value_add_label, None, 0), AxisOption("Var. strength", float, apply_field("subseed_strength")),
AxisOption("Steps", int, apply_field("steps"), format_value_add_label, None, 0), AxisOption("Steps", int, apply_field("steps")),
AxisOption("CFG Scale", float, apply_field("cfg_scale"), format_value_add_label, None, 0), AxisOption("CFG Scale", float, apply_field("cfg_scale")),
AxisOption("Prompt S/R", str, apply_prompt, format_value, None, 0), AxisOption("Prompt S/R", str, apply_prompt, format_value=format_value),
AxisOption("Prompt order", str_permutations, apply_order, format_value_join_list, None, 0), AxisOption("Prompt order", str_permutations, apply_order, format_value=format_value_join_list),
AxisOption("Sampler", str, apply_sampler, format_value, confirm_samplers, 0), AxisOption("Sampler", str, apply_sampler, format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers]),
AxisOption("Checkpoint name", str, apply_checkpoint, format_value, confirm_checkpoints, 1.0), AxisOption("Checkpoint name", str, apply_checkpoint, format_value=format_value, confirm=confirm_checkpoints, cost=1.0, choices=lambda: list(sd_models.checkpoints_list)),
AxisOption("Hypernetwork", str, apply_hypernetwork, format_value, confirm_hypernetworks, 0.2), AxisOption("Hypernetwork", str, apply_hypernetwork, format_value=format_value, confirm=confirm_hypernetworks, cost=0.2, choices=lambda: list(shared.hypernetworks)),
AxisOption("Hypernet str.", float, apply_hypernetwork_strength, format_value_add_label, None, 0), AxisOption("Hypernet str.", float, apply_hypernetwork_strength),
AxisOption("Sigma Churn", float, apply_field("s_churn"), format_value_add_label, None, 0), AxisOption("Sigma Churn", float, apply_field("s_churn")),
AxisOption("Sigma min", float, apply_field("s_tmin"), format_value_add_label, None, 0), AxisOption("Sigma min", float, apply_field("s_tmin")),
AxisOption("Sigma max", float, apply_field("s_tmax"), format_value_add_label, None, 0), AxisOption("Sigma max", float, apply_field("s_tmax")),
AxisOption("Sigma noise", float, apply_field("s_noise"), format_value_add_label, None, 0), AxisOption("Sigma noise", float, apply_field("s_noise")),
AxisOption("Eta", float, apply_field("eta"), format_value_add_label, None, 0), AxisOption("Eta", float, apply_field("eta")),
AxisOption("Clip skip", int, apply_clip_skip, format_value_add_label, None, 0), AxisOption("Clip skip", int, apply_clip_skip),
AxisOption("Denoising", float, apply_field("denoising_strength"), format_value_add_label, None, 0), AxisOption("Denoising", float, apply_field("denoising_strength")),
AxisOption("Hires upscaler", str, apply_field("hr_upscaler"), format_value_add_label, None, 0), AxisOption("Hires upscaler", str, apply_field("hr_upscaler"), choices=lambda: [x.name for x in shared.sd_upscalers]),
AxisOption("Cond. Image Mask Weight", float, apply_field("inpainting_mask_weight"), format_value_add_label, None, 0), AxisOption("Cond. Image Mask Weight", float, apply_field("inpainting_mask_weight")),
AxisOption("VAE", str, apply_vae, format_value_add_label, None, 0.7), AxisOption("VAE", str, apply_vae, cost=0.7, choices=lambda: list(sd_vae.vae_dict)),
AxisOption("Styles", str, apply_styles, format_value_add_label, None, 0), AxisOption("Styles", str, apply_styles, choices=lambda: list(shared.prompt_styles.styles)),
] ]
@ -262,7 +278,7 @@ def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend, include_lone_
if not processed_result: if not processed_result:
print("Unexpected error: draw_xy_grid failed to return even a single processed image") print("Unexpected error: draw_xy_grid failed to return even a single processed image")
return Processed() return Processed(p, [])
grid = images.image_grid(image_cache, rows=len(ys)) grid = images.image_grid(image_cache, rows=len(ys))
if draw_legend: if draw_legend:
@ -302,23 +318,25 @@ class Script(scripts.Script):
return "X/Y plot" return "X/Y plot"
def ui(self, is_img2img): def ui(self, is_img2img):
current_axis_options = [x for x in axis_options if type(x) == AxisOption or type(x) == AxisOptionImg2Img and is_img2img] current_axis_options = [x for x in axis_options if type(x) == AxisOption or x.is_img2img and is_img2img]
with gr.Row(): with gr.Row():
with gr.Column(scale=1, elem_id="xy_grid_button_column"):
swap_axes_button = gr.Button(value=up_down_arrow_symbol, elem_id="xy_grid_swap_axes")
with gr.Column(scale=19): with gr.Column(scale=19):
with gr.Row(): with gr.Row():
x_type = gr.Dropdown(label="X type", choices=[x.label for x in current_axis_options], value=current_axis_options[1].label, type="index", elem_id=self.elem_id("x_type")) x_type = gr.Dropdown(label="X type", choices=[x.label for x in current_axis_options], value=current_axis_options[1].label, type="index", elem_id=self.elem_id("x_type"))
x_values = gr.Textbox(label="X values", lines=1, elem_id=self.elem_id("x_values")) x_values = gr.Textbox(label="X values", lines=1, elem_id=self.elem_id("x_values"))
fill_x_button = ToolButton(value=fill_values_symbol, elem_id="xy_grid_fill_x_tool_button", visible=False)
with gr.Row(): with gr.Row():
y_type = gr.Dropdown(label="Y type", choices=[x.label for x in current_axis_options], value=current_axis_options[0].label, type="index", elem_id=self.elem_id("y_type")) y_type = gr.Dropdown(label="Y type", choices=[x.label for x in current_axis_options], value=current_axis_options[0].label, type="index", elem_id=self.elem_id("y_type"))
y_values = gr.Textbox(label="Y values", lines=1, elem_id=self.elem_id("y_values")) y_values = gr.Textbox(label="Y values", lines=1, elem_id=self.elem_id("y_values"))
fill_y_button = ToolButton(value=fill_values_symbol, elem_id="xy_grid_fill_y_tool_button", visible=False)
with gr.Row(variant="compact"):
draw_legend = gr.Checkbox(label='Draw legend', value=True, elem_id=self.elem_id("draw_legend")) draw_legend = gr.Checkbox(label='Draw legend', value=True, elem_id=self.elem_id("draw_legend"))
include_lone_images = gr.Checkbox(label='Include Separate Images', value=False, elem_id=self.elem_id("include_lone_images")) include_lone_images = gr.Checkbox(label='Include Separate Images', value=False, elem_id=self.elem_id("include_lone_images"))
no_fixed_seeds = gr.Checkbox(label='Keep -1 for seeds', value=False, elem_id=self.elem_id("no_fixed_seeds")) no_fixed_seeds = gr.Checkbox(label='Keep -1 for seeds', value=False, elem_id=self.elem_id("no_fixed_seeds"))
swap_axes_button = gr.Button(value="Swap axes", elem_id="xy_grid_swap_axes_button")
def swap_axes(x_type, x_values, y_type, y_values): def swap_axes(x_type, x_values, y_type, y_values):
nonlocal current_axis_options nonlocal current_axis_options
@ -327,6 +345,19 @@ class Script(scripts.Script):
swap_args = [x_type, x_values, y_type, y_values] swap_args = [x_type, x_values, y_type, y_values]
swap_axes_button.click(swap_axes, inputs=swap_args, outputs=swap_args) swap_axes_button.click(swap_axes, inputs=swap_args, outputs=swap_args)
def fill(x_type):
axis = axis_options[x_type]
return ", ".join(axis.choices()) if axis.choices else gr.update()
fill_x_button.click(fn=fill, inputs=[x_type], outputs=[x_values])
fill_y_button.click(fn=fill, inputs=[y_type], outputs=[y_values])
def select_axis(x_type):
return gr.Button.update(visible=axis_options[x_type].choices is not None)
x_type.change(fn=select_axis, inputs=[x_type], outputs=[fill_x_button])
y_type.change(fn=select_axis, inputs=[y_type], outputs=[fill_y_button])
return [x_type, x_values, y_type, y_values, draw_legend, include_lone_images, no_fixed_seeds] return [x_type, x_values, y_type, y_values, draw_legend, include_lone_images, no_fixed_seeds]
def run(self, p, x_type, x_values, y_type, y_values, draw_legend, include_lone_images, no_fixed_seeds): def run(self, p, x_type, x_values, y_type, y_values, draw_legend, include_lone_images, no_fixed_seeds):

View file

@ -644,7 +644,7 @@ canvas[key="mask"] {
max-width: 2.5em; max-width: 2.5em;
min-width: 2.5em !important; min-width: 2.5em !important;
height: 2.4em; height: 2.4em;
margin: 0.55em 0; margin: 0.55em 0.7em 0.55em 0;
} }
#quicksettings .gr-button-tool{ #quicksettings .gr-button-tool{
@ -717,16 +717,6 @@ footer {
line-height: 2.4em; line-height: 2.4em;
} }
#xy_grid_button_column {
min-width: 38px !important;
}
#xy_grid_button_column button {
height: 100%;
margin-bottom: 0.7em;
margin-left: 1em;
}
/* The following handles localization for right-to-left (RTL) languages like Arabic. /* The following handles localization for right-to-left (RTL) languages like Arabic.
The rtl media type will only be activated by the logic in javascript/localization.js. The rtl media type will only be activated by the logic in javascript/localization.js.
If you change anything above, you need to make sure it is RTL compliant by just running If you change anything above, you need to make sure it is RTL compliant by just running