make a dropdown for prompt template selection
This commit is contained in:
parent
43bb5190fc
commit
1fbb6f9ebe
5 changed files with 45 additions and 12 deletions
|
@ -24,6 +24,7 @@ from statistics import stdev, mean
|
||||||
|
|
||||||
optimizer_dict = {optim_name : cls_obj for optim_name, cls_obj in inspect.getmembers(torch.optim, inspect.isclass) if optim_name != "Optimizer"}
|
optimizer_dict = {optim_name : cls_obj for optim_name, cls_obj in inspect.getmembers(torch.optim, inspect.isclass) if optim_name != "Optimizer"}
|
||||||
|
|
||||||
|
|
||||||
class HypernetworkModule(torch.nn.Module):
|
class HypernetworkModule(torch.nn.Module):
|
||||||
multiplier = 1.0
|
multiplier = 1.0
|
||||||
activation_dict = {
|
activation_dict = {
|
||||||
|
@ -403,13 +404,15 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None,
|
||||||
shared.reload_hypernetworks()
|
shared.reload_hypernetworks()
|
||||||
|
|
||||||
|
|
||||||
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_filename, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
||||||
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
|
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
|
||||||
from modules import images
|
from modules import images
|
||||||
|
|
||||||
save_hypernetwork_every = save_hypernetwork_every or 0
|
save_hypernetwork_every = save_hypernetwork_every or 0
|
||||||
create_image_every = create_image_every or 0
|
create_image_every = create_image_every or 0
|
||||||
textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_hypernetwork_every, create_image_every, log_directory, name="hypernetwork")
|
template_file = textual_inversion.textual_inversion_templates.get(template_filename, None)
|
||||||
|
textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_hypernetwork_every, create_image_every, log_directory, name="hypernetwork")
|
||||||
|
template_file = template_file.path
|
||||||
|
|
||||||
path = shared.hypernetworks.get(hypernetwork_name, None)
|
path = shared.hypernetworks.get(hypernetwork_name, None)
|
||||||
shared.loaded_hypernetwork = Hypernetwork()
|
shared.loaded_hypernetwork = Hypernetwork()
|
||||||
|
|
|
@ -33,6 +33,7 @@ parser.add_argument("--no-half-vae", action='store_true', help="do not switch th
|
||||||
parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)")
|
parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)")
|
||||||
parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI")
|
parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI")
|
||||||
parser.add_argument("--embeddings-dir", type=str, default=os.path.join(script_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)")
|
parser.add_argument("--embeddings-dir", type=str, default=os.path.join(script_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)")
|
||||||
|
parser.add_argument("--textual-inversion-templates-dir", type=str, default=os.path.join(script_path, 'textual_inversion_templates'), help="directory with textual inversion templates")
|
||||||
parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory")
|
parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory")
|
||||||
parser.add_argument("--localizations-dir", type=str, default=os.path.join(script_path, 'localizations'), help="localizations directory")
|
parser.add_argument("--localizations-dir", type=str, default=os.path.join(script_path, 'localizations'), help="localizations directory")
|
||||||
parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui")
|
parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui")
|
||||||
|
|
|
@ -2,6 +2,7 @@ import os
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
import inspect
|
import inspect
|
||||||
|
from collections import namedtuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import tqdm
|
import tqdm
|
||||||
|
@ -15,12 +16,26 @@ from modules import shared, devices, sd_hijack, processing, sd_models, images, s
|
||||||
import modules.textual_inversion.dataset
|
import modules.textual_inversion.dataset
|
||||||
from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
||||||
|
|
||||||
from modules.textual_inversion.image_embedding import (embedding_to_b64, embedding_from_b64,
|
from modules.textual_inversion.image_embedding import embedding_to_b64, embedding_from_b64, insert_image_data_embed, extract_image_data_embed, caption_image_overlay
|
||||||
insert_image_data_embed, extract_image_data_embed,
|
|
||||||
caption_image_overlay)
|
|
||||||
from modules.textual_inversion.logging import save_settings_to_file
|
from modules.textual_inversion.logging import save_settings_to_file
|
||||||
|
|
||||||
|
|
||||||
|
TextualInversionTemplate = namedtuple("TextualInversionTemplate", ["name", "path"])
|
||||||
|
textual_inversion_templates = {}
|
||||||
|
|
||||||
|
|
||||||
|
def list_textual_inversion_templates():
|
||||||
|
textual_inversion_templates.clear()
|
||||||
|
|
||||||
|
for root, dirs, fns in os.walk(shared.cmd_opts.textual_inversion_templates_dir):
|
||||||
|
for fn in fns:
|
||||||
|
path = os.path.join(root, fn)
|
||||||
|
|
||||||
|
textual_inversion_templates[fn] = TextualInversionTemplate(fn, path)
|
||||||
|
|
||||||
|
return textual_inversion_templates
|
||||||
|
|
||||||
|
|
||||||
class Embedding:
|
class Embedding:
|
||||||
def __init__(self, vec, name, step=None):
|
def __init__(self, vec, name, step=None):
|
||||||
self.vec = vec
|
self.vec = vec
|
||||||
|
@ -274,7 +289,7 @@ def write_loss(log_directory, filename, step, epoch_len, values):
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_model_every, create_image_every, log_directory, name="embedding"):
|
def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_model_every, create_image_every, log_directory, name="embedding"):
|
||||||
assert model_name, f"{name} not selected"
|
assert model_name, f"{name} not selected"
|
||||||
assert learn_rate, "Learning rate is empty or 0"
|
assert learn_rate, "Learning rate is empty or 0"
|
||||||
assert isinstance(batch_size, int), "Batch size must be integer"
|
assert isinstance(batch_size, int), "Batch size must be integer"
|
||||||
|
@ -284,8 +299,9 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat
|
||||||
assert data_root, "Dataset directory is empty"
|
assert data_root, "Dataset directory is empty"
|
||||||
assert os.path.isdir(data_root), "Dataset directory doesn't exist"
|
assert os.path.isdir(data_root), "Dataset directory doesn't exist"
|
||||||
assert os.listdir(data_root), "Dataset directory is empty"
|
assert os.listdir(data_root), "Dataset directory is empty"
|
||||||
assert template_file, "Prompt template file is empty"
|
assert template_filename, "Prompt template file not selected"
|
||||||
assert os.path.isfile(template_file), "Prompt template file doesn't exist"
|
assert template_file, f"Prompt template file {template_filename} not found"
|
||||||
|
assert os.path.isfile(template_file.path), f"Prompt template file {template_filename} doesn't exist"
|
||||||
assert steps, "Max steps is empty or 0"
|
assert steps, "Max steps is empty or 0"
|
||||||
assert isinstance(steps, int), "Max steps must be integer"
|
assert isinstance(steps, int), "Max steps must be integer"
|
||||||
assert steps > 0, "Max steps must be positive"
|
assert steps > 0, "Max steps must be positive"
|
||||||
|
@ -296,10 +312,13 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat
|
||||||
if save_model_every or create_image_every:
|
if save_model_every or create_image_every:
|
||||||
assert log_directory, "Log directory is empty"
|
assert log_directory, "Log directory is empty"
|
||||||
|
|
||||||
def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
|
||||||
|
def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_filename, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
||||||
save_embedding_every = save_embedding_every or 0
|
save_embedding_every = save_embedding_every or 0
|
||||||
create_image_every = create_image_every or 0
|
create_image_every = create_image_every or 0
|
||||||
validate_train_inputs(embedding_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_embedding_every, create_image_every, log_directory, name="embedding")
|
template_file = textual_inversion_templates.get(template_filename, None)
|
||||||
|
validate_train_inputs(embedding_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_embedding_every, create_image_every, log_directory, name="embedding")
|
||||||
|
template_file = template_file.path
|
||||||
|
|
||||||
shared.state.job = "train-embedding"
|
shared.state.job = "train-embedding"
|
||||||
shared.state.textinfo = "Initializing textual inversion training..."
|
shared.state.textinfo = "Initializing textual inversion training..."
|
||||||
|
|
|
@ -37,7 +37,7 @@ from modules import prompt_parser
|
||||||
from modules.images import save_image
|
from modules.images import save_image
|
||||||
from modules.sd_hijack import model_hijack
|
from modules.sd_hijack import model_hijack
|
||||||
from modules.sd_samplers import samplers, samplers_for_img2img
|
from modules.sd_samplers import samplers, samplers_for_img2img
|
||||||
import modules.textual_inversion.ui
|
from modules.textual_inversion import textual_inversion
|
||||||
import modules.hypernetworks.ui
|
import modules.hypernetworks.ui
|
||||||
from modules.generation_parameters_copypaste import image_from_url_text
|
from modules.generation_parameters_copypaste import image_from_url_text
|
||||||
|
|
||||||
|
@ -1322,6 +1322,9 @@ def create_ui():
|
||||||
outputs=[process_focal_crop_row],
|
outputs=[process_focal_crop_row],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_textual_inversion_template_names():
|
||||||
|
return sorted([x for x in textual_inversion.textual_inversion_templates])
|
||||||
|
|
||||||
with gr.Tab(label="Train"):
|
with gr.Tab(label="Train"):
|
||||||
gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images <a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\" style=\"font-weight:bold;\">[wiki]</a></p>")
|
gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images <a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\" style=\"font-weight:bold;\">[wiki]</a></p>")
|
||||||
with FormRow():
|
with FormRow():
|
||||||
|
@ -1345,7 +1348,11 @@ def create_ui():
|
||||||
|
|
||||||
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images", elem_id="train_dataset_directory")
|
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images", elem_id="train_dataset_directory")
|
||||||
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion", elem_id="train_log_directory")
|
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion", elem_id="train_log_directory")
|
||||||
template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"), elem_id="train_template_file")
|
|
||||||
|
with FormRow():
|
||||||
|
template_file = gr.Dropdown(label='Prompt template', value="style_filewords.txt", elem_id="train_template_file", choices=get_textual_inversion_template_names())
|
||||||
|
create_refresh_button(template_file, textual_inversion.list_textual_inversion_templates, lambda: {"choices": get_textual_inversion_template_names()}, "refrsh_train_template_file")
|
||||||
|
|
||||||
training_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_training_width")
|
training_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_training_width")
|
||||||
training_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_training_height")
|
training_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_training_height")
|
||||||
varsize = gr.Checkbox(label="Do not resize images", value=False, elem_id="train_varsize")
|
varsize = gr.Checkbox(label="Do not resize images", value=False, elem_id="train_varsize")
|
||||||
|
|
3
webui.py
3
webui.py
|
@ -33,6 +33,7 @@ import modules.sd_models
|
||||||
import modules.sd_vae
|
import modules.sd_vae
|
||||||
import modules.txt2img
|
import modules.txt2img
|
||||||
import modules.script_callbacks
|
import modules.script_callbacks
|
||||||
|
import modules.textual_inversion.textual_inversion
|
||||||
|
|
||||||
import modules.ui
|
import modules.ui
|
||||||
from modules import modelloader
|
from modules import modelloader
|
||||||
|
@ -67,6 +68,8 @@ def initialize():
|
||||||
|
|
||||||
modules.sd_vae.refresh_vae_list()
|
modules.sd_vae.refresh_vae_list()
|
||||||
|
|
||||||
|
modules.textual_inversion.textual_inversion.list_textual_inversion_templates()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
modules.sd_models.load_model()
|
modules.sd_models.load_model()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
Loading…
Reference in a new issue