Merge pull request #3197 from AUTOMATIC1111/training-help-text
Training UI Changes
This commit is contained in:
commit
e4877722e3
6 changed files with 50 additions and 23 deletions
|
@ -396,7 +396,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
||||||
Loss: {mean_loss:.7f}<br/>
|
Loss: {mean_loss:.7f}<br/>
|
||||||
Step: {hypernetwork.step}<br/>
|
Step: {hypernetwork.step}<br/>
|
||||||
Last prompt: {html.escape(entries[0].cond_text)}<br/>
|
Last prompt: {html.escape(entries[0].cond_text)}<br/>
|
||||||
Last saved embedding: {html.escape(last_saved_file)}<br/>
|
Last saved hypernetwork: {html.escape(last_saved_file)}<br/>
|
||||||
Last saved image: {html.escape(last_saved_image)}<br/>
|
Last saved image: {html.escape(last_saved_image)}<br/>
|
||||||
</p>
|
</p>
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -10,9 +10,10 @@ from modules import sd_hijack, shared, devices
|
||||||
from modules.hypernetworks import hypernetwork
|
from modules.hypernetworks import hypernetwork
|
||||||
|
|
||||||
|
|
||||||
def create_hypernetwork(name, enable_sizes, layer_structure=None, add_layer_norm=False, activation_func=None):
|
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, add_layer_norm=False, activation_func=None):
|
||||||
fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
|
fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
|
||||||
assert not os.path.exists(fn), f"file {fn} already exists"
|
if not overwrite_old:
|
||||||
|
assert not os.path.exists(fn), f"file {fn} already exists"
|
||||||
|
|
||||||
if type(layer_structure) == str:
|
if type(layer_structure) == str:
|
||||||
layer_structure = [float(x.strip()) for x in layer_structure.split(",")]
|
layer_structure = [float(x.strip()) for x in layer_structure.split(",")]
|
||||||
|
|
|
@ -11,7 +11,7 @@ if cmd_opts.deepdanbooru:
|
||||||
import modules.deepbooru as deepbooru
|
import modules.deepbooru as deepbooru
|
||||||
|
|
||||||
|
|
||||||
def preprocess(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False):
|
def preprocess(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False):
|
||||||
try:
|
try:
|
||||||
if process_caption:
|
if process_caption:
|
||||||
shared.interrogator.load()
|
shared.interrogator.load()
|
||||||
|
@ -21,7 +21,7 @@ def preprocess(process_src, process_dst, process_width, process_height, process_
|
||||||
db_opts[deepbooru.OPT_INCLUDE_RANKS] = False
|
db_opts[deepbooru.OPT_INCLUDE_RANKS] = False
|
||||||
deepbooru.create_deepbooru_process(opts.interrogate_deepbooru_score_threshold, db_opts)
|
deepbooru.create_deepbooru_process(opts.interrogate_deepbooru_score_threshold, db_opts)
|
||||||
|
|
||||||
preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru)
|
preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru)
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
|
|
||||||
|
@ -33,7 +33,7 @@ def preprocess(process_src, process_dst, process_width, process_height, process_
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False):
|
def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False):
|
||||||
width = process_width
|
width = process_width
|
||||||
height = process_height
|
height = process_height
|
||||||
src = os.path.abspath(process_src)
|
src = os.path.abspath(process_src)
|
||||||
|
@ -48,7 +48,7 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro
|
||||||
shared.state.textinfo = "Preprocessing..."
|
shared.state.textinfo = "Preprocessing..."
|
||||||
shared.state.job_count = len(files)
|
shared.state.job_count = len(files)
|
||||||
|
|
||||||
def save_pic_with_caption(image, index):
|
def save_pic_with_caption(image, index, existing_caption=None):
|
||||||
caption = ""
|
caption = ""
|
||||||
|
|
||||||
if process_caption:
|
if process_caption:
|
||||||
|
@ -66,17 +66,26 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro
|
||||||
basename = f"{index:05}-{subindex[0]}-{filename_part}"
|
basename = f"{index:05}-{subindex[0]}-{filename_part}"
|
||||||
image.save(os.path.join(dst, f"{basename}.png"))
|
image.save(os.path.join(dst, f"{basename}.png"))
|
||||||
|
|
||||||
|
if preprocess_txt_action == 'prepend' and existing_caption:
|
||||||
|
caption = existing_caption + ' ' + caption
|
||||||
|
elif preprocess_txt_action == 'append' and existing_caption:
|
||||||
|
caption = caption + ' ' + existing_caption
|
||||||
|
elif preprocess_txt_action == 'copy' and existing_caption:
|
||||||
|
caption = existing_caption
|
||||||
|
|
||||||
|
caption = caption.strip()
|
||||||
|
|
||||||
if len(caption) > 0:
|
if len(caption) > 0:
|
||||||
with open(os.path.join(dst, f"{basename}.txt"), "w", encoding="utf8") as file:
|
with open(os.path.join(dst, f"{basename}.txt"), "w", encoding="utf8") as file:
|
||||||
file.write(caption)
|
file.write(caption)
|
||||||
|
|
||||||
subindex[0] += 1
|
subindex[0] += 1
|
||||||
|
|
||||||
def save_pic(image, index):
|
def save_pic(image, index, existing_caption=None):
|
||||||
save_pic_with_caption(image, index)
|
save_pic_with_caption(image, index, existing_caption=existing_caption)
|
||||||
|
|
||||||
if process_flip:
|
if process_flip:
|
||||||
save_pic_with_caption(ImageOps.mirror(image), index)
|
save_pic_with_caption(ImageOps.mirror(image), index, existing_caption=existing_caption)
|
||||||
|
|
||||||
for index, imagefile in enumerate(tqdm.tqdm(files)):
|
for index, imagefile in enumerate(tqdm.tqdm(files)):
|
||||||
subindex = [0]
|
subindex = [0]
|
||||||
|
@ -86,6 +95,13 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro
|
||||||
except Exception:
|
except Exception:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
existing_caption = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
existing_caption = open(os.path.splitext(filename)[0] + '.txt', 'r').read()
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
|
||||||
if shared.state.interrupted:
|
if shared.state.interrupted:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
@ -97,20 +113,20 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro
|
||||||
img = img.resize((width, height * img.height // img.width))
|
img = img.resize((width, height * img.height // img.width))
|
||||||
|
|
||||||
top = img.crop((0, 0, width, height))
|
top = img.crop((0, 0, width, height))
|
||||||
save_pic(top, index)
|
save_pic(top, index, existing_caption=existing_caption)
|
||||||
|
|
||||||
bot = img.crop((0, img.height - height, width, img.height))
|
bot = img.crop((0, img.height - height, width, img.height))
|
||||||
save_pic(bot, index)
|
save_pic(bot, index, existing_caption=existing_caption)
|
||||||
elif process_split and is_wide:
|
elif process_split and is_wide:
|
||||||
img = img.resize((width * img.width // img.height, height))
|
img = img.resize((width * img.width // img.height, height))
|
||||||
|
|
||||||
left = img.crop((0, 0, width, height))
|
left = img.crop((0, 0, width, height))
|
||||||
save_pic(left, index)
|
save_pic(left, index, existing_caption=existing_caption)
|
||||||
|
|
||||||
right = img.crop((img.width - width, 0, img.width, height))
|
right = img.crop((img.width - width, 0, img.width, height))
|
||||||
save_pic(right, index)
|
save_pic(right, index, existing_caption=existing_caption)
|
||||||
else:
|
else:
|
||||||
img = images.resize_image(1, img, width, height)
|
img = images.resize_image(1, img, width, height)
|
||||||
save_pic(img, index)
|
save_pic(img, index, existing_caption=existing_caption)
|
||||||
|
|
||||||
shared.state.nextjob()
|
shared.state.nextjob()
|
||||||
|
|
|
@ -153,7 +153,7 @@ class EmbeddingDatabase:
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
|
|
||||||
def create_embedding(name, num_vectors_per_token, init_text='*'):
|
def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'):
|
||||||
cond_model = shared.sd_model.cond_stage_model
|
cond_model = shared.sd_model.cond_stage_model
|
||||||
embedding_layer = cond_model.wrapped.transformer.text_model.embeddings
|
embedding_layer = cond_model.wrapped.transformer.text_model.embeddings
|
||||||
|
|
||||||
|
@ -165,7 +165,8 @@ def create_embedding(name, num_vectors_per_token, init_text='*'):
|
||||||
vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token]
|
vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token]
|
||||||
|
|
||||||
fn = os.path.join(shared.cmd_opts.embeddings_dir, f"{name}.pt")
|
fn = os.path.join(shared.cmd_opts.embeddings_dir, f"{name}.pt")
|
||||||
assert not os.path.exists(fn), f"file {fn} already exists"
|
if not overwrite_old:
|
||||||
|
assert not os.path.exists(fn), f"file {fn} already exists"
|
||||||
|
|
||||||
embedding = Embedding(vec, name)
|
embedding = Embedding(vec, name)
|
||||||
embedding.step = 0
|
embedding.step = 0
|
||||||
|
|
|
@ -7,8 +7,8 @@ import modules.textual_inversion.preprocess
|
||||||
from modules import sd_hijack, shared
|
from modules import sd_hijack, shared
|
||||||
|
|
||||||
|
|
||||||
def create_embedding(name, initialization_text, nvpt):
|
def create_embedding(name, initialization_text, nvpt, overwrite_old):
|
||||||
filename = modules.textual_inversion.textual_inversion.create_embedding(name, nvpt, init_text=initialization_text)
|
filename = modules.textual_inversion.textual_inversion.create_embedding(name, nvpt, overwrite_old, init_text=initialization_text)
|
||||||
|
|
||||||
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
|
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
|
||||||
|
|
||||||
|
|
|
@ -1211,6 +1211,7 @@ def create_ui(wrap_gradio_gpu_call):
|
||||||
new_embedding_name = gr.Textbox(label="Name")
|
new_embedding_name = gr.Textbox(label="Name")
|
||||||
initialization_text = gr.Textbox(label="Initialization text", value="*")
|
initialization_text = gr.Textbox(label="Initialization text", value="*")
|
||||||
nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1)
|
nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1)
|
||||||
|
overwrite_old_embedding = gr.Checkbox(value=False, label="Overwrite Old Embedding")
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column(scale=3):
|
with gr.Column(scale=3):
|
||||||
|
@ -1224,6 +1225,7 @@ def create_ui(wrap_gradio_gpu_call):
|
||||||
new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"])
|
new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"])
|
||||||
new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'")
|
new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'")
|
||||||
new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization")
|
new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization")
|
||||||
|
overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork")
|
||||||
new_hypernetwork_activation_func = gr.Dropdown(value="relu", label="Select activation function of hypernetwork", choices=["linear", "relu", "leakyrelu"])
|
new_hypernetwork_activation_func = gr.Dropdown(value="relu", label="Select activation function of hypernetwork", choices=["linear", "relu", "leakyrelu"])
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
|
@ -1238,6 +1240,7 @@ def create_ui(wrap_gradio_gpu_call):
|
||||||
process_dst = gr.Textbox(label='Destination directory')
|
process_dst = gr.Textbox(label='Destination directory')
|
||||||
process_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
|
process_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
|
||||||
process_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
|
process_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
|
||||||
|
preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', value="ignore", choices=["ignore", "copy", "prepend", "append"])
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
process_flip = gr.Checkbox(label='Create flipped copies')
|
process_flip = gr.Checkbox(label='Create flipped copies')
|
||||||
|
@ -1253,14 +1256,17 @@ def create_ui(wrap_gradio_gpu_call):
|
||||||
run_preprocess = gr.Button(value="Preprocess", variant='primary')
|
run_preprocess = gr.Button(value="Preprocess", variant='primary')
|
||||||
|
|
||||||
with gr.Tab(label="Train"):
|
with gr.Tab(label="Train"):
|
||||||
gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding; must specify a directory with a set of 1:1 ratio images</p>")
|
gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images <a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\" style=\"font-weight:bold;\">[wiki]</a></p>")
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
|
train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
|
||||||
create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name")
|
create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name")
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=[x for x in shared.hypernetworks.keys()])
|
train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=[x for x in shared.hypernetworks.keys()])
|
||||||
create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted([x for x in shared.hypernetworks.keys()])}, "refresh_train_hypernetwork_name")
|
create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted([x for x in shared.hypernetworks.keys()])}, "refresh_train_hypernetwork_name")
|
||||||
learn_rate = gr.Textbox(label='Learning rate', placeholder="Learning rate", value="0.005")
|
with gr.Row():
|
||||||
|
embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005")
|
||||||
|
hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001")
|
||||||
|
|
||||||
batch_size = gr.Number(label='Batch size', value=1, precision=0)
|
batch_size = gr.Number(label='Batch size', value=1, precision=0)
|
||||||
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
|
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
|
||||||
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
|
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
|
||||||
|
@ -1294,6 +1300,7 @@ def create_ui(wrap_gradio_gpu_call):
|
||||||
new_embedding_name,
|
new_embedding_name,
|
||||||
initialization_text,
|
initialization_text,
|
||||||
nvpt,
|
nvpt,
|
||||||
|
overwrite_old_embedding,
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
train_embedding_name,
|
train_embedding_name,
|
||||||
|
@ -1307,6 +1314,7 @@ def create_ui(wrap_gradio_gpu_call):
|
||||||
inputs=[
|
inputs=[
|
||||||
new_hypernetwork_name,
|
new_hypernetwork_name,
|
||||||
new_hypernetwork_sizes,
|
new_hypernetwork_sizes,
|
||||||
|
overwrite_old_hypernetwork,
|
||||||
new_hypernetwork_layer_structure,
|
new_hypernetwork_layer_structure,
|
||||||
new_hypernetwork_add_layer_norm,
|
new_hypernetwork_add_layer_norm,
|
||||||
new_hypernetwork_activation_func,
|
new_hypernetwork_activation_func,
|
||||||
|
@ -1326,6 +1334,7 @@ def create_ui(wrap_gradio_gpu_call):
|
||||||
process_dst,
|
process_dst,
|
||||||
process_width,
|
process_width,
|
||||||
process_height,
|
process_height,
|
||||||
|
preprocess_txt_action,
|
||||||
process_flip,
|
process_flip,
|
||||||
process_split,
|
process_split,
|
||||||
process_caption,
|
process_caption,
|
||||||
|
@ -1342,7 +1351,7 @@ def create_ui(wrap_gradio_gpu_call):
|
||||||
_js="start_training_textual_inversion",
|
_js="start_training_textual_inversion",
|
||||||
inputs=[
|
inputs=[
|
||||||
train_embedding_name,
|
train_embedding_name,
|
||||||
learn_rate,
|
embedding_learn_rate,
|
||||||
batch_size,
|
batch_size,
|
||||||
dataset_directory,
|
dataset_directory,
|
||||||
log_directory,
|
log_directory,
|
||||||
|
@ -1367,7 +1376,7 @@ def create_ui(wrap_gradio_gpu_call):
|
||||||
_js="start_training_textual_inversion",
|
_js="start_training_textual_inversion",
|
||||||
inputs=[
|
inputs=[
|
||||||
train_hypernetwork_name,
|
train_hypernetwork_name,
|
||||||
learn_rate,
|
hypernetwork_learn_rate,
|
||||||
batch_size,
|
batch_size,
|
||||||
dataset_directory,
|
dataset_directory,
|
||||||
log_directory,
|
log_directory,
|
||||||
|
|
Loading…
Reference in a new issue