Add a prompt order option to XY plot script
This commit is contained in:
parent
5ef0baf5ea
commit
1c5604791d
1 changed files with 38 additions and 2 deletions
|
@ -1,5 +1,6 @@
|
|||
from collections import namedtuple
|
||||
from copy import copy
|
||||
from itertools import permutations
|
||||
import random
|
||||
|
||||
from PIL import Image
|
||||
|
@ -28,6 +29,27 @@ def apply_prompt(p, x, xs):
|
|||
p.prompt = p.prompt.replace(xs[0], x)
|
||||
p.negative_prompt = p.negative_prompt.replace(xs[0], x)
|
||||
|
||||
def apply_order(p, x, xs):
|
||||
token_order = []
|
||||
|
||||
# Initally grab the tokens from the prompt so they can be later be replaced in order of earliest seen in the prompt
|
||||
for token in x:
|
||||
token_order.append((p.prompt.find(token), token))
|
||||
|
||||
token_order.sort(key=lambda t: t[0])
|
||||
|
||||
search_from_pos = 0
|
||||
for idx, token in enumerate(x):
|
||||
original_pos, old_token = token_order[idx]
|
||||
|
||||
# Get position of the token again as it will likely change as tokens are being replaced
|
||||
pos = p.prompt.find(old_token)
|
||||
if original_pos >= 0:
|
||||
# Avoid trying to replace what was just replaced by searching later in the prompt string
|
||||
p.prompt = p.prompt[0:search_from_pos] + p.prompt[search_from_pos:].replace(old_token, token, 1)
|
||||
|
||||
search_from_pos = pos + len(token)
|
||||
|
||||
|
||||
samplers_dict = {}
|
||||
for i, sampler in enumerate(modules.sd_samplers.samplers):
|
||||
|
@ -60,7 +82,8 @@ def format_value_add_label(p, opt, x):
|
|||
def format_value(p, opt, x):
|
||||
if type(x) == float:
|
||||
x = round(x, 8)
|
||||
|
||||
if type(x) == type(list()):
|
||||
x = str(x)
|
||||
return x
|
||||
|
||||
def do_nothing(p, x, xs):
|
||||
|
@ -89,6 +112,7 @@ axis_options = [
|
|||
AxisOption("Sigma max", float, apply_field("s_tmax"), format_value_add_label),
|
||||
AxisOption("Sigma noise", float, apply_field("s_noise"), format_value_add_label),
|
||||
AxisOption("Eta", float, apply_field("eta"), format_value_add_label),
|
||||
AxisOption("Prompt order", type(list()), apply_order, format_value),
|
||||
AxisOptionImg2Img("Denoising", float, apply_field("denoising_strength"), format_value_add_label), # as it is now all AxisOptionImg2Img items must go after AxisOption ones
|
||||
]
|
||||
|
||||
|
@ -159,8 +183,12 @@ class Script(scripts.Script):
|
|||
if opt.label == 'Nothing':
|
||||
return [0]
|
||||
|
||||
if opt.type == type(list()):
|
||||
valslist = [x for x in vals]
|
||||
else:
|
||||
valslist = [x.strip() for x in vals.split(",")]
|
||||
|
||||
|
||||
if opt.type == int:
|
||||
valslist_ext = []
|
||||
|
||||
|
@ -212,9 +240,17 @@ class Script(scripts.Script):
|
|||
return valslist
|
||||
|
||||
x_opt = axis_options[x_type]
|
||||
|
||||
if x_opt.label == "Prompt order":
|
||||
x_values = list(permutations([x.strip() for x in x_values.split(",")]))
|
||||
|
||||
xs = process_axis(x_opt, x_values)
|
||||
|
||||
y_opt = axis_options[y_type]
|
||||
|
||||
if y_opt.label == "Prompt order":
|
||||
y_values = list(permutations([y.strip() for y in y_values.split(",")]))
|
||||
|
||||
ys = process_axis(y_opt, y_values)
|
||||
|
||||
def fix_axis_seeds(axis_opt, axis_list):
|
||||
|
|
Loading…
Reference in a new issue