Merge pull request #1454 from AUTOMATIC1111/1404-script-reload-without-restart

Gradio+custom script+js+css reload without model reloading
This commit is contained in:
AUTOMATIC1111 2022-10-02 21:33:32 +03:00 committed by GitHub
commit ad0503c1b5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 109 additions and 19 deletions

View file

@ -218,6 +218,7 @@ function update_token_counter(button_id) {
clearTimeout(token_timeout); clearTimeout(token_timeout);
token_timeout = setTimeout(() => gradioApp().getElementById(button_id)?.click(), wait_time); token_timeout = setTimeout(() => gradioApp().getElementById(button_id)?.click(), wait_time);
} }
function submit_prompt(event, generate_button_id) { function submit_prompt(event, generate_button_id) {
if (event.altKey && event.keyCode === 13) { if (event.altKey && event.keyCode === 13) {
event.preventDefault(); event.preventDefault();
@ -225,3 +226,8 @@ function submit_prompt(event, generate_button_id) {
return; return;
} }
} }
function restart_reload(){
document.body.innerHTML='<h1 style="font-family:monospace;margin-top:20%;color:lightgray;text-align:center;">Reloading...</h1>';
setTimeout(function(){location.reload()},2000)
}

View file

@ -162,6 +162,40 @@ class ScriptRunner:
return processed return processed
def reload_sources(self):
for si, script in list(enumerate(self.scripts)):
with open(script.filename, "r", encoding="utf8") as file:
args_from = script.args_from
args_to = script.args_to
filename = script.filename
text = file.read()
from types import ModuleType
compiled = compile(text, filename, 'exec')
module = ModuleType(script.filename)
exec(compiled, module.__dict__)
for key, script_class in module.__dict__.items():
if type(script_class) == type and issubclass(script_class, Script):
self.scripts[si] = script_class()
self.scripts[si].filename = filename
self.scripts[si].args_from = args_from
self.scripts[si].args_to = args_to
scripts_txt2img = ScriptRunner() scripts_txt2img = ScriptRunner()
scripts_img2img = ScriptRunner() scripts_img2img = ScriptRunner()
def reload_script_body_only():
scripts_txt2img.reload_sources()
scripts_img2img.reload_sources()
def reload_scripts(basedir):
global scripts_txt2img, scripts_img2img
scripts_data.clear()
load_scripts(basedir)
scripts_txt2img = ScriptRunner()
scripts_img2img = ScriptRunner()

View file

@ -1145,6 +1145,31 @@ def create_ui(wrap_gradio_gpu_call):
_js='function(){}' _js='function(){}'
) )
with gr.Row():
reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary')
restart_gradio = gr.Button(value='Restart Gradio and Refresh components (Custom Scripts, ui.py, js and css only)', variant='primary')
def reload_scripts():
modules.scripts.reload_script_body_only()
reload_script_bodies.click(
fn=reload_scripts,
inputs=[],
outputs=[],
_js='function(){}'
)
def request_restart():
settings_interface.gradio_ref.do_restart = True
restart_gradio.click(
fn=request_restart,
inputs=[],
outputs=[],
_js='function(){restart_reload()}'
)
if column is not None: if column is not None:
column.__exit__() column.__exit__()
@ -1171,6 +1196,8 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo: with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo:
settings_interface.gradio_ref = demo
with gr.Tabs() as tabs: with gr.Tabs() as tabs:
for interface, label, ifid in interfaces: for interface, label, ifid in interfaces:
with gr.TabItem(label, id=ifid): with gr.TabItem(label, id=ifid):
@ -1350,12 +1377,12 @@ for filename in sorted(os.listdir(jsdir)):
javascript += f"\n<script>{jsfile.read()}</script>" javascript += f"\n<script>{jsfile.read()}</script>"
def template_response(*args, **kwargs): if 'gradio_routes_templates_response' not in globals():
def template_response(*args, **kwargs):
res = gradio_routes_templates_response(*args, **kwargs) res = gradio_routes_templates_response(*args, **kwargs)
res.body = res.body.replace(b'</head>', f'{javascript}</head>'.encode("utf8")) res.body = res.body.replace(b'</head>', f'{javascript}</head>'.encode("utf8"))
res.init_headers() res.init_headers()
return res return res
gradio_routes_templates_response = gradio.routes.templates.TemplateResponse
gradio_routes_templates_response = gradio.routes.templates.TemplateResponse gradio.routes.templates.TemplateResponse = template_response
gradio.routes.templates.TemplateResponse = template_response

View file

@ -1,4 +1,9 @@
import os import os
import threading
import time
import importlib
from modules import devices
from modules.paths import script_path
import signal import signal
import threading import threading
@ -82,6 +87,8 @@ def webui():
signal.signal(signal.SIGINT, sigint_handler) signal.signal(signal.SIGINT, sigint_handler)
while 1:
demo = modules.ui.create_ui(wrap_gradio_gpu_call=wrap_gradio_gpu_call) demo = modules.ui.create_ui(wrap_gradio_gpu_call=wrap_gradio_gpu_call)
demo.launch( demo.launch(
@ -91,8 +98,24 @@ def webui():
debug=cmd_opts.gradio_debug, debug=cmd_opts.gradio_debug,
auth=[tuple(cred.split(':')) for cred in cmd_opts.gradio_auth.strip('"').split(',')] if cmd_opts.gradio_auth else None, auth=[tuple(cred.split(':')) for cred in cmd_opts.gradio_auth.strip('"').split(',')] if cmd_opts.gradio_auth else None,
inbrowser=cmd_opts.autolaunch, inbrowser=cmd_opts.autolaunch,
prevent_thread_lock=True
) )
while 1:
time.sleep(0.5)
if getattr(demo,'do_restart',False):
time.sleep(0.5)
demo.close()
time.sleep(0.5)
break
print('Reloading Custom Scripts')
modules.scripts.reload_scripts(os.path.join(script_path, "scripts"))
print('Reloading modules: modules.ui')
importlib.reload(modules.ui)
print('Restarting Gradio')
if __name__ == "__main__": if __name__ == "__main__":
webui() webui()