feat: add app started callback

This commit is contained in:
Maiko Tan 2022-10-30 22:46:43 +08:00
parent 17a2076f72
commit 423f222283
No known key found for this signature in database
GPG key ID: 0F3B49C721E5F453
2 changed files with 18 additions and 0 deletions

View file

@ -3,6 +3,8 @@ import traceback
from collections import namedtuple from collections import namedtuple
import inspect import inspect
from fastapi import FastAPI
from gradio import Blocks
def report_exception(c, job): def report_exception(c, job):
print(f"Error executing callback {job} for {c.script}", file=sys.stderr) print(f"Error executing callback {job} for {c.script}", file=sys.stderr)
@ -25,6 +27,7 @@ class ImageSaveParams:
ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"]) ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"])
callbacks_app_started = []
callbacks_model_loaded = [] callbacks_model_loaded = []
callbacks_ui_tabs = [] callbacks_ui_tabs = []
callbacks_ui_settings = [] callbacks_ui_settings = []
@ -40,6 +43,14 @@ def clear_callbacks():
callbacks_image_saved.clear() callbacks_image_saved.clear()
def app_started_callback(demo: Blocks, app: FastAPI):
for c in callbacks_app_started:
try:
c.callback(demo, app)
except Exception:
report_exception(c, 'app_started_callback')
def model_loaded_callback(sd_model): def model_loaded_callback(sd_model):
for c in callbacks_model_loaded: for c in callbacks_model_loaded:
try: try:
@ -91,6 +102,10 @@ def add_callback(callbacks, fun):
callbacks.append(ScriptCallback(filename, fun)) callbacks.append(ScriptCallback(filename, fun))
def on_app_started(callback):
add_callback(callbacks_app_started, callback)
def on_model_loaded(callback): def on_model_loaded(callback):
"""register a function to be called when the stable diffusion model is created; the model is """register a function to be called when the stable diffusion model is created; the model is
passed as an argument""" passed as an argument"""

View file

@ -23,6 +23,7 @@ import modules.sd_hijack
import modules.sd_models import modules.sd_models
import modules.shared as shared import modules.shared as shared
import modules.txt2img import modules.txt2img
import modules.script_callbacks
import modules.ui import modules.ui
from modules import devices from modules import devices
@ -135,6 +136,8 @@ def webui():
if (launch_api): if (launch_api):
create_api(app) create_api(app)
modules.script_callbacks.app_started_callback(demo, app)
wait_on_server(demo) wait_on_server(demo)
sd_samplers.set_samplers() sd_samplers.set_samplers()