add callback for creating a tab in train UI
This commit is contained in:
parent
8011be33c3
commit
1610b32584
2 changed files with 29 additions and 2 deletions
|
@ -7,6 +7,7 @@ from typing import Optional
|
|||
from fastapi import FastAPI
|
||||
from gradio import Blocks
|
||||
|
||||
|
||||
def report_exception(c, job):
|
||||
print(f"Error executing callback {job} for {c.script}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
@ -45,15 +46,21 @@ class CFGDenoiserParams:
|
|||
"""Total number of sampling steps planned"""
|
||||
|
||||
|
||||
class UiTrainTabParams:
|
||||
def __init__(self, txt2img_preview_params):
|
||||
self.txt2img_preview_params = txt2img_preview_params
|
||||
|
||||
|
||||
ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"])
|
||||
callback_map = dict(
|
||||
callbacks_app_started=[],
|
||||
callbacks_model_loaded=[],
|
||||
callbacks_ui_tabs=[],
|
||||
callbacks_ui_train_tabs=[],
|
||||
callbacks_ui_settings=[],
|
||||
callbacks_before_image_saved=[],
|
||||
callbacks_image_saved=[],
|
||||
callbacks_cfg_denoiser=[]
|
||||
callbacks_cfg_denoiser=[],
|
||||
)
|
||||
|
||||
|
||||
|
@ -61,6 +68,7 @@ def clear_callbacks():
|
|||
for callback_list in callback_map.values():
|
||||
callback_list.clear()
|
||||
|
||||
|
||||
def app_started_callback(demo: Optional[Blocks], app: FastAPI):
|
||||
for c in callback_map['callbacks_app_started']:
|
||||
try:
|
||||
|
@ -79,7 +87,7 @@ def model_loaded_callback(sd_model):
|
|||
|
||||
def ui_tabs_callback():
|
||||
res = []
|
||||
|
||||
|
||||
for c in callback_map['callbacks_ui_tabs']:
|
||||
try:
|
||||
res += c.callback() or []
|
||||
|
@ -89,6 +97,14 @@ def ui_tabs_callback():
|
|||
return res
|
||||
|
||||
|
||||
def ui_train_tabs_callback(params: UiTrainTabParams):
|
||||
for c in callback_map['callbacks_ui_train_tabs']:
|
||||
try:
|
||||
c.callback(params)
|
||||
except Exception:
|
||||
report_exception(c, 'callbacks_ui_train_tabs')
|
||||
|
||||
|
||||
def ui_settings_callback():
|
||||
for c in callback_map['callbacks_ui_settings']:
|
||||
try:
|
||||
|
@ -169,6 +185,13 @@ def on_ui_tabs(callback):
|
|||
add_callback(callback_map['callbacks_ui_tabs'], callback)
|
||||
|
||||
|
||||
def on_ui_train_tabs(callback):
|
||||
"""register a function to be called when the UI is creating new tabs for the train tab.
|
||||
Create your new tabs with gr.Tab.
|
||||
"""
|
||||
add_callback(callback_map['callbacks_ui_train_tabs'], callback)
|
||||
|
||||
|
||||
def on_ui_settings(callback):
|
||||
"""register a function to be called before UI settings are populated; add your settings
|
||||
by using shared.opts.add_option(shared.OptionInfo(...)) """
|
||||
|
|
|
@ -1270,6 +1270,10 @@ def create_ui(wrap_gradio_gpu_call):
|
|||
train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary')
|
||||
train_embedding = gr.Button(value="Train Embedding", variant='primary')
|
||||
|
||||
params = script_callbacks.UiTrainTabParams(txt2img_preview_params)
|
||||
|
||||
script_callbacks.ui_train_tabs_callback(params)
|
||||
|
||||
with gr.Column():
|
||||
progressbar = gr.HTML(elem_id="ti_progressbar")
|
||||
ti_output = gr.Text(elem_id="ti_output", value="", show_label=False)
|
||||
|
|
Loading…
Reference in a new issue