Fix display and save order for X/Y/Z Grid script
This commit is contained in:
parent
0cc0ee1bcb
commit
2d9635cce5
1 changed files with 69 additions and 56 deletions
|
@ -25,8 +25,6 @@ from modules.ui_components import ToolButton
|
||||||
|
|
||||||
fill_values_symbol = "\U0001f4d2" # 📒
|
fill_values_symbol = "\U0001f4d2" # 📒
|
||||||
|
|
||||||
AxisInfo = namedtuple('AxisInfo', ['axis', 'values'])
|
|
||||||
|
|
||||||
|
|
||||||
def apply_field(field):
|
def apply_field(field):
|
||||||
def fun(p, x, xs):
|
def fun(p, x, xs):
|
||||||
|
@ -188,7 +186,6 @@ axis_options = [
|
||||||
AxisOption("Steps", int, apply_field("steps")),
|
AxisOption("Steps", int, apply_field("steps")),
|
||||||
AxisOptionTxt2Img("Hires steps", int, apply_field("hr_second_pass_steps")),
|
AxisOptionTxt2Img("Hires steps", int, apply_field("hr_second_pass_steps")),
|
||||||
AxisOption("CFG Scale", float, apply_field("cfg_scale")),
|
AxisOption("CFG Scale", float, apply_field("cfg_scale")),
|
||||||
AxisOptionImg2Img("Image CFG Scale", float, apply_field("image_cfg_scale")),
|
|
||||||
AxisOption("Prompt S/R", str, apply_prompt, format_value=format_value),
|
AxisOption("Prompt S/R", str, apply_prompt, format_value=format_value),
|
||||||
AxisOption("Prompt order", str_permutations, apply_order, format_value=format_value_join_list),
|
AxisOption("Prompt order", str_permutations, apply_order, format_value=format_value_join_list),
|
||||||
AxisOptionTxt2Img("Sampler", str, apply_sampler, format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers]),
|
AxisOptionTxt2Img("Sampler", str, apply_sampler, format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers]),
|
||||||
|
@ -213,49 +210,47 @@ def draw_xyz_grid(p, xs, ys, zs, x_labels, y_labels, z_labels, cell, draw_legend
|
||||||
ver_texts = [[images.GridAnnotation(y)] for y in y_labels]
|
ver_texts = [[images.GridAnnotation(y)] for y in y_labels]
|
||||||
title_texts = [[images.GridAnnotation(z)] for z in z_labels]
|
title_texts = [[images.GridAnnotation(z)] for z in z_labels]
|
||||||
|
|
||||||
# Temporary list of all the images that are generated to be populated into the grid.
|
list_size = (len(xs) * len(ys) * len(zs))
|
||||||
# Will be filled with empty images for any individual step that fails to process properly
|
|
||||||
image_cache = [None] * (len(xs) * len(ys) * len(zs))
|
|
||||||
|
|
||||||
processed_result = None
|
processed_result = None
|
||||||
cell_mode = "P"
|
|
||||||
cell_size = (1, 1)
|
|
||||||
|
|
||||||
state.job_count = len(xs) * len(ys) * len(zs) * p.n_iter
|
state.job_count = list_size * p.n_iter
|
||||||
|
|
||||||
def process_cell(x, y, z, ix, iy, iz):
|
def process_cell(x, y, z, ix, iy, iz):
|
||||||
nonlocal image_cache, processed_result, cell_mode, cell_size
|
nonlocal processed_result
|
||||||
|
|
||||||
def index(ix, iy, iz):
|
def index(ix, iy, iz):
|
||||||
return ix + iy * len(xs) + iz * len(xs) * len(ys)
|
return ix + iy * len(xs) + iz * len(xs) * len(ys)
|
||||||
|
|
||||||
state.job = f"{index(ix, iy, iz) + 1} out of {len(xs) * len(ys) * len(zs)}"
|
state.job = f"{index(ix, iy, iz) + 1} out of {list_size}"
|
||||||
|
|
||||||
processed: Processed = cell(x, y, z)
|
processed: Processed = cell(x, y, z)
|
||||||
|
|
||||||
try:
|
|
||||||
# this dereference will throw an exception if the image was not processed
|
|
||||||
# (this happens in cases such as if the user stops the process from the UI)
|
|
||||||
processed_image = processed.images[0]
|
|
||||||
|
|
||||||
if processed_result is None:
|
if processed_result is None:
|
||||||
# Use our first valid processed result as a template container to hold our full results
|
# Use our first processed result object as a template container to hold our full results
|
||||||
processed_result = copy(processed)
|
processed_result = copy(processed)
|
||||||
cell_mode = processed_image.mode
|
processed_result.images = [None] * list_size
|
||||||
cell_size = processed_image.size
|
processed_result.all_prompts = [None] * list_size
|
||||||
processed_result.images = [Image.new(cell_mode, cell_size)]
|
processed_result.all_seeds = [None] * list_size
|
||||||
processed_result.all_prompts = [processed.prompt]
|
processed_result.infotexts = [None] * list_size
|
||||||
processed_result.all_seeds = [processed.seed]
|
processed_result.index_of_first_image = 0
|
||||||
processed_result.infotexts = [processed.infotexts[0]]
|
|
||||||
|
idx = index(ix, iy, iz)
|
||||||
|
if processed.images:
|
||||||
|
# Non-empty list indicates some degree of success.
|
||||||
|
processed_result.images[idx] = processed.images[0]
|
||||||
|
processed_result.all_prompts[idx] = processed.prompt
|
||||||
|
processed_result.all_seeds[idx] = processed.seed
|
||||||
|
processed_result.infotexts[idx] = processed.infotexts[0]
|
||||||
|
else:
|
||||||
|
cell_mode = "P"
|
||||||
|
cell_size = (processed_result.width, processed_result.height)
|
||||||
|
if processed_result.images[0] is not None:
|
||||||
|
cell_mode = processed_result.images[0].mode
|
||||||
|
#This corrects size in case of batches:
|
||||||
|
cell_size = processed_result.images[0].size
|
||||||
|
processed_result.images[idx] = Image.new(cell_mode, cell_size)
|
||||||
|
|
||||||
image_cache[index(ix, iy, iz)] = processed_image
|
|
||||||
if include_lone_images:
|
|
||||||
processed_result.images.append(processed_image)
|
|
||||||
processed_result.all_prompts.append(processed.prompt)
|
|
||||||
processed_result.all_seeds.append(processed.seed)
|
|
||||||
processed_result.infotexts.append(processed.infotexts[0])
|
|
||||||
except:
|
|
||||||
image_cache[index(ix, iy, iz)] = Image.new(cell_mode, cell_size)
|
|
||||||
|
|
||||||
if first_axes_processed == 'x':
|
if first_axes_processed == 'x':
|
||||||
for ix, x in enumerate(xs):
|
for ix, x in enumerate(xs):
|
||||||
|
@ -289,27 +284,36 @@ def draw_xyz_grid(p, xs, ys, zs, x_labels, y_labels, z_labels, cell, draw_legend
|
||||||
process_cell(x, y, z, ix, iy, iz)
|
process_cell(x, y, z, ix, iy, iz)
|
||||||
|
|
||||||
if not processed_result:
|
if not processed_result:
|
||||||
|
# Should never happen, I've only seen it on one of four open tabs and it needed to refresh.
|
||||||
|
print("Unexpected error: Processing could not begin, you may need to refresh the tab or restart the service.")
|
||||||
|
return Processed(p, [])
|
||||||
|
elif not any(processed_result.images):
|
||||||
print("Unexpected error: draw_xyz_grid failed to return even a single processed image")
|
print("Unexpected error: draw_xyz_grid failed to return even a single processed image")
|
||||||
return Processed(p, [])
|
return Processed(p, [])
|
||||||
|
|
||||||
sub_grids = [None] * len(zs)
|
z_count = len(zs)
|
||||||
for i in range(len(zs)):
|
sub_grids = [None] * z_count
|
||||||
start_index = i * len(xs) * len(ys)
|
for i in range(z_count):
|
||||||
|
start_index = (i * len(xs) * len(ys)) + i
|
||||||
end_index = start_index + len(xs) * len(ys)
|
end_index = start_index + len(xs) * len(ys)
|
||||||
grid = images.image_grid(image_cache[start_index:end_index], rows=len(ys))
|
grid = images.image_grid(processed_result.images[start_index:end_index], rows=len(ys))
|
||||||
if draw_legend:
|
if draw_legend:
|
||||||
grid = images.draw_grid_annotations(grid, cell_size[0], cell_size[1], hor_texts, ver_texts, margin_size)
|
grid = images.draw_grid_annotations(grid, processed_result.images[start_index].size[0], processed_result.images[start_index].size[1], hor_texts, ver_texts, margin_size)
|
||||||
sub_grids[i] = grid
|
processed_result.images.insert(i, grid)
|
||||||
if include_sub_grids and len(zs) > 1:
|
processed_result.all_prompts.insert(i, processed_result.all_prompts[start_index])
|
||||||
processed_result.images.insert(i+1, grid)
|
processed_result.all_seeds.insert(i, processed_result.all_seeds[start_index])
|
||||||
|
processed_result.infotexts.insert(i, processed_result.infotexts[start_index])
|
||||||
|
|
||||||
sub_grid_size = sub_grids[0].size
|
sub_grid_size = processed_result.images[0].size
|
||||||
z_grid = images.image_grid(sub_grids, rows=1)
|
z_grid = images.image_grid(processed_result.images[:z_count], rows=1)
|
||||||
if draw_legend:
|
if draw_legend:
|
||||||
z_grid = images.draw_grid_annotations(z_grid, sub_grid_size[0], sub_grid_size[1], title_texts, [[images.GridAnnotation()]])
|
z_grid = images.draw_grid_annotations(z_grid, sub_grid_size[0], sub_grid_size[1], title_texts, [[images.GridAnnotation()]])
|
||||||
processed_result.images[0] = z_grid
|
processed_result.images.insert(0, z_grid)
|
||||||
|
processed_result.all_prompts.insert(0, processed_result.all_prompts[0])
|
||||||
|
processed_result.all_seeds.insert(0, processed_result.all_seeds[0])
|
||||||
|
processed_result.infotexts.insert(0, processed_result.infotexts[0])
|
||||||
|
|
||||||
return processed_result, sub_grids
|
return processed_result
|
||||||
|
|
||||||
|
|
||||||
class SharedSettingsStackHelper(object):
|
class SharedSettingsStackHelper(object):
|
||||||
|
@ -364,7 +368,7 @@ class Script(scripts.Script):
|
||||||
include_lone_images = gr.Checkbox(label='Include Sub Images', value=False, elem_id=self.elem_id("include_lone_images"))
|
include_lone_images = gr.Checkbox(label='Include Sub Images', value=False, elem_id=self.elem_id("include_lone_images"))
|
||||||
include_sub_grids = gr.Checkbox(label='Include Sub Grids', value=False, elem_id=self.elem_id("include_sub_grids"))
|
include_sub_grids = gr.Checkbox(label='Include Sub Grids', value=False, elem_id=self.elem_id("include_sub_grids"))
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
margin_size = gr.Slider(label="Grid margins (px)", minimum=0, maximum=500, value=0, step=2, elem_id=self.elem_id("margin_size"))
|
margin_size = gr.Slider(label="Grid margins (px)", min=0, max=500, value=0, step=2, elem_id=self.elem_id("margin_size"))
|
||||||
|
|
||||||
with gr.Row(variant="compact", elem_id="swap_axes"):
|
with gr.Row(variant="compact", elem_id="swap_axes"):
|
||||||
swap_xy_axes_button = gr.Button(value="Swap X/Y axes", elem_id="xy_grid_swap_axes_button")
|
swap_xy_axes_button = gr.Button(value="Swap X/Y axes", elem_id="xy_grid_swap_axes_button")
|
||||||
|
@ -526,14 +530,10 @@ class Script(scripts.Script):
|
||||||
|
|
||||||
grid_infotext = [None]
|
grid_infotext = [None]
|
||||||
|
|
||||||
state.xyz_plot_x = AxisInfo(x_opt, xs)
|
|
||||||
state.xyz_plot_y = AxisInfo(y_opt, ys)
|
|
||||||
state.xyz_plot_z = AxisInfo(z_opt, zs)
|
|
||||||
|
|
||||||
# If one of the axes is very slow to change between (like SD model
|
# If one of the axes is very slow to change between (like SD model
|
||||||
# checkpoint), then make sure it is in the outer iteration of the nested
|
# checkpoint), then make sure it is in the outer iteration of the nested
|
||||||
# `for` loop.
|
# `for` loop.
|
||||||
first_axes_processed = 'x'
|
first_axes_processed = 'z'
|
||||||
second_axes_processed = 'y'
|
second_axes_processed = 'y'
|
||||||
if x_opt.cost > y_opt.cost and x_opt.cost > z_opt.cost:
|
if x_opt.cost > y_opt.cost and x_opt.cost > z_opt.cost:
|
||||||
first_axes_processed = 'x'
|
first_axes_processed = 'x'
|
||||||
|
@ -593,7 +593,7 @@ class Script(scripts.Script):
|
||||||
return res
|
return res
|
||||||
|
|
||||||
with SharedSettingsStackHelper():
|
with SharedSettingsStackHelper():
|
||||||
processed, sub_grids = draw_xyz_grid(
|
processed = draw_xyz_grid(
|
||||||
p,
|
p,
|
||||||
xs=xs,
|
xs=xs,
|
||||||
ys=ys,
|
ys=ys,
|
||||||
|
@ -610,11 +610,24 @@ class Script(scripts.Script):
|
||||||
margin_size=margin_size
|
margin_size=margin_size
|
||||||
)
|
)
|
||||||
|
|
||||||
if opts.grid_save and len(sub_grids) > 1:
|
z_count = len(zs)
|
||||||
for sub_grid in sub_grids:
|
|
||||||
images.save_image(sub_grid, p.outpath_grids, "xyz_grid", info=grid_infotext[0], extension=opts.grid_format, prompt=p.prompt, seed=processed.seed, grid=True, p=p)
|
|
||||||
|
|
||||||
if opts.grid_save:
|
if not include_lone_images:
|
||||||
images.save_image(processed.images[0], p.outpath_grids, "xyz_grid", info=grid_infotext[0], extension=opts.grid_format, prompt=p.prompt, seed=processed.seed, grid=True, p=p)
|
# Don't need sub-images anymore, drop from list:
|
||||||
|
processed.images = processed.images[:z_count+1]
|
||||||
|
|
||||||
|
if opts.grid_save and processed.images:
|
||||||
|
# Auto-save main and sub-grids:
|
||||||
|
grid_count = z_count + 1 if z_count > 1 else 1
|
||||||
|
for g in range(grid_count):
|
||||||
|
images.save_image(processed.images[g], p.outpath_grids, "xyz_grid", info=processed.infotexts[g], extension=opts.grid_format, prompt=processed.all_prompts[g], seed=processed.all_seeds[g], grid=True, p=processed)
|
||||||
|
|
||||||
|
if not include_sub_grids:
|
||||||
|
# Done with sub-grids, drop all related information:
|
||||||
|
for sg in range(z_count):
|
||||||
|
del processed.images[1]
|
||||||
|
del processed.all_prompts[1]
|
||||||
|
del processed.all_seeds[1]
|
||||||
|
del processed.infotexts[1]
|
||||||
|
|
||||||
return processed
|
return processed
|
||||||
|
|
Loading…
Reference in a new issue