fix btoken hypernetworks in XY plot

This commit is contained in:
AUTOMATIC 2022-10-09 14:33:22 +03:00
parent 77a719648d
commit 542a3d3a4a
2 changed files with 8 additions and 8 deletions

View file

@ -49,15 +49,18 @@ def list_hypernetworks(path):
def load_hypernetwork(filename): def load_hypernetwork(filename):
print(f"Loading hypernetwork {filename}")
path = shared.hypernetworks.get(filename, None) path = shared.hypernetworks.get(filename, None)
if (path is not None): if path is not None:
print(f"Loading hypernetwork {filename}")
try: try:
shared.loaded_hypernetwork = Hypernetwork(path) shared.loaded_hypernetwork = Hypernetwork(path)
except Exception: except Exception:
print(f"Error loading hypernetwork {path}", file=sys.stderr) print(f"Error loading hypernetwork {path}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr) print(traceback.format_exc(), file=sys.stderr)
else: else:
if shared.loaded_hypernetwork is not None:
print(f"Unloading hypernetwork")
shared.loaded_hypernetwork = None shared.loaded_hypernetwork = None

View file

@ -10,7 +10,7 @@ import numpy as np
import modules.scripts as scripts import modules.scripts as scripts
import gradio as gr import gradio as gr
from modules import images from modules import images, hypernetwork
from modules.processing import process_images, Processed, get_correct_sampler from modules.processing import process_images, Processed, get_correct_sampler
from modules.shared import opts, cmd_opts, state from modules.shared import opts, cmd_opts, state
import modules.shared as shared import modules.shared as shared
@ -80,8 +80,7 @@ def apply_checkpoint(p, x, xs):
def apply_hypernetwork(p, x, xs): def apply_hypernetwork(p, x, xs):
hn = shared.hypernetworks.get(x, None) hypernetwork.load_hypernetwork(x)
opts.data["sd_hypernetwork"] = hn.name if hn is not None else 'None'
def format_value_add_label(p, opt, x): def format_value_add_label(p, opt, x):
@ -203,8 +202,6 @@ class Script(scripts.Script):
p.batch_size = 1 p.batch_size = 1
initial_hn = opts.sd_hypernetwork
def process_axis(opt, vals): def process_axis(opt, vals):
if opt.label == 'Nothing': if opt.label == 'Nothing':
return [0] return [0]
@ -321,6 +318,6 @@ class Script(scripts.Script):
# restore checkpoint in case it was changed by axes # restore checkpoint in case it was changed by axes
modules.sd_models.reload_model_weights(shared.sd_model) modules.sd_models.reload_model_weights(shared.sd_model)
opts.data["sd_hypernetwork"] = initial_hn hypernetwork.load_hypernetwork(opts.sd_hypernetwork)
return processed return processed