X/Y plot support for switching checkpoints.
This commit is contained in:
parent
99585b3514
commit
304222ef94
3 changed files with 19 additions and 2 deletions
|
@ -127,9 +127,9 @@ def load_model():
|
||||||
return sd_model
|
return sd_model
|
||||||
|
|
||||||
|
|
||||||
def reload_model_weights(sd_model):
|
def reload_model_weights(sd_model, info=None):
|
||||||
from modules import lowvram, devices
|
from modules import lowvram, devices
|
||||||
checkpoint_info = select_checkpoint()
|
checkpoint_info = info or select_checkpoint()
|
||||||
|
|
||||||
if sd_model.sd_model_checkpint == checkpoint_info.filename:
|
if sd_model.sd_model_checkpint == checkpoint_info.filename:
|
||||||
return
|
return
|
||||||
|
|
|
@ -66,6 +66,8 @@ titles = {
|
||||||
"Style 2": "Style to apply; styles have components for both positive and negative prompts and apply to both",
|
"Style 2": "Style to apply; styles have components for both positive and negative prompts and apply to both",
|
||||||
"Apply style": "Insert selected styles into prompt fields",
|
"Apply style": "Insert selected styles into prompt fields",
|
||||||
"Create style": "Save current prompts as a style. If you add the token {prompt} to the text, the style use that as placeholder for your prompt when you use the style in the future.",
|
"Create style": "Save current prompts as a style. If you add the token {prompt} to the text, the style use that as placeholder for your prompt when you use the style in the future.",
|
||||||
|
|
||||||
|
"Checkpoint name": "Loads weights from checkpoint before making images. You can either use hash or a part of filename (as seen in settings) for checkpoint name. Recommended to use with Y axis for less switching.",
|
||||||
}
|
}
|
||||||
|
|
||||||
function gradioApp(){
|
function gradioApp(){
|
||||||
|
|
|
@ -10,7 +10,9 @@ import gradio as gr
|
||||||
from modules import images
|
from modules import images
|
||||||
from modules.processing import process_images, Processed
|
from modules.processing import process_images, Processed
|
||||||
from modules.shared import opts, cmd_opts, state
|
from modules.shared import opts, cmd_opts, state
|
||||||
|
import modules.shared as shared
|
||||||
import modules.sd_samplers
|
import modules.sd_samplers
|
||||||
|
import modules.sd_models
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
|
||||||
|
@ -41,6 +43,15 @@ def apply_sampler(p, x, xs):
|
||||||
p.sampler_index = sampler_index
|
p.sampler_index = sampler_index
|
||||||
|
|
||||||
|
|
||||||
|
def apply_checkpoint(p, x, xs):
|
||||||
|
applicable = [info for info in modules.sd_models.checkpoints_list.values() if x in info.title]
|
||||||
|
assert len(applicable) > 0, f'Checkpoint {x} for found'
|
||||||
|
|
||||||
|
info = applicable[0]
|
||||||
|
|
||||||
|
modules.sd_models.reload_model_weights(shared.sd_model, info)
|
||||||
|
|
||||||
|
|
||||||
def format_value_add_label(p, opt, x):
|
def format_value_add_label(p, opt, x):
|
||||||
if type(x) == float:
|
if type(x) == float:
|
||||||
x = round(x, 8)
|
x = round(x, 8)
|
||||||
|
@ -74,6 +85,7 @@ axis_options = [
|
||||||
AxisOption("CFG Scale", float, apply_field("cfg_scale"), format_value_add_label),
|
AxisOption("CFG Scale", float, apply_field("cfg_scale"), format_value_add_label),
|
||||||
AxisOption("Prompt S/R", str, apply_prompt, format_value),
|
AxisOption("Prompt S/R", str, apply_prompt, format_value),
|
||||||
AxisOption("Sampler", str, apply_sampler, format_value),
|
AxisOption("Sampler", str, apply_sampler, format_value),
|
||||||
|
AxisOption("Checkpoint name", str, apply_checkpoint, format_value),
|
||||||
AxisOptionImg2Img("Denoising", float, apply_field("denoising_strength"), format_value_add_label), # as it is now all AxisOptionImg2Img items must go after AxisOption ones
|
AxisOptionImg2Img("Denoising", float, apply_field("denoising_strength"), format_value_add_label), # as it is now all AxisOptionImg2Img items must go after AxisOption ones
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -215,4 +227,7 @@ class Script(scripts.Script):
|
||||||
if opts.grid_save:
|
if opts.grid_save:
|
||||||
images.save_image(processed.images[0], p.outpath_grids, "xy_grid", prompt=p.prompt, seed=processed.seed, grid=True, p=p)
|
images.save_image(processed.images[0], p.outpath_grids, "xy_grid", prompt=p.prompt, seed=processed.seed, grid=True, p=p)
|
||||||
|
|
||||||
|
# restore checkpoint in case it was changed by axes
|
||||||
|
modules.sd_models.reload_model_weights(shared.sd_model)
|
||||||
|
|
||||||
return processed
|
return processed
|
||||||
|
|
Loading…
Reference in a new issue