Merge branch 'master' into fix/encode-pnginfo

This commit is contained in:
evshiron 2022-11-05 22:12:20 +08:00
commit b6cfaaa20b
10 changed files with 132 additions and 37 deletions

View file

@ -231,6 +231,10 @@ class Api:
return options return options
def set_config(self, req: OptionsModel): def set_config(self, req: OptionsModel):
# currently req has all options fields even if you send a dict like { "send_seed": false }, which means it will
# overwrite all options with default values.
raise RuntimeError('Setting options via API is not supported')
reqDict = vars(req) reqDict = vars(req)
for o in reqDict: for o in reqDict:
setattr(shared.opts, o, reqDict[o]) setattr(shared.opts, o, reqDict[o])

View file

@ -1,6 +1,6 @@
import inspect import inspect
from pydantic import BaseModel, Field, create_model from pydantic import BaseModel, Field, create_model
from typing import Any, Optional, Union from typing import Any, Optional
from typing_extensions import Literal from typing_extensions import Literal
from inflection import underscore from inflection import underscore
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img
@ -185,22 +185,22 @@ _options = vars(parser)['_option_string_actions']
for key in _options: for key in _options:
if(_options[key].dest != 'help'): if(_options[key].dest != 'help'):
flag = _options[key] flag = _options[key]
_type = str _type = str
if(_options[key].default != None): _type = type(_options[key].default) if _options[key].default is not None: _type = type(_options[key].default)
flags.update({flag.dest: (_type,Field(default=flag.default, description=flag.help))}) flags.update({flag.dest: (_type,Field(default=flag.default, description=flag.help))})
FlagsModel = create_model("Flags", **flags) FlagsModel = create_model("Flags", **flags)
class SamplerItem(BaseModel): class SamplerItem(BaseModel):
name: str = Field(title="Name") name: str = Field(title="Name")
aliases: list[str] = Field(title="Aliases") aliases: list[str] = Field(title="Aliases")
options: dict[str, str] = Field(title="Options") options: dict[str, str] = Field(title="Options")
class UpscalerItem(BaseModel): class UpscalerItem(BaseModel):
name: str = Field(title="Name") name: str = Field(title="Name")
model_name: str | None = Field(title="Model Name") model_name: Optional[str] = Field(title="Model Name")
model_path: str | None = Field(title="Path") model_path: Optional[str] = Field(title="Path")
model_url: str | None = Field(title="URL") model_url: Optional[str] = Field(title="URL")
class SDModelItem(BaseModel): class SDModelItem(BaseModel):
title: str = Field(title="Title") title: str = Field(title="Title")
@ -211,21 +211,21 @@ class SDModelItem(BaseModel):
class HypernetworkItem(BaseModel): class HypernetworkItem(BaseModel):
name: str = Field(title="Name") name: str = Field(title="Name")
path: str | None = Field(title="Path") path: Optional[str] = Field(title="Path")
class FaceRestorerItem(BaseModel): class FaceRestorerItem(BaseModel):
name: str = Field(title="Name") name: str = Field(title="Name")
cmd_dir: str | None = Field(title="Path") cmd_dir: Optional[str] = Field(title="Path")
class RealesrganItem(BaseModel): class RealesrganItem(BaseModel):
name: str = Field(title="Name") name: str = Field(title="Name")
path: str | None = Field(title="Path") path: Optional[str] = Field(title="Path")
scale: int | None = Field(title="Scale") scale: Optional[int] = Field(title="Scale")
class PromptStyleItem(BaseModel): class PromptStyleItem(BaseModel):
name: str = Field(title="Name") name: str = Field(title="Name")
prompt: str | None = Field(title="Prompt") prompt: Optional[str] = Field(title="Prompt")
negative_prompt: str | None = Field(title="Negative Prompt") negative_prompt: Optional[str] = Field(title="Negative Prompt")
class ArtistItem(BaseModel): class ArtistItem(BaseModel):
name: str = Field(title="Name") name: str = Field(title="Name")

View file

@ -34,8 +34,11 @@ class Extension:
if repo is None or repo.bare: if repo is None or repo.bare:
self.remote = None self.remote = None
else: else:
self.remote = next(repo.remote().urls, None) try:
self.status = 'unknown' self.remote = next(repo.remote().urls, None)
self.status = 'unknown'
except Exception:
self.remote = None
def list_files(self, subdir, extension): def list_files(self, subdir, extension):
from modules import scripts from modules import scripts

View file

@ -22,6 +22,8 @@ from collections import defaultdict, deque
from statistics import stdev, mean from statistics import stdev, mean
optimizer_dict = {optim_name : cls_obj for optim_name, cls_obj in inspect.getmembers(torch.optim, inspect.isclass) if optim_name != "Optimizer"}
class HypernetworkModule(torch.nn.Module): class HypernetworkModule(torch.nn.Module):
multiplier = 1.0 multiplier = 1.0
activation_dict = { activation_dict = {
@ -142,6 +144,8 @@ class Hypernetwork:
self.use_dropout = use_dropout self.use_dropout = use_dropout
self.activate_output = activate_output self.activate_output = activate_output
self.last_layer_dropout = kwargs['last_layer_dropout'] if 'last_layer_dropout' in kwargs else True self.last_layer_dropout = kwargs['last_layer_dropout'] if 'last_layer_dropout' in kwargs else True
self.optimizer_name = None
self.optimizer_state_dict = None
for size in enable_sizes or []: for size in enable_sizes or []:
self.layers[size] = ( self.layers[size] = (
@ -163,6 +167,7 @@ class Hypernetwork:
def save(self, filename): def save(self, filename):
state_dict = {} state_dict = {}
optimizer_saved_dict = {}
for k, v in self.layers.items(): for k, v in self.layers.items():
state_dict[k] = (v[0].state_dict(), v[1].state_dict()) state_dict[k] = (v[0].state_dict(), v[1].state_dict())
@ -178,8 +183,15 @@ class Hypernetwork:
state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
state_dict['activate_output'] = self.activate_output state_dict['activate_output'] = self.activate_output
state_dict['last_layer_dropout'] = self.last_layer_dropout state_dict['last_layer_dropout'] = self.last_layer_dropout
if self.optimizer_name is not None:
optimizer_saved_dict['optimizer_name'] = self.optimizer_name
torch.save(state_dict, filename) torch.save(state_dict, filename)
if shared.opts.save_optimizer_state and self.optimizer_state_dict:
optimizer_saved_dict['hash'] = sd_models.model_hash(filename)
optimizer_saved_dict['optimizer_state_dict'] = self.optimizer_state_dict
torch.save(optimizer_saved_dict, filename + '.optim')
def load(self, filename): def load(self, filename):
self.filename = filename self.filename = filename
@ -202,6 +214,18 @@ class Hypernetwork:
print(f"Activate last layer is set to {self.activate_output}") print(f"Activate last layer is set to {self.activate_output}")
self.last_layer_dropout = state_dict.get('last_layer_dropout', False) self.last_layer_dropout = state_dict.get('last_layer_dropout', False)
optimizer_saved_dict = torch.load(self.filename + '.optim', map_location = 'cpu') if os.path.exists(self.filename + '.optim') else {}
self.optimizer_name = optimizer_saved_dict.get('optimizer_name', 'AdamW')
print(f"Optimizer name is {self.optimizer_name}")
if sd_models.model_hash(filename) == optimizer_saved_dict.get('hash', None):
self.optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
else:
self.optimizer_state_dict = None
if self.optimizer_state_dict:
print("Loaded existing optimizer from checkpoint")
else:
print("No saved optimizer exists in checkpoint")
for size, sd in state_dict.items(): for size, sd in state_dict.items():
if type(size) == int: if type(size) == int:
self.layers[size] = ( self.layers[size] = (
@ -219,11 +243,11 @@ class Hypernetwork:
def list_hypernetworks(path): def list_hypernetworks(path):
res = {} res = {}
for filename in glob.iglob(os.path.join(path, '**/*.pt'), recursive=True): for filename in sorted(glob.iglob(os.path.join(path, '**/*.pt'), recursive=True)):
name = os.path.splitext(os.path.basename(filename))[0] name = os.path.splitext(os.path.basename(filename))[0]
# Prevent a hypothetical "None.pt" from being listed. # Prevent a hypothetical "None.pt" from being listed.
if name != "None": if name != "None":
res[name] = filename res[name + f"({sd_models.model_hash(filename)})"] = filename
return res return res
@ -358,6 +382,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
shared.state.textinfo = "Initializing hypernetwork training..." shared.state.textinfo = "Initializing hypernetwork training..."
shared.state.job_count = steps shared.state.job_count = steps
hypernetwork_name = hypernetwork_name.rsplit('(', 1)[0]
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt') filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), hypernetwork_name) log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), hypernetwork_name)
@ -404,8 +429,22 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
weights = hypernetwork.weights() weights = hypernetwork.weights()
for weight in weights: for weight in weights:
weight.requires_grad = True weight.requires_grad = True
# if optimizer == "AdamW": or else Adam / AdamW / SGD, etc...
optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate) # Here we use optimizer from saved HN, or we can specify as UI option.
if hypernetwork.optimizer_name in optimizer_dict:
optimizer = optimizer_dict[hypernetwork.optimizer_name](params=weights, lr=scheduler.learn_rate)
optimizer_name = hypernetwork.optimizer_name
else:
print(f"Optimizer type {hypernetwork.optimizer_name} is not defined!")
optimizer = torch.optim.AdamW(params=weights, lr=scheduler.learn_rate)
optimizer_name = 'AdamW'
if hypernetwork.optimizer_state_dict: # This line must be changed if Optimizer type can be different from saved optimizer.
try:
optimizer.load_state_dict(hypernetwork.optimizer_state_dict)
except RuntimeError as e:
print("Cannot resume from saved optimizer!")
print(e)
steps_without_grad = 0 steps_without_grad = 0
@ -467,7 +506,11 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
# Before saving, change name to match current checkpoint. # Before saving, change name to match current checkpoint.
hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}' hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}'
last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name_every}.pt') last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name_every}.pt')
hypernetwork.optimizer_name = optimizer_name
if shared.opts.save_optimizer_state:
hypernetwork.optimizer_state_dict = optimizer.state_dict()
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file) save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file)
hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), { textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), {
"loss": f"{previous_mean_loss:.7f}", "loss": f"{previous_mean_loss:.7f}",
@ -530,8 +573,12 @@ Last saved image: {html.escape(last_saved_image)}<br/>
report_statistics(loss_dict) report_statistics(loss_dict)
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt') filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
hypernetwork.optimizer_name = optimizer_name
if shared.opts.save_optimizer_state:
hypernetwork.optimizer_state_dict = optimizer.state_dict()
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename) save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename)
del optimizer
hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
return hypernetwork, filename return hypernetwork, filename
def save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename): def save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename):

View file

@ -9,7 +9,7 @@ from modules import devices, sd_hijack, shared
from modules.hypernetworks import hypernetwork from modules.hypernetworks import hypernetwork
not_available = ["hardswish", "multiheadattention"] not_available = ["hardswish", "multiheadattention"]
keys = ["linear"] + list(x for x in hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available) keys = list(x for x in hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available)
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False): def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False):
# Remove illegal characters from name. # Remove illegal characters from name.

View file

@ -86,6 +86,10 @@ parser.add_argument("--nowebui", action='store_true', help="use api=True to laun
parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load model to quickly launch UI") parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load model to quickly launch UI")
parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None) parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None)
parser.add_argument("--administrator", action='store_true', help="Administrator rights", default=False) parser.add_argument("--administrator", action='store_true', help="Administrator rights", default=False)
parser.add_argument("--cors-allow-origins", type=str, help="Allowed CORS origins", default=None)
parser.add_argument("--tls-keyfile", type=str, help="Partially enables TLS, requires --tls-certfile to fully function", default=None)
parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, requires --tls-keyfile to fully function", default=None)
parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None)
cmd_opts = parser.parse_args() cmd_opts = parser.parse_args()
restricted_opts = { restricted_opts = {
@ -147,9 +151,9 @@ class State:
self.interrupted = True self.interrupted = True
def nextjob(self): def nextjob(self):
if opts.show_progress_every_n_steps == -1: if opts.show_progress_every_n_steps == -1:
self.do_set_current_image() self.do_set_current_image()
self.job_no += 1 self.job_no += 1
self.sampling_step = 0 self.sampling_step = 0
self.current_image_sampling_step = 0 self.current_image_sampling_step = 0
@ -198,7 +202,7 @@ class State:
return return
if self.current_latent is None: if self.current_latent is None:
return return
if opts.show_progress_grid: if opts.show_progress_grid:
self.current_image = sd_samplers.samples_to_image_grid(self.current_latent) self.current_image = sd_samplers.samples_to_image_grid(self.current_latent)
else: else:
@ -317,6 +321,7 @@ options_templates.update(options_section(('system', "System"), {
options_templates.update(options_section(('training', "Training"), { options_templates.update(options_section(('training', "Training"), {
"unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."), "unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."),
"save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training can be resumed with HN itself and matching optim file."),
"dataset_filename_word_regex": OptionInfo("", "Filename word regex"), "dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
"dataset_filename_join_string": OptionInfo(" ", "Filename join string"), "dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
"training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}), "training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
@ -406,7 +411,8 @@ class Options:
if key in self.data or key in self.data_labels: if key in self.data or key in self.data_labels:
assert not cmd_opts.freeze_settings, "changing settings is disabled" assert not cmd_opts.freeze_settings, "changing settings is disabled"
comp_args = opts.data_labels[key].component_args info = opts.data_labels.get(key, None)
comp_args = info.component_args if info else None
if isinstance(comp_args, dict) and comp_args.get('visible', True) is False: if isinstance(comp_args, dict) and comp_args.get('visible', True) is False:
raise RuntimeError(f"not possible to set {key} because it is restricted") raise RuntimeError(f"not possible to set {key} because it is restricted")

View file

@ -1446,17 +1446,19 @@ def create_ui(wrap_gradio_gpu_call):
continue continue
oldval = opts.data.get(key, None) oldval = opts.data.get(key, None)
try:
setattr(opts, key, value) setattr(opts, key, value)
except RuntimeError:
continue
if oldval != value: if oldval != value:
if opts.data_labels[key].onchange is not None: if opts.data_labels[key].onchange is not None:
opts.data_labels[key].onchange() opts.data_labels[key].onchange()
changed += 1 changed += 1
try:
opts.save(shared.config_filename) opts.save(shared.config_filename)
except RuntimeError:
return opts.dumpjson(), f'{changed} settings changed without save.'
return opts.dumpjson(), f'{changed} settings changed.' return opts.dumpjson(), f'{changed} settings changed.'
def run_settings_single(value, key): def run_settings_single(value, key):

View file

@ -188,7 +188,7 @@ def refresh_available_extensions_from_data():
code += f""" code += f"""
<tr> <tr>
<td><a href="{html.escape(url)}">{html.escape(name)}</a></td> <td><a href="{html.escape(url)}" target="_blank">{html.escape(name)}</a></td>
<td>{html.escape(description)}</td> <td>{html.escape(description)}</td>
<td>{install_code}</td> <td>{install_code}</td>
</tr> </tr>

View file

@ -57,10 +57,18 @@ class Upscaler:
self.scale = scale self.scale = scale
dest_w = img.width * scale dest_w = img.width * scale
dest_h = img.height * scale dest_h = img.height * scale
for i in range(3): for i in range(3):
if img.width > dest_w and img.height > dest_h: shape = (img.width, img.height)
break
img = self.do_upscale(img, selected_model) img = self.do_upscale(img, selected_model)
if shape == (img.width, img.height):
break
if img.width >= dest_w and img.height >= dest_h:
break
if img.width != dest_w or img.height != dest_h: if img.width != dest_w or img.height != dest_h:
img = img.resize((int(dest_w), int(dest_h)), resample=LANCZOS) img = img.resize((int(dest_w), int(dest_h)), resample=LANCZOS)

View file

@ -5,6 +5,7 @@ import importlib
import signal import signal
import threading import threading
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware from fastapi.middleware.gzip import GZipMiddleware
from modules.paths import script_path from modules.paths import script_path
@ -34,7 +35,7 @@ from modules.shared import cmd_opts
import modules.hypernetworks.hypernetwork import modules.hypernetworks.hypernetwork
queue_lock = threading.Lock() queue_lock = threading.Lock()
server_name = "0.0.0.0" if cmd_opts.listen else cmd_opts.server_name
def wrap_queued_call(func): def wrap_queued_call(func):
def f(*args, **kwargs): def f(*args, **kwargs):
@ -85,6 +86,20 @@ def initialize():
shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork))) shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork)))
shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength) shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength)
if cmd_opts.tls_keyfile is not None and cmd_opts.tls_keyfile is not None:
try:
if not os.path.exists(cmd_opts.tls_keyfile):
print("Invalid path to TLS keyfile given")
if not os.path.exists(cmd_opts.tls_certfile):
print(f"Invalid path to TLS certfile: '{cmd_opts.tls_certfile}'")
except TypeError:
cmd_opts.tls_keyfile = cmd_opts.tls_certfile = None
print("TLS setup invalid, running webui without TLS")
else:
print("Running with TLS")
# make the program just exit at ctrl+c without waiting for anything # make the program just exit at ctrl+c without waiting for anything
def sigint_handler(sig, frame): def sigint_handler(sig, frame):
print(f'Interrupted with signal {sig} in {frame}') print(f'Interrupted with signal {sig} in {frame}')
@ -93,6 +108,11 @@ def initialize():
signal.signal(signal.SIGINT, sigint_handler) signal.signal(signal.SIGINT, sigint_handler)
def setup_cors(app):
if cmd_opts.cors_allow_origins:
app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_methods=['*'])
def create_api(app): def create_api(app):
from modules.api.api import Api from modules.api.api import Api
api = Api(app, queue_lock) api = Api(app, queue_lock)
@ -114,6 +134,7 @@ def api_only():
initialize() initialize()
app = FastAPI() app = FastAPI()
setup_cors(app)
app.add_middleware(GZipMiddleware, minimum_size=1000) app.add_middleware(GZipMiddleware, minimum_size=1000)
api = create_api(app) api = create_api(app)
@ -131,8 +152,10 @@ def webui():
app, local_url, share_url = demo.launch( app, local_url, share_url = demo.launch(
share=cmd_opts.share, share=cmd_opts.share,
server_name="0.0.0.0" if cmd_opts.listen else None, server_name=server_name,
server_port=cmd_opts.port, server_port=cmd_opts.port,
ssl_keyfile=cmd_opts.tls_keyfile,
ssl_certfile=cmd_opts.tls_certfile,
debug=cmd_opts.gradio_debug, debug=cmd_opts.gradio_debug,
auth=[tuple(cred.split(':')) for cred in cmd_opts.gradio_auth.strip('"').split(',')] if cmd_opts.gradio_auth else None, auth=[tuple(cred.split(':')) for cred in cmd_opts.gradio_auth.strip('"').split(',')] if cmd_opts.gradio_auth else None,
inbrowser=cmd_opts.autolaunch, inbrowser=cmd_opts.autolaunch,
@ -147,6 +170,8 @@ def webui():
# runnnig its code. We disable this here. Suggested by RyotaK. # runnnig its code. We disable this here. Suggested by RyotaK.
app.user_middleware = [x for x in app.user_middleware if x.cls.__name__ != 'CORSMiddleware'] app.user_middleware = [x for x in app.user_middleware if x.cls.__name__ != 'CORSMiddleware']
setup_cors(app)
app.add_middleware(GZipMiddleware, minimum_size=1000) app.add_middleware(GZipMiddleware, minimum_size=1000)
if launch_api: if launch_api: