add callback for creating a tab in train UI
This commit is contained in:
parent
8011be33c3
commit
1610b32584
2 changed files with 29 additions and 2 deletions
|
@ -7,6 +7,7 @@ from typing import Optional
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from gradio import Blocks
|
from gradio import Blocks
|
||||||
|
|
||||||
|
|
||||||
def report_exception(c, job):
|
def report_exception(c, job):
|
||||||
print(f"Error executing callback {job} for {c.script}", file=sys.stderr)
|
print(f"Error executing callback {job} for {c.script}", file=sys.stderr)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
@ -45,15 +46,21 @@ class CFGDenoiserParams:
|
||||||
"""Total number of sampling steps planned"""
|
"""Total number of sampling steps planned"""
|
||||||
|
|
||||||
|
|
||||||
|
class UiTrainTabParams:
|
||||||
|
def __init__(self, txt2img_preview_params):
|
||||||
|
self.txt2img_preview_params = txt2img_preview_params
|
||||||
|
|
||||||
|
|
||||||
ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"])
|
ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"])
|
||||||
callback_map = dict(
|
callback_map = dict(
|
||||||
callbacks_app_started=[],
|
callbacks_app_started=[],
|
||||||
callbacks_model_loaded=[],
|
callbacks_model_loaded=[],
|
||||||
callbacks_ui_tabs=[],
|
callbacks_ui_tabs=[],
|
||||||
|
callbacks_ui_train_tabs=[],
|
||||||
callbacks_ui_settings=[],
|
callbacks_ui_settings=[],
|
||||||
callbacks_before_image_saved=[],
|
callbacks_before_image_saved=[],
|
||||||
callbacks_image_saved=[],
|
callbacks_image_saved=[],
|
||||||
callbacks_cfg_denoiser=[]
|
callbacks_cfg_denoiser=[],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -61,6 +68,7 @@ def clear_callbacks():
|
||||||
for callback_list in callback_map.values():
|
for callback_list in callback_map.values():
|
||||||
callback_list.clear()
|
callback_list.clear()
|
||||||
|
|
||||||
|
|
||||||
def app_started_callback(demo: Optional[Blocks], app: FastAPI):
|
def app_started_callback(demo: Optional[Blocks], app: FastAPI):
|
||||||
for c in callback_map['callbacks_app_started']:
|
for c in callback_map['callbacks_app_started']:
|
||||||
try:
|
try:
|
||||||
|
@ -89,6 +97,14 @@ def ui_tabs_callback():
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def ui_train_tabs_callback(params: UiTrainTabParams):
|
||||||
|
for c in callback_map['callbacks_ui_train_tabs']:
|
||||||
|
try:
|
||||||
|
c.callback(params)
|
||||||
|
except Exception:
|
||||||
|
report_exception(c, 'callbacks_ui_train_tabs')
|
||||||
|
|
||||||
|
|
||||||
def ui_settings_callback():
|
def ui_settings_callback():
|
||||||
for c in callback_map['callbacks_ui_settings']:
|
for c in callback_map['callbacks_ui_settings']:
|
||||||
try:
|
try:
|
||||||
|
@ -169,6 +185,13 @@ def on_ui_tabs(callback):
|
||||||
add_callback(callback_map['callbacks_ui_tabs'], callback)
|
add_callback(callback_map['callbacks_ui_tabs'], callback)
|
||||||
|
|
||||||
|
|
||||||
|
def on_ui_train_tabs(callback):
|
||||||
|
"""register a function to be called when the UI is creating new tabs for the train tab.
|
||||||
|
Create your new tabs with gr.Tab.
|
||||||
|
"""
|
||||||
|
add_callback(callback_map['callbacks_ui_train_tabs'], callback)
|
||||||
|
|
||||||
|
|
||||||
def on_ui_settings(callback):
|
def on_ui_settings(callback):
|
||||||
"""register a function to be called before UI settings are populated; add your settings
|
"""register a function to be called before UI settings are populated; add your settings
|
||||||
by using shared.opts.add_option(shared.OptionInfo(...)) """
|
by using shared.opts.add_option(shared.OptionInfo(...)) """
|
||||||
|
|
|
@ -1270,6 +1270,10 @@ def create_ui(wrap_gradio_gpu_call):
|
||||||
train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary')
|
train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary')
|
||||||
train_embedding = gr.Button(value="Train Embedding", variant='primary')
|
train_embedding = gr.Button(value="Train Embedding", variant='primary')
|
||||||
|
|
||||||
|
params = script_callbacks.UiTrainTabParams(txt2img_preview_params)
|
||||||
|
|
||||||
|
script_callbacks.ui_train_tabs_callback(params)
|
||||||
|
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
progressbar = gr.HTML(elem_id="ti_progressbar")
|
progressbar = gr.HTML(elem_id="ti_progressbar")
|
||||||
ti_output = gr.Text(elem_id="ti_output", value="", show_label=False)
|
ti_output = gr.Text(elem_id="ti_output", value="", show_label=False)
|
||||||
|
|
Loading…
Reference in a new issue