Textual Inversion: Added custom training image size and number of repeats per input image in a single epoch

This commit is contained in:
alg-wiki 2022-10-10 17:07:46 +09:00
parent 8acc901ba3
commit 3110f895b2
No known key found for this signature in database
GPG key ID: 9D27E57D1F43E750
4 changed files with 24 additions and 9 deletions

View file

@ -15,13 +15,13 @@ re_tag = re.compile(r"[a-zA-Z][_\w\d()]+")
class PersonalizedBase(Dataset): class PersonalizedBase(Dataset):
def __init__(self, data_root, size=None, repeats=100, flip_p=0.5, placeholder_token="*", width=512, height=512, model=None, device=None, template_file=None): def __init__(self, data_root, size, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None):
self.placeholder_token = placeholder_token self.placeholder_token = placeholder_token
self.size = size self.size = size
self.width = width self.width = size
self.height = height self.height = size
self.flip = transforms.RandomHorizontalFlip(p=flip_p) self.flip = transforms.RandomHorizontalFlip(p=flip_p)
self.dataset = [] self.dataset = []

View file

@ -7,8 +7,8 @@ import tqdm
from modules import shared, images from modules import shared, images
def preprocess(process_src, process_dst, process_flip, process_split, process_caption): def preprocess(process_src, process_dst, process_size, process_flip, process_split, process_caption):
size = 512 size = process_size
src = os.path.abspath(process_src) src = os.path.abspath(process_src)
dst = os.path.abspath(process_dst) dst = os.path.abspath(process_dst)

View file

@ -6,6 +6,7 @@ import torch
import tqdm import tqdm
import html import html
import datetime import datetime
import math
from modules import shared, devices, sd_hijack, processing, sd_models from modules import shared, devices, sd_hijack, processing, sd_models
@ -156,7 +157,7 @@ def create_embedding(name, num_vectors_per_token, init_text='*'):
return fn return fn
def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, create_image_every, save_embedding_every, template_file): def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_size, steps, num_repeats, create_image_every, save_embedding_every, template_file):
assert embedding_name, 'embedding not selected' assert embedding_name, 'embedding not selected'
shared.state.textinfo = "Initializing textual inversion training..." shared.state.textinfo = "Initializing textual inversion training..."
@ -182,7 +183,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps,
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"): with torch.autocast("cuda"):
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, size=512, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file) ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, size=training_size, repeats=num_repeats, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file)
hijack = sd_hijack.model_hijack hijack = sd_hijack.model_hijack
@ -200,6 +201,9 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps,
if ititial_step > steps: if ititial_step > steps:
return embedding, filename return embedding, filename
tr_img_len = len([os.path.join(data_root, file_path) for file_path in os.listdir(data_root)])
epoch_len = (tr_img_len * num_repeats) + tr_img_len
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step) pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
for i, (x, text) in pbar: for i, (x, text) in pbar:
embedding.step = i + ititial_step embedding.step = i + ititial_step
@ -223,7 +227,10 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps,
loss.backward() loss.backward()
optimizer.step() optimizer.step()
pbar.set_description(f"loss: {losses.mean():.7f}") epoch_num = math.floor(embedding.step / epoch_len)
epoch_step = embedding.step - (epoch_num * epoch_len)
pbar.set_description(f"[Epoch {epoch_num}: {epoch_step}/{epoch_len}]loss: {losses.mean():.7f}")
if embedding.step > 0 and embedding_dir is not None and embedding.step % save_embedding_every == 0: if embedding.step > 0 and embedding_dir is not None and embedding.step % save_embedding_every == 0:
last_saved_file = os.path.join(embedding_dir, f'{embedding_name}-{embedding.step}.pt') last_saved_file = os.path.join(embedding_dir, f'{embedding_name}-{embedding.step}.pt')
@ -236,6 +243,8 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps,
sd_model=shared.sd_model, sd_model=shared.sd_model,
prompt=text, prompt=text,
steps=20, steps=20,
height=training_size,
width=training_size,
do_not_save_grid=True, do_not_save_grid=True,
do_not_save_samples=True, do_not_save_samples=True,
) )

View file

@ -1029,6 +1029,7 @@ def create_ui(wrap_gradio_gpu_call):
process_src = gr.Textbox(label='Source directory') process_src = gr.Textbox(label='Source directory')
process_dst = gr.Textbox(label='Destination directory') process_dst = gr.Textbox(label='Destination directory')
process_size = gr.Slider(minimum=64, maximum=2048, step=64, label="Size (width and height)", value=512)
with gr.Row(): with gr.Row():
process_flip = gr.Checkbox(label='Create flipped copies') process_flip = gr.Checkbox(label='Create flipped copies')
@ -1043,13 +1044,15 @@ def create_ui(wrap_gradio_gpu_call):
run_preprocess = gr.Button(value="Preprocess", variant='primary') run_preprocess = gr.Button(value="Preprocess", variant='primary')
with gr.Group(): with gr.Group():
gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding; must specify a directory with a set of 512x512 images</p>") gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding; must specify a directory with a set of 1:1 ratio images</p>")
train_embedding_name = gr.Dropdown(label='Embedding', choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())) train_embedding_name = gr.Dropdown(label='Embedding', choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
learn_rate = gr.Number(label='Learning rate', value=5.0e-03) learn_rate = gr.Number(label='Learning rate', value=5.0e-03)
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images") dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion") log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt")) template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"))
training_size = gr.Slider(minimum=64, maximum=2048, step=64, label="Size (width and height)", value=512)
steps = gr.Number(label='Max steps', value=100000, precision=0) steps = gr.Number(label='Max steps', value=100000, precision=0)
num_repeats = gr.Number(label='Number of repeats for a single input image per epoch', value=100, precision=0)
create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0) create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0)
save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0) save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0)
@ -1092,6 +1095,7 @@ def create_ui(wrap_gradio_gpu_call):
inputs=[ inputs=[
process_src, process_src,
process_dst, process_dst,
process_size,
process_flip, process_flip,
process_split, process_split,
process_caption, process_caption,
@ -1110,7 +1114,9 @@ def create_ui(wrap_gradio_gpu_call):
learn_rate, learn_rate,
dataset_directory, dataset_directory,
log_directory, log_directory,
training_size,
steps, steps,
num_repeats,
create_image_every, create_image_every,
save_embedding_every, save_embedding_every,
template_file, template_file,