Merge branch 'master' into esrgan_mod

This commit is contained in:
victorca25 2022-10-09 14:11:22 +02:00 committed by GitHub
commit 53154ba10a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 13 additions and 10 deletions

View file

@ -49,15 +49,18 @@ def list_hypernetworks(path):
def load_hypernetwork(filename): def load_hypernetwork(filename):
print(f"Loading hypernetwork {filename}")
path = shared.hypernetworks.get(filename, None) path = shared.hypernetworks.get(filename, None)
if (path is not None): if path is not None:
print(f"Loading hypernetwork {filename}")
try: try:
shared.loaded_hypernetwork = Hypernetwork(path) shared.loaded_hypernetwork = Hypernetwork(path)
except Exception: except Exception:
print(f"Error loading hypernetwork {path}", file=sys.stderr) print(f"Error loading hypernetwork {path}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr) print(traceback.format_exc(), file=sys.stderr)
else: else:
if shared.loaded_hypernetwork is not None:
print(f"Unloading hypernetwork")
shared.loaded_hypernetwork = None shared.loaded_hypernetwork = None

View file

@ -284,6 +284,8 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
"Face restoration": (opts.face_restoration_model if p.restore_faces else None), "Face restoration": (opts.face_restoration_model if p.restore_faces else None),
"Size": f"{p.width}x{p.height}", "Size": f"{p.width}x{p.height}",
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash), "Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
"Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
"Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name.replace(',', '').replace(':', '')),
"Batch size": (None if p.batch_size < 2 else p.batch_size), "Batch size": (None if p.batch_size < 2 else p.batch_size),
"Batch pos": (None if p.batch_size < 2 else position_in_batch), "Batch pos": (None if p.batch_size < 2 else position_in_batch),
"Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]), "Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),

View file

@ -5,7 +5,6 @@ from collections import namedtuple
import torch import torch
from omegaconf import OmegaConf from omegaconf import OmegaConf
from ldm.util import instantiate_from_config from ldm.util import instantiate_from_config
from modules import shared, modelloader, devices from modules import shared, modelloader, devices

View file

@ -242,6 +242,7 @@ options_templates.update(options_section(('ui', "User interface"), {
"return_grid": OptionInfo(True, "Show grid in results for web"), "return_grid": OptionInfo(True, "Show grid in results for web"),
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"), "do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"), "add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
"add_model_name_to_info": OptionInfo(False, "Add model name to generation information"),
"font": OptionInfo("", "Font for image grids that have text"), "font": OptionInfo("", "Font for image grids that have text"),
"js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"), "js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"),
"js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"), "js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),

View file

@ -37,7 +37,7 @@ class Upscaler:
self.pre_pad = 0 self.pre_pad = 0
self.mod_scale = None self.mod_scale = None
if self.model_path is not None and self.name: if self.model_path is None and self.name:
self.model_path = os.path.join(models_path, self.name) self.model_path = os.path.join(models_path, self.name)
if self.model_path and create_dirs: if self.model_path and create_dirs:
os.makedirs(self.model_path, exist_ok=True) os.makedirs(self.model_path, exist_ok=True)

View file

@ -10,7 +10,7 @@ import numpy as np
import modules.scripts as scripts import modules.scripts as scripts
import gradio as gr import gradio as gr
from modules import images from modules import images, hypernetwork
from modules.processing import process_images, Processed, get_correct_sampler from modules.processing import process_images, Processed, get_correct_sampler
from modules.shared import opts, cmd_opts, state from modules.shared import opts, cmd_opts, state
import modules.shared as shared import modules.shared as shared
@ -80,8 +80,7 @@ def apply_checkpoint(p, x, xs):
def apply_hypernetwork(p, x, xs): def apply_hypernetwork(p, x, xs):
hn = shared.hypernetworks.get(x, None) hypernetwork.load_hypernetwork(x)
opts.data["sd_hypernetwork"] = hn.name if hn is not None else 'None'
def format_value_add_label(p, opt, x): def format_value_add_label(p, opt, x):
@ -203,8 +202,6 @@ class Script(scripts.Script):
p.batch_size = 1 p.batch_size = 1
initial_hn = opts.sd_hypernetwork
def process_axis(opt, vals): def process_axis(opt, vals):
if opt.label == 'Nothing': if opt.label == 'Nothing':
return [0] return [0]
@ -262,6 +259,7 @@ class Script(scripts.Script):
# Confirm options are valid before starting # Confirm options are valid before starting
if opt.label == "Sampler": if opt.label == "Sampler":
samplers_dict = build_samplers_dict(p)
for sampler_val in valslist: for sampler_val in valslist:
if sampler_val.lower() not in samplers_dict.keys(): if sampler_val.lower() not in samplers_dict.keys():
raise RuntimeError(f"Unknown sampler: {sampler_val}") raise RuntimeError(f"Unknown sampler: {sampler_val}")
@ -321,6 +319,6 @@ class Script(scripts.Script):
# restore checkpoint in case it was changed by axes # restore checkpoint in case it was changed by axes
modules.sd_models.reload_model_weights(shared.sd_model) modules.sd_models.reload_model_weights(shared.sd_model)
opts.data["sd_hypernetwork"] = initial_hn hypernetwork.load_hypernetwork(opts.sd_hypernetwork)
return processed return processed