Merge pull request #3197 from AUTOMATIC1111/training-help-text

Training UI Changes
This commit is contained in:
AUTOMATIC1111 2022-10-21 09:58:16 +03:00 committed by GitHub
commit e4877722e3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 50 additions and 23 deletions

View file

@ -396,7 +396,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
Loss: {mean_loss:.7f}<br/> Loss: {mean_loss:.7f}<br/>
Step: {hypernetwork.step}<br/> Step: {hypernetwork.step}<br/>
Last prompt: {html.escape(entries[0].cond_text)}<br/> Last prompt: {html.escape(entries[0].cond_text)}<br/>
Last saved embedding: {html.escape(last_saved_file)}<br/> Last saved hypernetwork: {html.escape(last_saved_file)}<br/>
Last saved image: {html.escape(last_saved_image)}<br/> Last saved image: {html.escape(last_saved_image)}<br/>
</p> </p>
""" """

View file

@ -10,8 +10,9 @@ from modules import sd_hijack, shared, devices
from modules.hypernetworks import hypernetwork from modules.hypernetworks import hypernetwork
def create_hypernetwork(name, enable_sizes, layer_structure=None, add_layer_norm=False, activation_func=None): def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, add_layer_norm=False, activation_func=None):
fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt") fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
if not overwrite_old:
assert not os.path.exists(fn), f"file {fn} already exists" assert not os.path.exists(fn), f"file {fn} already exists"
if type(layer_structure) == str: if type(layer_structure) == str:

View file

@ -11,7 +11,7 @@ if cmd_opts.deepdanbooru:
import modules.deepbooru as deepbooru import modules.deepbooru as deepbooru
def preprocess(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False): def preprocess(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False):
try: try:
if process_caption: if process_caption:
shared.interrogator.load() shared.interrogator.load()
@ -21,7 +21,7 @@ def preprocess(process_src, process_dst, process_width, process_height, process_
db_opts[deepbooru.OPT_INCLUDE_RANKS] = False db_opts[deepbooru.OPT_INCLUDE_RANKS] = False
deepbooru.create_deepbooru_process(opts.interrogate_deepbooru_score_threshold, db_opts) deepbooru.create_deepbooru_process(opts.interrogate_deepbooru_score_threshold, db_opts)
preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru) preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru)
finally: finally:
@ -33,7 +33,7 @@ def preprocess(process_src, process_dst, process_width, process_height, process_
def preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False): def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False):
width = process_width width = process_width
height = process_height height = process_height
src = os.path.abspath(process_src) src = os.path.abspath(process_src)
@ -48,7 +48,7 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro
shared.state.textinfo = "Preprocessing..." shared.state.textinfo = "Preprocessing..."
shared.state.job_count = len(files) shared.state.job_count = len(files)
def save_pic_with_caption(image, index): def save_pic_with_caption(image, index, existing_caption=None):
caption = "" caption = ""
if process_caption: if process_caption:
@ -66,17 +66,26 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro
basename = f"{index:05}-{subindex[0]}-{filename_part}" basename = f"{index:05}-{subindex[0]}-{filename_part}"
image.save(os.path.join(dst, f"{basename}.png")) image.save(os.path.join(dst, f"{basename}.png"))
if preprocess_txt_action == 'prepend' and existing_caption:
caption = existing_caption + ' ' + caption
elif preprocess_txt_action == 'append' and existing_caption:
caption = caption + ' ' + existing_caption
elif preprocess_txt_action == 'copy' and existing_caption:
caption = existing_caption
caption = caption.strip()
if len(caption) > 0: if len(caption) > 0:
with open(os.path.join(dst, f"{basename}.txt"), "w", encoding="utf8") as file: with open(os.path.join(dst, f"{basename}.txt"), "w", encoding="utf8") as file:
file.write(caption) file.write(caption)
subindex[0] += 1 subindex[0] += 1
def save_pic(image, index): def save_pic(image, index, existing_caption=None):
save_pic_with_caption(image, index) save_pic_with_caption(image, index, existing_caption=existing_caption)
if process_flip: if process_flip:
save_pic_with_caption(ImageOps.mirror(image), index) save_pic_with_caption(ImageOps.mirror(image), index, existing_caption=existing_caption)
for index, imagefile in enumerate(tqdm.tqdm(files)): for index, imagefile in enumerate(tqdm.tqdm(files)):
subindex = [0] subindex = [0]
@ -86,6 +95,13 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro
except Exception: except Exception:
continue continue
existing_caption = None
try:
existing_caption = open(os.path.splitext(filename)[0] + '.txt', 'r').read()
except Exception as e:
print(e)
if shared.state.interrupted: if shared.state.interrupted:
break break
@ -97,20 +113,20 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro
img = img.resize((width, height * img.height // img.width)) img = img.resize((width, height * img.height // img.width))
top = img.crop((0, 0, width, height)) top = img.crop((0, 0, width, height))
save_pic(top, index) save_pic(top, index, existing_caption=existing_caption)
bot = img.crop((0, img.height - height, width, img.height)) bot = img.crop((0, img.height - height, width, img.height))
save_pic(bot, index) save_pic(bot, index, existing_caption=existing_caption)
elif process_split and is_wide: elif process_split and is_wide:
img = img.resize((width * img.width // img.height, height)) img = img.resize((width * img.width // img.height, height))
left = img.crop((0, 0, width, height)) left = img.crop((0, 0, width, height))
save_pic(left, index) save_pic(left, index, existing_caption=existing_caption)
right = img.crop((img.width - width, 0, img.width, height)) right = img.crop((img.width - width, 0, img.width, height))
save_pic(right, index) save_pic(right, index, existing_caption=existing_caption)
else: else:
img = images.resize_image(1, img, width, height) img = images.resize_image(1, img, width, height)
save_pic(img, index) save_pic(img, index, existing_caption=existing_caption)
shared.state.nextjob() shared.state.nextjob()

View file

@ -153,7 +153,7 @@ class EmbeddingDatabase:
return None, None return None, None
def create_embedding(name, num_vectors_per_token, init_text='*'): def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'):
cond_model = shared.sd_model.cond_stage_model cond_model = shared.sd_model.cond_stage_model
embedding_layer = cond_model.wrapped.transformer.text_model.embeddings embedding_layer = cond_model.wrapped.transformer.text_model.embeddings
@ -165,6 +165,7 @@ def create_embedding(name, num_vectors_per_token, init_text='*'):
vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token] vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token]
fn = os.path.join(shared.cmd_opts.embeddings_dir, f"{name}.pt") fn = os.path.join(shared.cmd_opts.embeddings_dir, f"{name}.pt")
if not overwrite_old:
assert not os.path.exists(fn), f"file {fn} already exists" assert not os.path.exists(fn), f"file {fn} already exists"
embedding = Embedding(vec, name) embedding = Embedding(vec, name)

View file

@ -7,8 +7,8 @@ import modules.textual_inversion.preprocess
from modules import sd_hijack, shared from modules import sd_hijack, shared
def create_embedding(name, initialization_text, nvpt): def create_embedding(name, initialization_text, nvpt, overwrite_old):
filename = modules.textual_inversion.textual_inversion.create_embedding(name, nvpt, init_text=initialization_text) filename = modules.textual_inversion.textual_inversion.create_embedding(name, nvpt, overwrite_old, init_text=initialization_text)
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()

View file

@ -1211,6 +1211,7 @@ def create_ui(wrap_gradio_gpu_call):
new_embedding_name = gr.Textbox(label="Name") new_embedding_name = gr.Textbox(label="Name")
initialization_text = gr.Textbox(label="Initialization text", value="*") initialization_text = gr.Textbox(label="Initialization text", value="*")
nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1) nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1)
overwrite_old_embedding = gr.Checkbox(value=False, label="Overwrite Old Embedding")
with gr.Row(): with gr.Row():
with gr.Column(scale=3): with gr.Column(scale=3):
@ -1224,6 +1225,7 @@ def create_ui(wrap_gradio_gpu_call):
new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"]) new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"])
new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'") new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'")
new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization") new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization")
overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork")
new_hypernetwork_activation_func = gr.Dropdown(value="relu", label="Select activation function of hypernetwork", choices=["linear", "relu", "leakyrelu"]) new_hypernetwork_activation_func = gr.Dropdown(value="relu", label="Select activation function of hypernetwork", choices=["linear", "relu", "leakyrelu"])
with gr.Row(): with gr.Row():
@ -1238,6 +1240,7 @@ def create_ui(wrap_gradio_gpu_call):
process_dst = gr.Textbox(label='Destination directory') process_dst = gr.Textbox(label='Destination directory')
process_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512) process_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
process_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512) process_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', value="ignore", choices=["ignore", "copy", "prepend", "append"])
with gr.Row(): with gr.Row():
process_flip = gr.Checkbox(label='Create flipped copies') process_flip = gr.Checkbox(label='Create flipped copies')
@ -1253,14 +1256,17 @@ def create_ui(wrap_gradio_gpu_call):
run_preprocess = gr.Button(value="Preprocess", variant='primary') run_preprocess = gr.Button(value="Preprocess", variant='primary')
with gr.Tab(label="Train"): with gr.Tab(label="Train"):
gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding; must specify a directory with a set of 1:1 ratio images</p>") gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images <a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\" style=\"font-weight:bold;\">[wiki]</a></p>")
with gr.Row(): with gr.Row():
train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())) train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name") create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name")
with gr.Row(): with gr.Row():
train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=[x for x in shared.hypernetworks.keys()]) train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=[x for x in shared.hypernetworks.keys()])
create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted([x for x in shared.hypernetworks.keys()])}, "refresh_train_hypernetwork_name") create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted([x for x in shared.hypernetworks.keys()])}, "refresh_train_hypernetwork_name")
learn_rate = gr.Textbox(label='Learning rate', placeholder="Learning rate", value="0.005") with gr.Row():
embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005")
hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001")
batch_size = gr.Number(label='Batch size', value=1, precision=0) batch_size = gr.Number(label='Batch size', value=1, precision=0)
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images") dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion") log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
@ -1294,6 +1300,7 @@ def create_ui(wrap_gradio_gpu_call):
new_embedding_name, new_embedding_name,
initialization_text, initialization_text,
nvpt, nvpt,
overwrite_old_embedding,
], ],
outputs=[ outputs=[
train_embedding_name, train_embedding_name,
@ -1307,6 +1314,7 @@ def create_ui(wrap_gradio_gpu_call):
inputs=[ inputs=[
new_hypernetwork_name, new_hypernetwork_name,
new_hypernetwork_sizes, new_hypernetwork_sizes,
overwrite_old_hypernetwork,
new_hypernetwork_layer_structure, new_hypernetwork_layer_structure,
new_hypernetwork_add_layer_norm, new_hypernetwork_add_layer_norm,
new_hypernetwork_activation_func, new_hypernetwork_activation_func,
@ -1326,6 +1334,7 @@ def create_ui(wrap_gradio_gpu_call):
process_dst, process_dst,
process_width, process_width,
process_height, process_height,
preprocess_txt_action,
process_flip, process_flip,
process_split, process_split,
process_caption, process_caption,
@ -1342,7 +1351,7 @@ def create_ui(wrap_gradio_gpu_call):
_js="start_training_textual_inversion", _js="start_training_textual_inversion",
inputs=[ inputs=[
train_embedding_name, train_embedding_name,
learn_rate, embedding_learn_rate,
batch_size, batch_size,
dataset_directory, dataset_directory,
log_directory, log_directory,
@ -1367,7 +1376,7 @@ def create_ui(wrap_gradio_gpu_call):
_js="start_training_textual_inversion", _js="start_training_textual_inversion",
inputs=[ inputs=[
train_hypernetwork_name, train_hypernetwork_name,
learn_rate, hypernetwork_learn_rate,
batch_size, batch_size,
dataset_directory, dataset_directory,
log_directory, log_directory,