xy_grid: Refactor confirm functions
This commit is contained in:
parent
7dba1c07cb
commit
2fffd4bddc
1 changed files with 39 additions and 34 deletions
|
@ -77,12 +77,26 @@ def apply_sampler(p, x, xs):
|
|||
p.sampler_index = sampler_index
|
||||
|
||||
|
||||
def confirm_samplers(p, xs):
|
||||
samplers_dict = build_samplers_dict(p)
|
||||
for x in xs:
|
||||
if x.lower() not in samplers_dict.keys():
|
||||
raise RuntimeError(f"Unknown sampler: {x}")
|
||||
|
||||
|
||||
def apply_checkpoint(p, x, xs):
|
||||
info = modules.sd_models.get_closet_checkpoint_match(x)
|
||||
assert info is not None, f'Checkpoint for {x} not found'
|
||||
if info is None:
|
||||
raise RuntimeError(f"Unknown checkpoint: {x}")
|
||||
modules.sd_models.reload_model_weights(shared.sd_model, info)
|
||||
|
||||
|
||||
def confirm_checkpoints(p, xs):
|
||||
for x in xs:
|
||||
if modules.sd_models.get_closet_checkpoint_match(x) is None:
|
||||
raise RuntimeError(f"Unknown checkpoint: {x}")
|
||||
|
||||
|
||||
def apply_hypernetwork(p, x, xs):
|
||||
if x.lower() in ["", "none"]:
|
||||
name = None
|
||||
|
@ -93,7 +107,7 @@ def apply_hypernetwork(p, x, xs):
|
|||
hypernetwork.load_hypernetwork(name)
|
||||
|
||||
|
||||
def confirm_hypernetworks(xs):
|
||||
def confirm_hypernetworks(p, xs):
|
||||
for x in xs:
|
||||
if x.lower() in ["", "none"]:
|
||||
continue
|
||||
|
@ -135,29 +149,29 @@ def str_permutations(x):
|
|||
return x
|
||||
|
||||
|
||||
AxisOption = namedtuple("AxisOption", ["label", "type", "apply", "format_value"])
|
||||
AxisOptionImg2Img = namedtuple("AxisOptionImg2Img", ["label", "type", "apply", "format_value"])
|
||||
AxisOption = namedtuple("AxisOption", ["label", "type", "apply", "format_value", "confirm"])
|
||||
AxisOptionImg2Img = namedtuple("AxisOptionImg2Img", ["label", "type", "apply", "format_value", "confirm"])
|
||||
|
||||
|
||||
axis_options = [
|
||||
AxisOption("Nothing", str, do_nothing, format_nothing),
|
||||
AxisOption("Seed", int, apply_field("seed"), format_value_add_label),
|
||||
AxisOption("Var. seed", int, apply_field("subseed"), format_value_add_label),
|
||||
AxisOption("Var. strength", float, apply_field("subseed_strength"), format_value_add_label),
|
||||
AxisOption("Steps", int, apply_field("steps"), format_value_add_label),
|
||||
AxisOption("CFG Scale", float, apply_field("cfg_scale"), format_value_add_label),
|
||||
AxisOption("Prompt S/R", str, apply_prompt, format_value),
|
||||
AxisOption("Prompt order", str_permutations, apply_order, format_value_join_list),
|
||||
AxisOption("Sampler", str, apply_sampler, format_value),
|
||||
AxisOption("Checkpoint name", str, apply_checkpoint, format_value),
|
||||
AxisOption("Hypernetwork", str, apply_hypernetwork, format_value),
|
||||
AxisOption("Sigma Churn", float, apply_field("s_churn"), format_value_add_label),
|
||||
AxisOption("Sigma min", float, apply_field("s_tmin"), format_value_add_label),
|
||||
AxisOption("Sigma max", float, apply_field("s_tmax"), format_value_add_label),
|
||||
AxisOption("Sigma noise", float, apply_field("s_noise"), format_value_add_label),
|
||||
AxisOption("Eta", float, apply_field("eta"), format_value_add_label),
|
||||
AxisOption("Clip skip", int, apply_clip_skip, format_value_add_label),
|
||||
AxisOptionImg2Img("Denoising", float, apply_field("denoising_strength"), format_value_add_label), # as it is now all AxisOptionImg2Img items must go after AxisOption ones
|
||||
AxisOption("Nothing", str, do_nothing, format_nothing, None),
|
||||
AxisOption("Seed", int, apply_field("seed"), format_value_add_label, None),
|
||||
AxisOption("Var. seed", int, apply_field("subseed"), format_value_add_label, None),
|
||||
AxisOption("Var. strength", float, apply_field("subseed_strength"), format_value_add_label, None),
|
||||
AxisOption("Steps", int, apply_field("steps"), format_value_add_label, None),
|
||||
AxisOption("CFG Scale", float, apply_field("cfg_scale"), format_value_add_label, None),
|
||||
AxisOption("Prompt S/R", str, apply_prompt, format_value, None),
|
||||
AxisOption("Prompt order", str_permutations, apply_order, format_value_join_list, None),
|
||||
AxisOption("Sampler", str, apply_sampler, format_value, confirm_samplers),
|
||||
AxisOption("Checkpoint name", str, apply_checkpoint, format_value, confirm_checkpoints),
|
||||
AxisOption("Hypernetwork", str, apply_hypernetwork, format_value, confirm_hypernetworks),
|
||||
AxisOption("Sigma Churn", float, apply_field("s_churn"), format_value_add_label, None),
|
||||
AxisOption("Sigma min", float, apply_field("s_tmin"), format_value_add_label, None),
|
||||
AxisOption("Sigma max", float, apply_field("s_tmax"), format_value_add_label, None),
|
||||
AxisOption("Sigma noise", float, apply_field("s_noise"), format_value_add_label, None),
|
||||
AxisOption("Eta", float, apply_field("eta"), format_value_add_label, None),
|
||||
AxisOption("Clip skip", int, apply_clip_skip, format_value_add_label, None),
|
||||
AxisOptionImg2Img("Denoising", float, apply_field("denoising_strength"), format_value_add_label, None), # as it is now all AxisOptionImg2Img items must go after AxisOption ones
|
||||
]
|
||||
|
||||
|
||||
|
@ -283,19 +297,10 @@ class Script(scripts.Script):
|
|||
valslist = list(permutations(valslist))
|
||||
|
||||
valslist = [opt.type(x) for x in valslist]
|
||||
|
||||
|
||||
# Confirm options are valid before starting
|
||||
if opt.label == "Sampler":
|
||||
samplers_dict = build_samplers_dict(p)
|
||||
for sampler_val in valslist:
|
||||
if sampler_val.lower() not in samplers_dict.keys():
|
||||
raise RuntimeError(f"Unknown sampler: {sampler_val}")
|
||||
elif opt.label == "Checkpoint name":
|
||||
for ckpt_val in valslist:
|
||||
if modules.sd_models.get_closet_checkpoint_match(ckpt_val) is None:
|
||||
raise RuntimeError(f"Checkpoint for {ckpt_val} not found")
|
||||
elif opt.label == "Hypernetwork":
|
||||
confirm_hypernetworks(valslist)
|
||||
if opt.confirm:
|
||||
opt.confirm(p, valslist)
|
||||
|
||||
return valslist
|
||||
|
||||
|
|
Loading…
Reference in a new issue