Safeguard setting restore logic against exceptions

also useful for keeping settings cache and restore logic together, and nice for code reuse (other third party scripts can import this class)
This commit is contained in:
Greg Fuller 2022-10-16 12:10:07 -07:00 committed by AUTOMATIC1111
parent 62edfae257
commit cccc5a20fc

View file

@ -233,6 +233,21 @@ def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend, include_lone_
return processed_result return processed_result
class SharedSettingsStackHelper(object):
def __enter__(self):
self.CLIP_stop_at_last_layers = opts.CLIP_stop_at_last_layers
self.hypernetwork = opts.sd_hypernetwork
self.model = shared.sd_model
def __exit__(self, exc_type, exc_value, tb):
modules.sd_models.reload_model_weights(self.model)
hypernetwork.load_hypernetwork(self.hypernetwork)
hypernetwork.apply_strength()
opts.data["CLIP_stop_at_last_layers"] = self.CLIP_stop_at_last_layers
re_range = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\(([+-]\d+)\s*\))?\s*") re_range = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\(([+-]\d+)\s*\))?\s*")
re_range_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\(([+-]\d+(?:.\d*)?)\s*\))?\s*") re_range_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\(([+-]\d+(?:.\d*)?)\s*\))?\s*")
@ -267,9 +282,6 @@ class Script(scripts.Script):
if not opts.return_grid: if not opts.return_grid:
p.batch_size = 1 p.batch_size = 1
CLIP_stop_at_last_layers = opts.CLIP_stop_at_last_layers
def process_axis(opt, vals): def process_axis(opt, vals):
if opt.label == 'Nothing': if opt.label == 'Nothing':
return [0] return [0]
@ -367,27 +379,19 @@ class Script(scripts.Script):
return process_images(pc) return process_images(pc)
processed = draw_xy_grid( with SharedSettingsStackHelper():
p, processed = draw_xy_grid(
xs=xs, p,
ys=ys, xs=xs,
x_labels=[x_opt.format_value(p, x_opt, x) for x in xs], ys=ys,
y_labels=[y_opt.format_value(p, y_opt, y) for y in ys], x_labels=[x_opt.format_value(p, x_opt, x) for x in xs],
cell=cell, y_labels=[y_opt.format_value(p, y_opt, y) for y in ys],
draw_legend=draw_legend, cell=cell,
include_lone_images=include_lone_images draw_legend=draw_legend,
) include_lone_images=include_lone_images
)
if opts.grid_save: if opts.grid_save:
images.save_image(processed.images[0], p.outpath_grids, "xy_grid", prompt=p.prompt, seed=processed.seed, grid=True, p=p) images.save_image(processed.images[0], p.outpath_grids, "xy_grid", prompt=p.prompt, seed=processed.seed, grid=True, p=p)
# restore checkpoint in case it was changed by axes
modules.sd_models.reload_model_weights(shared.sd_model)
hypernetwork.load_hypernetwork(opts.sd_hypernetwork)
hypernetwork.apply_strength()
opts.data["CLIP_stop_at_last_layers"] = CLIP_stop_at_last_layers
return processed return processed