xy_grid: Refactor confirm functions
This commit is contained in:
parent
7dba1c07cb
commit
2fffd4bddc
1 changed files with 39 additions and 34 deletions
|
@ -77,12 +77,26 @@ def apply_sampler(p, x, xs):
|
||||||
p.sampler_index = sampler_index
|
p.sampler_index = sampler_index
|
||||||
|
|
||||||
|
|
||||||
|
def confirm_samplers(p, xs):
|
||||||
|
samplers_dict = build_samplers_dict(p)
|
||||||
|
for x in xs:
|
||||||
|
if x.lower() not in samplers_dict.keys():
|
||||||
|
raise RuntimeError(f"Unknown sampler: {x}")
|
||||||
|
|
||||||
|
|
||||||
def apply_checkpoint(p, x, xs):
|
def apply_checkpoint(p, x, xs):
|
||||||
info = modules.sd_models.get_closet_checkpoint_match(x)
|
info = modules.sd_models.get_closet_checkpoint_match(x)
|
||||||
assert info is not None, f'Checkpoint for {x} not found'
|
if info is None:
|
||||||
|
raise RuntimeError(f"Unknown checkpoint: {x}")
|
||||||
modules.sd_models.reload_model_weights(shared.sd_model, info)
|
modules.sd_models.reload_model_weights(shared.sd_model, info)
|
||||||
|
|
||||||
|
|
||||||
|
def confirm_checkpoints(p, xs):
|
||||||
|
for x in xs:
|
||||||
|
if modules.sd_models.get_closet_checkpoint_match(x) is None:
|
||||||
|
raise RuntimeError(f"Unknown checkpoint: {x}")
|
||||||
|
|
||||||
|
|
||||||
def apply_hypernetwork(p, x, xs):
|
def apply_hypernetwork(p, x, xs):
|
||||||
if x.lower() in ["", "none"]:
|
if x.lower() in ["", "none"]:
|
||||||
name = None
|
name = None
|
||||||
|
@ -93,7 +107,7 @@ def apply_hypernetwork(p, x, xs):
|
||||||
hypernetwork.load_hypernetwork(name)
|
hypernetwork.load_hypernetwork(name)
|
||||||
|
|
||||||
|
|
||||||
def confirm_hypernetworks(xs):
|
def confirm_hypernetworks(p, xs):
|
||||||
for x in xs:
|
for x in xs:
|
||||||
if x.lower() in ["", "none"]:
|
if x.lower() in ["", "none"]:
|
||||||
continue
|
continue
|
||||||
|
@ -135,29 +149,29 @@ def str_permutations(x):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
AxisOption = namedtuple("AxisOption", ["label", "type", "apply", "format_value"])
|
AxisOption = namedtuple("AxisOption", ["label", "type", "apply", "format_value", "confirm"])
|
||||||
AxisOptionImg2Img = namedtuple("AxisOptionImg2Img", ["label", "type", "apply", "format_value"])
|
AxisOptionImg2Img = namedtuple("AxisOptionImg2Img", ["label", "type", "apply", "format_value", "confirm"])
|
||||||
|
|
||||||
|
|
||||||
axis_options = [
|
axis_options = [
|
||||||
AxisOption("Nothing", str, do_nothing, format_nothing),
|
AxisOption("Nothing", str, do_nothing, format_nothing, None),
|
||||||
AxisOption("Seed", int, apply_field("seed"), format_value_add_label),
|
AxisOption("Seed", int, apply_field("seed"), format_value_add_label, None),
|
||||||
AxisOption("Var. seed", int, apply_field("subseed"), format_value_add_label),
|
AxisOption("Var. seed", int, apply_field("subseed"), format_value_add_label, None),
|
||||||
AxisOption("Var. strength", float, apply_field("subseed_strength"), format_value_add_label),
|
AxisOption("Var. strength", float, apply_field("subseed_strength"), format_value_add_label, None),
|
||||||
AxisOption("Steps", int, apply_field("steps"), format_value_add_label),
|
AxisOption("Steps", int, apply_field("steps"), format_value_add_label, None),
|
||||||
AxisOption("CFG Scale", float, apply_field("cfg_scale"), format_value_add_label),
|
AxisOption("CFG Scale", float, apply_field("cfg_scale"), format_value_add_label, None),
|
||||||
AxisOption("Prompt S/R", str, apply_prompt, format_value),
|
AxisOption("Prompt S/R", str, apply_prompt, format_value, None),
|
||||||
AxisOption("Prompt order", str_permutations, apply_order, format_value_join_list),
|
AxisOption("Prompt order", str_permutations, apply_order, format_value_join_list, None),
|
||||||
AxisOption("Sampler", str, apply_sampler, format_value),
|
AxisOption("Sampler", str, apply_sampler, format_value, confirm_samplers),
|
||||||
AxisOption("Checkpoint name", str, apply_checkpoint, format_value),
|
AxisOption("Checkpoint name", str, apply_checkpoint, format_value, confirm_checkpoints),
|
||||||
AxisOption("Hypernetwork", str, apply_hypernetwork, format_value),
|
AxisOption("Hypernetwork", str, apply_hypernetwork, format_value, confirm_hypernetworks),
|
||||||
AxisOption("Sigma Churn", float, apply_field("s_churn"), format_value_add_label),
|
AxisOption("Sigma Churn", float, apply_field("s_churn"), format_value_add_label, None),
|
||||||
AxisOption("Sigma min", float, apply_field("s_tmin"), format_value_add_label),
|
AxisOption("Sigma min", float, apply_field("s_tmin"), format_value_add_label, None),
|
||||||
AxisOption("Sigma max", float, apply_field("s_tmax"), format_value_add_label),
|
AxisOption("Sigma max", float, apply_field("s_tmax"), format_value_add_label, None),
|
||||||
AxisOption("Sigma noise", float, apply_field("s_noise"), format_value_add_label),
|
AxisOption("Sigma noise", float, apply_field("s_noise"), format_value_add_label, None),
|
||||||
AxisOption("Eta", float, apply_field("eta"), format_value_add_label),
|
AxisOption("Eta", float, apply_field("eta"), format_value_add_label, None),
|
||||||
AxisOption("Clip skip", int, apply_clip_skip, format_value_add_label),
|
AxisOption("Clip skip", int, apply_clip_skip, format_value_add_label, None),
|
||||||
AxisOptionImg2Img("Denoising", float, apply_field("denoising_strength"), format_value_add_label), # as it is now all AxisOptionImg2Img items must go after AxisOption ones
|
AxisOptionImg2Img("Denoising", float, apply_field("denoising_strength"), format_value_add_label, None), # as it is now all AxisOptionImg2Img items must go after AxisOption ones
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -283,19 +297,10 @@ class Script(scripts.Script):
|
||||||
valslist = list(permutations(valslist))
|
valslist = list(permutations(valslist))
|
||||||
|
|
||||||
valslist = [opt.type(x) for x in valslist]
|
valslist = [opt.type(x) for x in valslist]
|
||||||
|
|
||||||
# Confirm options are valid before starting
|
# Confirm options are valid before starting
|
||||||
if opt.label == "Sampler":
|
if opt.confirm:
|
||||||
samplers_dict = build_samplers_dict(p)
|
opt.confirm(p, valslist)
|
||||||
for sampler_val in valslist:
|
|
||||||
if sampler_val.lower() not in samplers_dict.keys():
|
|
||||||
raise RuntimeError(f"Unknown sampler: {sampler_val}")
|
|
||||||
elif opt.label == "Checkpoint name":
|
|
||||||
for ckpt_val in valslist:
|
|
||||||
if modules.sd_models.get_closet_checkpoint_match(ckpt_val) is None:
|
|
||||||
raise RuntimeError(f"Checkpoint for {ckpt_val} not found")
|
|
||||||
elif opt.label == "Hypernetwork":
|
|
||||||
confirm_hypernetworks(valslist)
|
|
||||||
|
|
||||||
return valslist
|
return valslist
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue