X/Y plot support for switching checkpoints.
This commit is contained in:
parent
99585b3514
commit
304222ef94
3 changed files with 19 additions and 2 deletions
|
@ -127,9 +127,9 @@ def load_model():
|
|||
return sd_model
|
||||
|
||||
|
||||
def reload_model_weights(sd_model):
|
||||
def reload_model_weights(sd_model, info=None):
|
||||
from modules import lowvram, devices
|
||||
checkpoint_info = select_checkpoint()
|
||||
checkpoint_info = info or select_checkpoint()
|
||||
|
||||
if sd_model.sd_model_checkpint == checkpoint_info.filename:
|
||||
return
|
||||
|
|
|
@ -66,6 +66,8 @@ titles = {
|
|||
"Style 2": "Style to apply; styles have components for both positive and negative prompts and apply to both",
|
||||
"Apply style": "Insert selected styles into prompt fields",
|
||||
"Create style": "Save current prompts as a style. If you add the token {prompt} to the text, the style use that as placeholder for your prompt when you use the style in the future.",
|
||||
|
||||
"Checkpoint name": "Loads weights from checkpoint before making images. You can either use hash or a part of filename (as seen in settings) for checkpoint name. Recommended to use with Y axis for less switching.",
|
||||
}
|
||||
|
||||
function gradioApp(){
|
||||
|
|
|
@ -10,7 +10,9 @@ import gradio as gr
|
|||
from modules import images
|
||||
from modules.processing import process_images, Processed
|
||||
from modules.shared import opts, cmd_opts, state
|
||||
import modules.shared as shared
|
||||
import modules.sd_samplers
|
||||
import modules.sd_models
|
||||
import re
|
||||
|
||||
|
||||
|
@ -41,6 +43,15 @@ def apply_sampler(p, x, xs):
|
|||
p.sampler_index = sampler_index
|
||||
|
||||
|
||||
def apply_checkpoint(p, x, xs):
|
||||
applicable = [info for info in modules.sd_models.checkpoints_list.values() if x in info.title]
|
||||
assert len(applicable) > 0, f'Checkpoint {x} for found'
|
||||
|
||||
info = applicable[0]
|
||||
|
||||
modules.sd_models.reload_model_weights(shared.sd_model, info)
|
||||
|
||||
|
||||
def format_value_add_label(p, opt, x):
|
||||
if type(x) == float:
|
||||
x = round(x, 8)
|
||||
|
@ -74,6 +85,7 @@ axis_options = [
|
|||
AxisOption("CFG Scale", float, apply_field("cfg_scale"), format_value_add_label),
|
||||
AxisOption("Prompt S/R", str, apply_prompt, format_value),
|
||||
AxisOption("Sampler", str, apply_sampler, format_value),
|
||||
AxisOption("Checkpoint name", str, apply_checkpoint, format_value),
|
||||
AxisOptionImg2Img("Denoising", float, apply_field("denoising_strength"), format_value_add_label), # as it is now all AxisOptionImg2Img items must go after AxisOption ones
|
||||
]
|
||||
|
||||
|
@ -215,4 +227,7 @@ class Script(scripts.Script):
|
|||
if opts.grid_save:
|
||||
images.save_image(processed.images[0], p.outpath_grids, "xy_grid", prompt=p.prompt, seed=processed.seed, grid=True, p=p)
|
||||
|
||||
# restore checkpoint in case it was changed by axes
|
||||
modules.sd_models.reload_model_weights(shared.sd_model)
|
||||
|
||||
return processed
|
||||
|
|
Loading…
Reference in a new issue