Add a prompt order option to XY plot script

This commit is contained in:
DoTheSneedful 2022-10-03 22:20:09 -04:00 committed by AUTOMATIC1111
parent 5ef0baf5ea
commit 1c5604791d

View file

@ -1,5 +1,6 @@
from collections import namedtuple
from copy import copy
from itertools import permutations
import random
from PIL import Image
@ -28,6 +29,27 @@ def apply_prompt(p, x, xs):
p.prompt = p.prompt.replace(xs[0], x)
p.negative_prompt = p.negative_prompt.replace(xs[0], x)
def apply_order(p, x, xs):
token_order = []
# Initally grab the tokens from the prompt so they can be later be replaced in order of earliest seen in the prompt
for token in x:
token_order.append((p.prompt.find(token), token))
token_order.sort(key=lambda t: t[0])
search_from_pos = 0
for idx, token in enumerate(x):
original_pos, old_token = token_order[idx]
# Get position of the token again as it will likely change as tokens are being replaced
pos = p.prompt.find(old_token)
if original_pos >= 0:
# Avoid trying to replace what was just replaced by searching later in the prompt string
p.prompt = p.prompt[0:search_from_pos] + p.prompt[search_from_pos:].replace(old_token, token, 1)
search_from_pos = pos + len(token)
samplers_dict = {}
for i, sampler in enumerate(modules.sd_samplers.samplers):
@ -60,7 +82,8 @@ def format_value_add_label(p, opt, x):
def format_value(p, opt, x):
if type(x) == float:
x = round(x, 8)
if type(x) == type(list()):
x = str(x)
return x
def do_nothing(p, x, xs):
@ -89,6 +112,7 @@ axis_options = [
AxisOption("Sigma max", float, apply_field("s_tmax"), format_value_add_label),
AxisOption("Sigma noise", float, apply_field("s_noise"), format_value_add_label),
AxisOption("Eta", float, apply_field("eta"), format_value_add_label),
AxisOption("Prompt order", type(list()), apply_order, format_value),
AxisOptionImg2Img("Denoising", float, apply_field("denoising_strength"), format_value_add_label), # as it is now all AxisOptionImg2Img items must go after AxisOption ones
]
@ -159,7 +183,11 @@ class Script(scripts.Script):
if opt.label == 'Nothing':
return [0]
valslist = [x.strip() for x in vals.split(",")]
if opt.type == type(list()):
valslist = [x for x in vals]
else:
valslist = [x.strip() for x in vals.split(",")]
if opt.type == int:
valslist_ext = []
@ -212,9 +240,17 @@ class Script(scripts.Script):
return valslist
x_opt = axis_options[x_type]
if x_opt.label == "Prompt order":
x_values = list(permutations([x.strip() for x in x_values.split(",")]))
xs = process_axis(x_opt, x_values)
y_opt = axis_options[y_type]
if y_opt.label == "Prompt order":
y_values = list(permutations([y.strip() for y in y_values.split(",")]))
ys = process_axis(y_opt, y_values)
def fix_axis_seeds(axis_opt, axis_list):