Make it so that upscalers are not repeated when restarting UI.

This commit is contained in:
AUTOMATIC 2023-01-03 18:38:21 +03:00
parent e9fb9bb0c2
commit 2d5a5076bb
2 changed files with 27 additions and 7 deletions

View file

@ -123,6 +123,23 @@ def move_files(src_path: str, dest_path: str, ext_filter: str = None):
pass pass
builtin_upscaler_classes = []
forbidden_upscaler_classes = set()
def list_builtin_upscalers():
load_upscalers()
builtin_upscaler_classes.clear()
builtin_upscaler_classes.extend(Upscaler.__subclasses__())
def forbid_loaded_nonbuiltin_upscalers():
for cls in Upscaler.__subclasses__():
if cls not in builtin_upscaler_classes:
forbidden_upscaler_classes.add(cls)
def load_upscalers(): def load_upscalers():
# We can only do this 'magic' method to dynamically load upscalers if they are referenced, # We can only do this 'magic' method to dynamically load upscalers if they are referenced,
# so we'll try to import any _model.py files before looking in __subclasses__ # so we'll try to import any _model.py files before looking in __subclasses__
@ -139,6 +156,9 @@ def load_upscalers():
datas = [] datas = []
commandline_options = vars(shared.cmd_opts) commandline_options = vars(shared.cmd_opts)
for cls in Upscaler.__subclasses__(): for cls in Upscaler.__subclasses__():
if cls in forbidden_upscaler_classes:
continue
name = cls.__name__ name = cls.__name__
cmd_name = f"{name.lower().replace('upscaler', '')}_models_path" cmd_name = f"{name.lower().replace('upscaler', '')}_models_path"
scaler = cls(commandline_options.get(cmd_name, None)) scaler = cls(commandline_options.get(cmd_name, None))

View file

@ -1,4 +1,5 @@
import os import os
import sys
import threading import threading
import time import time
import importlib import importlib
@ -55,8 +56,8 @@ def initialize():
gfpgan.setup_model(cmd_opts.gfpgan_models_path) gfpgan.setup_model(cmd_opts.gfpgan_models_path)
shared.face_restorers.append(modules.face_restoration.FaceRestoration()) shared.face_restorers.append(modules.face_restoration.FaceRestoration())
modelloader.list_builtin_upscalers()
modules.scripts.load_scripts() modules.scripts.load_scripts()
modelloader.load_upscalers() modelloader.load_upscalers()
modules.sd_vae.refresh_vae_list() modules.sd_vae.refresh_vae_list()
@ -169,23 +170,22 @@ def webui():
modules.script_callbacks.app_started_callback(shared.demo, app) modules.script_callbacks.app_started_callback(shared.demo, app)
wait_on_server(shared.demo) wait_on_server(shared.demo)
print('Restarting UI...')
sd_samplers.set_samplers() sd_samplers.set_samplers()
print('Reloading extensions')
extensions.list_extensions() extensions.list_extensions()
localization.list_localizations(cmd_opts.localizations_dir) localization.list_localizations(cmd_opts.localizations_dir)
print('Reloading custom scripts') modelloader.forbid_loaded_nonbuiltin_upscalers()
modules.scripts.reload_scripts() modules.scripts.reload_scripts()
modelloader.load_upscalers() modelloader.load_upscalers()
print('Reloading modules: modules.ui') for module in [module for name, module in sys.modules.items() if name.startswith("modules.ui")]:
importlib.reload(modules.ui) importlib.reload(module)
print('Refreshing Model List')
modules.sd_models.list_models() modules.sd_models.list_models()
print('Restarting Gradio')
if __name__ == "__main__": if __name__ == "__main__":