feat: add app started callback
This commit is contained in:
parent
17a2076f72
commit
423f222283
2 changed files with 18 additions and 0 deletions
|
@ -3,6 +3,8 @@ import traceback
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
import inspect
|
import inspect
|
||||||
|
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from gradio import Blocks
|
||||||
|
|
||||||
def report_exception(c, job):
|
def report_exception(c, job):
|
||||||
print(f"Error executing callback {job} for {c.script}", file=sys.stderr)
|
print(f"Error executing callback {job} for {c.script}", file=sys.stderr)
|
||||||
|
@ -25,6 +27,7 @@ class ImageSaveParams:
|
||||||
|
|
||||||
|
|
||||||
ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"])
|
ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"])
|
||||||
|
callbacks_app_started = []
|
||||||
callbacks_model_loaded = []
|
callbacks_model_loaded = []
|
||||||
callbacks_ui_tabs = []
|
callbacks_ui_tabs = []
|
||||||
callbacks_ui_settings = []
|
callbacks_ui_settings = []
|
||||||
|
@ -40,6 +43,14 @@ def clear_callbacks():
|
||||||
callbacks_image_saved.clear()
|
callbacks_image_saved.clear()
|
||||||
|
|
||||||
|
|
||||||
|
def app_started_callback(demo: Blocks, app: FastAPI):
|
||||||
|
for c in callbacks_app_started:
|
||||||
|
try:
|
||||||
|
c.callback(demo, app)
|
||||||
|
except Exception:
|
||||||
|
report_exception(c, 'app_started_callback')
|
||||||
|
|
||||||
|
|
||||||
def model_loaded_callback(sd_model):
|
def model_loaded_callback(sd_model):
|
||||||
for c in callbacks_model_loaded:
|
for c in callbacks_model_loaded:
|
||||||
try:
|
try:
|
||||||
|
@ -91,6 +102,10 @@ def add_callback(callbacks, fun):
|
||||||
callbacks.append(ScriptCallback(filename, fun))
|
callbacks.append(ScriptCallback(filename, fun))
|
||||||
|
|
||||||
|
|
||||||
|
def on_app_started(callback):
|
||||||
|
add_callback(callbacks_app_started, callback)
|
||||||
|
|
||||||
|
|
||||||
def on_model_loaded(callback):
|
def on_model_loaded(callback):
|
||||||
"""register a function to be called when the stable diffusion model is created; the model is
|
"""register a function to be called when the stable diffusion model is created; the model is
|
||||||
passed as an argument"""
|
passed as an argument"""
|
||||||
|
|
3
webui.py
3
webui.py
|
@ -23,6 +23,7 @@ import modules.sd_hijack
|
||||||
import modules.sd_models
|
import modules.sd_models
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
import modules.txt2img
|
import modules.txt2img
|
||||||
|
import modules.script_callbacks
|
||||||
|
|
||||||
import modules.ui
|
import modules.ui
|
||||||
from modules import devices
|
from modules import devices
|
||||||
|
@ -135,6 +136,8 @@ def webui():
|
||||||
if (launch_api):
|
if (launch_api):
|
||||||
create_api(app)
|
create_api(app)
|
||||||
|
|
||||||
|
modules.script_callbacks.app_started_callback(demo, app)
|
||||||
|
|
||||||
wait_on_server(demo)
|
wait_on_server(demo)
|
||||||
|
|
||||||
sd_samplers.set_samplers()
|
sd_samplers.set_samplers()
|
||||||
|
|
Loading…
Reference in a new issue