train tab visual updates
allow setting train tab values from ui-config.json
This commit is contained in:
parent
9092e1ca77
commit
24d4a0841d
2 changed files with 22 additions and 15 deletions
|
@ -1281,42 +1281,48 @@ def create_ui():
|
|||
|
||||
with gr.Tab(label="Train"):
|
||||
gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images <a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\" style=\"font-weight:bold;\">[wiki]</a></p>")
|
||||
with gr.Row():
|
||||
with FormRow():
|
||||
train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
|
||||
create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name")
|
||||
with gr.Row():
|
||||
|
||||
train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=[x for x in shared.hypernetworks.keys()])
|
||||
create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted([x for x in shared.hypernetworks.keys()])}, "refresh_train_hypernetwork_name")
|
||||
with gr.Row():
|
||||
|
||||
with FormRow():
|
||||
embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005", elem_id="train_embedding_learn_rate")
|
||||
hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001", elem_id="train_hypernetwork_learn_rate")
|
||||
|
||||
with gr.Row():
|
||||
with FormRow():
|
||||
clip_grad_mode = gr.Dropdown(value="disabled", label="Gradient Clipping", choices=["disabled", "value", "norm"])
|
||||
clip_grad_value = gr.Textbox(placeholder="Gradient clip value", value="0.1", show_label=False)
|
||||
|
||||
batch_size = gr.Number(label='Batch size', value=1, precision=0, elem_id="train_batch_size")
|
||||
gradient_step = gr.Number(label='Gradient accumulation steps', value=1, precision=0, elem_id="train_gradient_step")
|
||||
with FormRow():
|
||||
batch_size = gr.Number(label='Batch size', value=1, precision=0, elem_id="train_batch_size")
|
||||
gradient_step = gr.Number(label='Gradient accumulation steps', value=1, precision=0, elem_id="train_gradient_step")
|
||||
|
||||
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images", elem_id="train_dataset_directory")
|
||||
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion", elem_id="train_log_directory")
|
||||
template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"), elem_id="train_template_file")
|
||||
training_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_training_width")
|
||||
training_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_training_height")
|
||||
steps = gr.Number(label='Max steps', value=100000, precision=0, elem_id="train_steps")
|
||||
create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0, elem_id="train_create_image_every")
|
||||
save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0, elem_id="train_save_embedding_every")
|
||||
|
||||
with FormRow():
|
||||
create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0, elem_id="train_create_image_every")
|
||||
save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0, elem_id="train_save_embedding_every")
|
||||
|
||||
save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True, elem_id="train_save_image_with_stored_embedding")
|
||||
preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False, elem_id="train_preview_from_txt2img")
|
||||
with gr.Row():
|
||||
shuffle_tags = gr.Checkbox(label="Shuffle tags by ',' when creating prompts.", value=False, elem_id="train_shuffle_tags")
|
||||
tag_drop_out = gr.Slider(minimum=0, maximum=1, step=0.1, label="Drop out tags when creating prompts.", value=0, elem_id="train_tag_drop_out")
|
||||
with gr.Row():
|
||||
latent_sampling_method = gr.Radio(label='Choose latent sampling method', value="once", choices=['once', 'deterministic', 'random'], elem_id="train_latent_sampling_method")
|
||||
|
||||
shuffle_tags = gr.Checkbox(label="Shuffle tags by ',' when creating prompts.", value=False, elem_id="train_shuffle_tags")
|
||||
tag_drop_out = gr.Slider(minimum=0, maximum=1, step=0.1, label="Drop out tags when creating prompts.", value=0, elem_id="train_tag_drop_out")
|
||||
|
||||
latent_sampling_method = gr.Radio(label='Choose latent sampling method', value="once", choices=['once', 'deterministic', 'random'], elem_id="train_latent_sampling_method")
|
||||
|
||||
with gr.Row():
|
||||
train_embedding = gr.Button(value="Train Embedding", variant='primary', elem_id="train_train_embedding")
|
||||
interrupt_training = gr.Button(value="Interrupt", elem_id="train_interrupt_training")
|
||||
train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary', elem_id="train_train_hypernetwork")
|
||||
train_embedding = gr.Button(value="Train Embedding", variant='primary', elem_id="train_train_embedding")
|
||||
|
||||
params = script_callbacks.UiTrainTabParams(txt2img_preview_params)
|
||||
|
||||
|
@ -1803,6 +1809,7 @@ def create_ui():
|
|||
visit(img2img_interface, loadsave, "img2img")
|
||||
visit(extras_interface, loadsave, "extras")
|
||||
visit(modelmerger_interface, loadsave, "modelmerger")
|
||||
visit(train_interface, loadsave, "train")
|
||||
|
||||
if not error_loading and (not os.path.exists(ui_config_file) or settings_count != len(ui_settings)):
|
||||
with open(ui_config_file, "w", encoding="utf8") as file:
|
||||
|
|
|
@ -611,7 +611,7 @@ img2maskimg, #img2maskimg > .h-60, #img2maskimg > .h-60 > div, #img2maskimg > .h
|
|||
padding-top: 0.9em;
|
||||
}
|
||||
|
||||
#img2img_settings div.gr-form .gr-form, #txt2img_settings div.gr-form .gr-form{
|
||||
#img2img_settings div.gr-form .gr-form, #txt2img_settings div.gr-form .gr-form, #train_tabs div.gr-form .gr-form{
|
||||
border: none;
|
||||
padding-bottom: 0.5em;
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue