add a button for XY Plot to fill in available values for axes that support this
This commit is contained in:
parent
d073637e10
commit
55947857f0
3 changed files with 68 additions and 46 deletions
|
@ -20,6 +20,7 @@ titles = {
|
||||||
"\u{1f4be}": "Save style",
|
"\u{1f4be}": "Save style",
|
||||||
"\U0001F5D1": "Clear prompt",
|
"\U0001F5D1": "Clear prompt",
|
||||||
"\u{1f4cb}": "Apply selected styles to current prompt",
|
"\u{1f4cb}": "Apply selected styles to current prompt",
|
||||||
|
"\u{1f4d2}": "Paste available values into the field",
|
||||||
|
|
||||||
"Inpaint a part of image": "Draw a mask over an image, and the script will regenerate the masked area with content according to prompt",
|
"Inpaint a part of image": "Draw a mask over an image, and the script will regenerate the masked area with content according to prompt",
|
||||||
"SD upscale": "Upscale image normally, split result into tiles, improve each tile using img2img, merge whole image back",
|
"SD upscale": "Upscale image normally, split result into tiles, improve each tile using img2img, merge whole image back",
|
||||||
|
|
|
@ -10,7 +10,7 @@ import numpy as np
|
||||||
import modules.scripts as scripts
|
import modules.scripts as scripts
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
from modules import images, paths, sd_samplers, processing
|
from modules import images, paths, sd_samplers, processing, sd_models, sd_vae
|
||||||
from modules.hypernetworks import hypernetwork
|
from modules.hypernetworks import hypernetwork
|
||||||
from modules.processing import process_images, Processed, StableDiffusionProcessingTxt2Img
|
from modules.processing import process_images, Processed, StableDiffusionProcessingTxt2Img
|
||||||
from modules.shared import opts, cmd_opts, state
|
from modules.shared import opts, cmd_opts, state
|
||||||
|
@ -22,8 +22,9 @@ import glob
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
from modules.ui_components import ToolButton
|
||||||
|
|
||||||
up_down_arrow_symbol = "\u2195\ufe0f"
|
fill_values_symbol = "\U0001f4d2" # 📒
|
||||||
|
|
||||||
|
|
||||||
def apply_field(field):
|
def apply_field(field):
|
||||||
|
@ -178,34 +179,49 @@ def str_permutations(x):
|
||||||
"""dummy function for specifying it in AxisOption's type when you want to get a list of permutations"""
|
"""dummy function for specifying it in AxisOption's type when you want to get a list of permutations"""
|
||||||
return x
|
return x
|
||||||
|
|
||||||
AxisOption = namedtuple("AxisOption", ["label", "type", "apply", "format_value", "confirm", "cost"])
|
|
||||||
AxisOptionImg2Img = namedtuple("AxisOptionImg2Img", ["label", "type", "apply", "format_value", "confirm", "cost"])
|
class AxisOption:
|
||||||
|
def __init__(self, label, type, apply, format_value=format_value_add_label, confirm=None, cost=0.0, choices=None):
|
||||||
|
self.label = label
|
||||||
|
self.type = type
|
||||||
|
self.apply = apply
|
||||||
|
self.format_value = format_value
|
||||||
|
self.confirm = confirm
|
||||||
|
self.cost = cost
|
||||||
|
self.choices = choices
|
||||||
|
self.is_img2img = False
|
||||||
|
|
||||||
|
|
||||||
|
class AxisOptionImg2Img(AxisOption):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.is_img2img = False
|
||||||
|
|
||||||
|
|
||||||
axis_options = [
|
axis_options = [
|
||||||
AxisOption("Nothing", str, do_nothing, format_nothing, None, 0),
|
AxisOption("Nothing", str, do_nothing, format_value=format_nothing),
|
||||||
AxisOption("Seed", int, apply_field("seed"), format_value_add_label, None, 0),
|
AxisOption("Seed", int, apply_field("seed")),
|
||||||
AxisOption("Var. seed", int, apply_field("subseed"), format_value_add_label, None, 0),
|
AxisOption("Var. seed", int, apply_field("subseed")),
|
||||||
AxisOption("Var. strength", float, apply_field("subseed_strength"), format_value_add_label, None, 0),
|
AxisOption("Var. strength", float, apply_field("subseed_strength")),
|
||||||
AxisOption("Steps", int, apply_field("steps"), format_value_add_label, None, 0),
|
AxisOption("Steps", int, apply_field("steps")),
|
||||||
AxisOption("CFG Scale", float, apply_field("cfg_scale"), format_value_add_label, None, 0),
|
AxisOption("CFG Scale", float, apply_field("cfg_scale")),
|
||||||
AxisOption("Prompt S/R", str, apply_prompt, format_value, None, 0),
|
AxisOption("Prompt S/R", str, apply_prompt, format_value=format_value),
|
||||||
AxisOption("Prompt order", str_permutations, apply_order, format_value_join_list, None, 0),
|
AxisOption("Prompt order", str_permutations, apply_order, format_value=format_value_join_list),
|
||||||
AxisOption("Sampler", str, apply_sampler, format_value, confirm_samplers, 0),
|
AxisOption("Sampler", str, apply_sampler, format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers]),
|
||||||
AxisOption("Checkpoint name", str, apply_checkpoint, format_value, confirm_checkpoints, 1.0),
|
AxisOption("Checkpoint name", str, apply_checkpoint, format_value=format_value, confirm=confirm_checkpoints, cost=1.0, choices=lambda: list(sd_models.checkpoints_list)),
|
||||||
AxisOption("Hypernetwork", str, apply_hypernetwork, format_value, confirm_hypernetworks, 0.2),
|
AxisOption("Hypernetwork", str, apply_hypernetwork, format_value=format_value, confirm=confirm_hypernetworks, cost=0.2, choices=lambda: list(shared.hypernetworks)),
|
||||||
AxisOption("Hypernet str.", float, apply_hypernetwork_strength, format_value_add_label, None, 0),
|
AxisOption("Hypernet str.", float, apply_hypernetwork_strength),
|
||||||
AxisOption("Sigma Churn", float, apply_field("s_churn"), format_value_add_label, None, 0),
|
AxisOption("Sigma Churn", float, apply_field("s_churn")),
|
||||||
AxisOption("Sigma min", float, apply_field("s_tmin"), format_value_add_label, None, 0),
|
AxisOption("Sigma min", float, apply_field("s_tmin")),
|
||||||
AxisOption("Sigma max", float, apply_field("s_tmax"), format_value_add_label, None, 0),
|
AxisOption("Sigma max", float, apply_field("s_tmax")),
|
||||||
AxisOption("Sigma noise", float, apply_field("s_noise"), format_value_add_label, None, 0),
|
AxisOption("Sigma noise", float, apply_field("s_noise")),
|
||||||
AxisOption("Eta", float, apply_field("eta"), format_value_add_label, None, 0),
|
AxisOption("Eta", float, apply_field("eta")),
|
||||||
AxisOption("Clip skip", int, apply_clip_skip, format_value_add_label, None, 0),
|
AxisOption("Clip skip", int, apply_clip_skip),
|
||||||
AxisOption("Denoising", float, apply_field("denoising_strength"), format_value_add_label, None, 0),
|
AxisOption("Denoising", float, apply_field("denoising_strength")),
|
||||||
AxisOption("Hires upscaler", str, apply_field("hr_upscaler"), format_value_add_label, None, 0),
|
AxisOption("Hires upscaler", str, apply_field("hr_upscaler"), choices=lambda: [x.name for x in shared.sd_upscalers]),
|
||||||
AxisOption("Cond. Image Mask Weight", float, apply_field("inpainting_mask_weight"), format_value_add_label, None, 0),
|
AxisOption("Cond. Image Mask Weight", float, apply_field("inpainting_mask_weight")),
|
||||||
AxisOption("VAE", str, apply_vae, format_value_add_label, None, 0.7),
|
AxisOption("VAE", str, apply_vae, cost=0.7, choices=lambda: list(sd_vae.vae_dict)),
|
||||||
AxisOption("Styles", str, apply_styles, format_value_add_label, None, 0),
|
AxisOption("Styles", str, apply_styles, choices=lambda: list(shared.prompt_styles.styles)),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -262,7 +278,7 @@ def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend, include_lone_
|
||||||
|
|
||||||
if not processed_result:
|
if not processed_result:
|
||||||
print("Unexpected error: draw_xy_grid failed to return even a single processed image")
|
print("Unexpected error: draw_xy_grid failed to return even a single processed image")
|
||||||
return Processed()
|
return Processed(p, [])
|
||||||
|
|
||||||
grid = images.image_grid(image_cache, rows=len(ys))
|
grid = images.image_grid(image_cache, rows=len(ys))
|
||||||
if draw_legend:
|
if draw_legend:
|
||||||
|
@ -302,23 +318,25 @@ class Script(scripts.Script):
|
||||||
return "X/Y plot"
|
return "X/Y plot"
|
||||||
|
|
||||||
def ui(self, is_img2img):
|
def ui(self, is_img2img):
|
||||||
current_axis_options = [x for x in axis_options if type(x) == AxisOption or type(x) == AxisOptionImg2Img and is_img2img]
|
current_axis_options = [x for x in axis_options if type(x) == AxisOption or x.is_img2img and is_img2img]
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column(scale=1, elem_id="xy_grid_button_column"):
|
|
||||||
swap_axes_button = gr.Button(value=up_down_arrow_symbol, elem_id="xy_grid_swap_axes")
|
|
||||||
with gr.Column(scale=19):
|
with gr.Column(scale=19):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
x_type = gr.Dropdown(label="X type", choices=[x.label for x in current_axis_options], value=current_axis_options[1].label, type="index", elem_id=self.elem_id("x_type"))
|
x_type = gr.Dropdown(label="X type", choices=[x.label for x in current_axis_options], value=current_axis_options[1].label, type="index", elem_id=self.elem_id("x_type"))
|
||||||
x_values = gr.Textbox(label="X values", lines=1, elem_id=self.elem_id("x_values"))
|
x_values = gr.Textbox(label="X values", lines=1, elem_id=self.elem_id("x_values"))
|
||||||
|
fill_x_button = ToolButton(value=fill_values_symbol, elem_id="xy_grid_fill_x_tool_button", visible=False)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
y_type = gr.Dropdown(label="Y type", choices=[x.label for x in current_axis_options], value=current_axis_options[0].label, type="index", elem_id=self.elem_id("y_type"))
|
y_type = gr.Dropdown(label="Y type", choices=[x.label for x in current_axis_options], value=current_axis_options[0].label, type="index", elem_id=self.elem_id("y_type"))
|
||||||
y_values = gr.Textbox(label="Y values", lines=1, elem_id=self.elem_id("y_values"))
|
y_values = gr.Textbox(label="Y values", lines=1, elem_id=self.elem_id("y_values"))
|
||||||
|
fill_y_button = ToolButton(value=fill_values_symbol, elem_id="xy_grid_fill_y_tool_button", visible=False)
|
||||||
draw_legend = gr.Checkbox(label='Draw legend', value=True, elem_id=self.elem_id("draw_legend"))
|
|
||||||
include_lone_images = gr.Checkbox(label='Include Separate Images', value=False, elem_id=self.elem_id("include_lone_images"))
|
with gr.Row(variant="compact"):
|
||||||
no_fixed_seeds = gr.Checkbox(label='Keep -1 for seeds', value=False, elem_id=self.elem_id("no_fixed_seeds"))
|
draw_legend = gr.Checkbox(label='Draw legend', value=True, elem_id=self.elem_id("draw_legend"))
|
||||||
|
include_lone_images = gr.Checkbox(label='Include Separate Images', value=False, elem_id=self.elem_id("include_lone_images"))
|
||||||
|
no_fixed_seeds = gr.Checkbox(label='Keep -1 for seeds', value=False, elem_id=self.elem_id("no_fixed_seeds"))
|
||||||
|
swap_axes_button = gr.Button(value="Swap axes", elem_id="xy_grid_swap_axes_button")
|
||||||
|
|
||||||
def swap_axes(x_type, x_values, y_type, y_values):
|
def swap_axes(x_type, x_values, y_type, y_values):
|
||||||
nonlocal current_axis_options
|
nonlocal current_axis_options
|
||||||
|
@ -327,6 +345,19 @@ class Script(scripts.Script):
|
||||||
swap_args = [x_type, x_values, y_type, y_values]
|
swap_args = [x_type, x_values, y_type, y_values]
|
||||||
swap_axes_button.click(swap_axes, inputs=swap_args, outputs=swap_args)
|
swap_axes_button.click(swap_axes, inputs=swap_args, outputs=swap_args)
|
||||||
|
|
||||||
|
def fill(x_type):
|
||||||
|
axis = axis_options[x_type]
|
||||||
|
return ", ".join(axis.choices()) if axis.choices else gr.update()
|
||||||
|
|
||||||
|
fill_x_button.click(fn=fill, inputs=[x_type], outputs=[x_values])
|
||||||
|
fill_y_button.click(fn=fill, inputs=[y_type], outputs=[y_values])
|
||||||
|
|
||||||
|
def select_axis(x_type):
|
||||||
|
return gr.Button.update(visible=axis_options[x_type].choices is not None)
|
||||||
|
|
||||||
|
x_type.change(fn=select_axis, inputs=[x_type], outputs=[fill_x_button])
|
||||||
|
y_type.change(fn=select_axis, inputs=[y_type], outputs=[fill_y_button])
|
||||||
|
|
||||||
return [x_type, x_values, y_type, y_values, draw_legend, include_lone_images, no_fixed_seeds]
|
return [x_type, x_values, y_type, y_values, draw_legend, include_lone_images, no_fixed_seeds]
|
||||||
|
|
||||||
def run(self, p, x_type, x_values, y_type, y_values, draw_legend, include_lone_images, no_fixed_seeds):
|
def run(self, p, x_type, x_values, y_type, y_values, draw_legend, include_lone_images, no_fixed_seeds):
|
||||||
|
|
12
style.css
12
style.css
|
@ -644,7 +644,7 @@ canvas[key="mask"] {
|
||||||
max-width: 2.5em;
|
max-width: 2.5em;
|
||||||
min-width: 2.5em !important;
|
min-width: 2.5em !important;
|
||||||
height: 2.4em;
|
height: 2.4em;
|
||||||
margin: 0.55em 0;
|
margin: 0.55em 0.7em 0.55em 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
#quicksettings .gr-button-tool{
|
#quicksettings .gr-button-tool{
|
||||||
|
@ -717,16 +717,6 @@ footer {
|
||||||
line-height: 2.4em;
|
line-height: 2.4em;
|
||||||
}
|
}
|
||||||
|
|
||||||
#xy_grid_button_column {
|
|
||||||
min-width: 38px !important;
|
|
||||||
}
|
|
||||||
|
|
||||||
#xy_grid_button_column button {
|
|
||||||
height: 100%;
|
|
||||||
margin-bottom: 0.7em;
|
|
||||||
margin-left: 1em;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* The following handles localization for right-to-left (RTL) languages like Arabic.
|
/* The following handles localization for right-to-left (RTL) languages like Arabic.
|
||||||
The rtl media type will only be activated by the logic in javascript/localization.js.
|
The rtl media type will only be activated by the logic in javascript/localization.js.
|
||||||
If you change anything above, you need to make sure it is RTL compliant by just running
|
If you change anything above, you need to make sure it is RTL compliant by just running
|
||||||
|
|
Loading…
Reference in a new issue