Gradient accumulation, autocast fix, new latent sampling method, etc

This commit is contained in:
flamelaw 2022-11-20 12:35:26 +09:00
parent 47a44c7e42
commit bd68e35de3
7 changed files with 408 additions and 273 deletions

View file

@ -367,13 +367,13 @@ def report_statistics(loss_info:dict):
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
# images allows training previews to have infotext. Importing it at the top causes a circular import problem. # images allows training previews to have infotext. Importing it at the top causes a circular import problem.
from modules import images from modules import images
save_hypernetwork_every = save_hypernetwork_every or 0 save_hypernetwork_every = save_hypernetwork_every or 0
create_image_every = create_image_every or 0 create_image_every = create_image_every or 0
textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, data_root, template_file, steps, save_hypernetwork_every, create_image_every, log_directory, name="hypernetwork") textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_hypernetwork_every, create_image_every, log_directory, name="hypernetwork")
path = shared.hypernetworks.get(hypernetwork_name, None) path = shared.hypernetworks.get(hypernetwork_name, None)
shared.loaded_hypernetwork = Hypernetwork() shared.loaded_hypernetwork = Hypernetwork()
@ -403,29 +403,25 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
hypernetwork = shared.loaded_hypernetwork hypernetwork = shared.loaded_hypernetwork
checkpoint = sd_models.select_checkpoint() checkpoint = sd_models.select_checkpoint()
ititial_step = hypernetwork.step or 0 initial_step = hypernetwork.step or 0
if ititial_step >= steps: if initial_step >= steps:
shared.state.textinfo = f"Model has already been trained beyond specified max steps" shared.state.textinfo = f"Model has already been trained beyond specified max steps"
return hypernetwork, filename return hypernetwork, filename
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step) scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
# dataset loading may take a while, so input validations and early returns should be done before this # dataset loading may take a while, so input validations and early returns should be done before this
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"):
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size) pin_memory = shared.opts.pin_memory
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method)
dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, batch_size=ds.batch_size, pin_memory=pin_memory)
if unload: if unload:
shared.sd_model.cond_stage_model.to(devices.cpu) shared.sd_model.cond_stage_model.to(devices.cpu)
shared.sd_model.first_stage_model.to(devices.cpu) shared.sd_model.first_stage_model.to(devices.cpu)
size = len(ds.indexes)
loss_dict = defaultdict(lambda : deque(maxlen = 1024))
losses = torch.zeros((size,))
previous_mean_losses = [0]
previous_mean_loss = 0
print("Mean loss of {} elements".format(size))
weights = hypernetwork.weights() weights = hypernetwork.weights()
for weight in weights: for weight in weights:
weight.requires_grad = True weight.requires_grad = True
@ -446,62 +442,81 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
print("Cannot resume from saved optimizer!") print("Cannot resume from saved optimizer!")
print(e) print(e)
scaler = torch.cuda.amp.GradScaler()
batch_size = ds.batch_size
gradient_step = ds.gradient_step
# n steps = batch_size * gradient_step * n image processed
steps_per_epoch = len(ds) // batch_size // gradient_step
max_steps_per_epoch = len(ds) // batch_size - (len(ds) // batch_size) % gradient_step
loss_step = 0
_loss_step = 0 #internal
# size = len(ds.indexes)
# loss_dict = defaultdict(lambda : deque(maxlen = 1024))
# losses = torch.zeros((size,))
# previous_mean_losses = [0]
# previous_mean_loss = 0
# print("Mean loss of {} elements".format(size))
steps_without_grad = 0 steps_without_grad = 0
last_saved_file = "<none>" last_saved_file = "<none>"
last_saved_image = "<none>" last_saved_image = "<none>"
forced_filename = "<none>" forced_filename = "<none>"
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step) pbar = tqdm.tqdm(total=steps - initial_step)
for i, entries in pbar: try:
hypernetwork.step = i + ititial_step for i in range((steps-initial_step) * gradient_step):
if len(loss_dict) > 0: if scheduler.finished:
previous_mean_losses = [i[-1] for i in loss_dict.values()] break
previous_mean_loss = mean(previous_mean_losses) if shared.state.interrupted:
break
for j, batch in enumerate(dl):
# works as a drop_last=True for gradient accumulation
if j == max_steps_per_epoch:
break
scheduler.apply(optimizer, hypernetwork.step) scheduler.apply(optimizer, hypernetwork.step)
if scheduler.finished: if scheduler.finished:
break break
if shared.state.interrupted: if shared.state.interrupted:
break break
with torch.autocast("cuda"): with torch.autocast("cuda"):
c = stack_conds([entry.cond for entry in entries]).to(devices.device) x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
# c = torch.vstack([entry.cond for entry in entries]).to(devices.device) if tag_drop_out != 0 or shuffle_tags:
x = torch.stack([entry.latent for entry in entries]).to(devices.device) shared.sd_model.cond_stage_model.to(devices.device)
loss = shared.sd_model(x, c)[0] c = shared.sd_model.cond_stage_model(batch.cond_text).to(devices.device, non_blocking=pin_memory)
shared.sd_model.cond_stage_model.to(devices.cpu)
else:
c = stack_conds(batch.cond).to(devices.device, non_blocking=pin_memory)
loss = shared.sd_model(x, c)[0] / gradient_step
del x del x
del c del c
losses[hypernetwork.step % losses.shape[0]] = loss.item() _loss_step += loss.item()
for entry in entries: scaler.scale(loss).backward()
loss_dict[entry.filename].append(loss.item()) # go back until we reach gradient accumulation steps
if (j + 1) % gradient_step != 0:
optimizer.zero_grad() continue
weights[0].grad = None # print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.7f}")
loss.backward() # scaler.unscale_(optimizer)
# print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.15f}")
if weights[0].grad is None: # torch.nn.utils.clip_grad_norm_(weights, max_norm=1.0)
steps_without_grad += 1 # print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.15f}")
else: scaler.step(optimizer)
steps_without_grad = 0 scaler.update()
assert steps_without_grad < 10, 'no gradient found for the trained weight after backward() for 10 steps in a row; this is a bug; training cannot continue' hypernetwork.step += 1
pbar.update()
optimizer.step() optimizer.zero_grad(set_to_none=True)
loss_step = _loss_step
_loss_step = 0
steps_done = hypernetwork.step + 1 steps_done = hypernetwork.step + 1
if torch.isnan(losses[hypernetwork.step % losses.shape[0]]): epoch_num = hypernetwork.step // steps_per_epoch
raise RuntimeError("Loss diverged.") epoch_step = hypernetwork.step % steps_per_epoch
if len(previous_mean_losses) > 1:
std = stdev(previous_mean_losses)
else:
std = 0
dataset_loss_info = f"dataset loss:{mean(previous_mean_losses):.3f}" + u"\u00B1" + f"({std / (len(previous_mean_losses) ** 0.5):.3f})"
pbar.set_description(dataset_loss_info)
pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}")
if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0: if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0:
# Before saving, change name to match current checkpoint. # Before saving, change name to match current checkpoint.
hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}' hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}'
@ -512,8 +527,8 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file) save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file)
hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory. hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), { textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, steps_per_epoch, {
"loss": f"{previous_mean_loss:.7f}", "loss": f"{loss_step:.7f}",
"learn_rate": scheduler.learn_rate "learn_rate": scheduler.learn_rate
}) })
@ -521,7 +536,6 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
forced_filename = f'{hypernetwork_name}-{steps_done}' forced_filename = f'{hypernetwork_name}-{steps_done}'
last_saved_image = os.path.join(images_dir, forced_filename) last_saved_image = os.path.join(images_dir, forced_filename)
optimizer.zero_grad()
shared.sd_model.cond_stage_model.to(devices.device) shared.sd_model.cond_stage_model.to(devices.device)
shared.sd_model.first_stage_model.to(devices.device) shared.sd_model.first_stage_model.to(devices.device)
@ -541,8 +555,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
p.width = preview_width p.width = preview_width
p.height = preview_height p.height = preview_height
else: else:
p.prompt = entries[0].cond_text p.prompt = batch.cond_text[0]
p.steps = 20 p.steps = 20
p.width = training_width
p.height = training_height
preview_text = p.prompt preview_text = p.prompt
@ -562,15 +578,19 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
shared.state.textinfo = f""" shared.state.textinfo = f"""
<p> <p>
Loss: {previous_mean_loss:.7f}<br/> Loss: {loss_step:.7f}<br/>
Step: {hypernetwork.step}<br/> Step: {hypernetwork.step}<br/>
Last prompt: {html.escape(entries[0].cond_text)}<br/> Last prompt: {html.escape(batch.cond_text[0])}<br/>
Last saved hypernetwork: {html.escape(last_saved_file)}<br/> Last saved hypernetwork: {html.escape(last_saved_file)}<br/>
Last saved image: {html.escape(last_saved_image)}<br/> Last saved image: {html.escape(last_saved_image)}<br/>
</p> </p>
""" """
except Exception:
report_statistics(loss_dict) print(traceback.format_exc(), file=sys.stderr)
finally:
pbar.leave = False
pbar.close()
#report_statistics(loss_dict)
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt') filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
hypernetwork.optimizer_name = optimizer_name hypernetwork.optimizer_name = optimizer_name
@ -579,6 +599,9 @@ Last saved image: {html.escape(last_saved_image)}<br/>
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename) save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename)
del optimizer del optimizer
hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory. hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
shared.sd_model.cond_stage_model.to(devices.device)
shared.sd_model.first_stage_model.to(devices.device)
return hypernetwork, filename return hypernetwork, filename
def save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename): def save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename):

View file

@ -8,7 +8,7 @@ from torch import einsum
from torch.nn.functional import silu from torch.nn.functional import silu
import modules.textual_inversion.textual_inversion import modules.textual_inversion.textual_inversion
from modules import prompt_parser, devices, sd_hijack_optimizations, shared from modules import prompt_parser, devices, sd_hijack_optimizations, shared, sd_hijack_checkpoint
from modules.shared import opts, device, cmd_opts from modules.shared import opts, device, cmd_opts
from modules.sd_hijack_optimizations import invokeAI_mps_available from modules.sd_hijack_optimizations import invokeAI_mps_available
@ -59,6 +59,10 @@ def undo_optimizations():
def get_target_prompt_token_count(token_count): def get_target_prompt_token_count(token_count):
return math.ceil(max(token_count, 1) / 75) * 75 return math.ceil(max(token_count, 1) / 75) * 75
def fix_checkpoint():
ldm.modules.attention.BasicTransformerBlock.forward = sd_hijack_checkpoint.BasicTransformerBlock_forward
ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = sd_hijack_checkpoint.ResBlock_forward
ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = sd_hijack_checkpoint.AttentionBlock_forward
class StableDiffusionModelHijack: class StableDiffusionModelHijack:
fixes = None fixes = None
@ -78,6 +82,7 @@ class StableDiffusionModelHijack:
self.clip = m.cond_stage_model self.clip = m.cond_stage_model
apply_optimizations() apply_optimizations()
fix_checkpoint()
def flatten(el): def flatten(el):
flattened = [flatten(children) for children in el.children()] flattened = [flatten(children) for children in el.children()]

View file

@ -0,0 +1,10 @@
from torch.utils.checkpoint import checkpoint
def BasicTransformerBlock_forward(self, x, context=None):
return checkpoint(self._forward, x, context)
def AttentionBlock_forward(self, x):
return checkpoint(self._forward, x)
def ResBlock_forward(self, x, emb):
return checkpoint(self._forward, x, emb)

View file

@ -322,8 +322,7 @@ options_templates.update(options_section(('system', "System"), {
options_templates.update(options_section(('training', "Training"), { options_templates.update(options_section(('training', "Training"), {
"unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."), "unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."),
"shuffle_tags": OptionInfo(False, "Shuffleing tags by ',' when create texts."), "pin_memory": OptionInfo(False, "Turn on pin_memory for DataLoader. Makes training slightly faster but can increase memory usage."),
"tag_drop_out": OptionInfo(0, "Dropout tags when create texts", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.1}),
"save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training can be resumed with HN itself and matching optim file."), "save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training can be resumed with HN itself and matching optim file."),
"dataset_filename_word_regex": OptionInfo("", "Filename word regex"), "dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
"dataset_filename_join_string": OptionInfo(" ", "Filename join string"), "dataset_filename_join_string": OptionInfo(" ", "Filename join string"),

View file

@ -3,7 +3,7 @@ import numpy as np
import PIL import PIL
import torch import torch
from PIL import Image from PIL import Image
from torch.utils.data import Dataset from torch.utils.data import Dataset, DataLoader
from torchvision import transforms from torchvision import transforms
import random import random
@ -11,25 +11,28 @@ import tqdm
from modules import devices, shared from modules import devices, shared
import re import re
from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
re_numbers_at_start = re.compile(r"^[-\d]+\s*") re_numbers_at_start = re.compile(r"^[-\d]+\s*")
class DatasetEntry: class DatasetEntry:
def __init__(self, filename=None, latent=None, filename_text=None): def __init__(self, filename=None, filename_text=None, latent_dist=None, latent_sample=None, cond=None, cond_text=None, pixel_values=None):
self.filename = filename self.filename = filename
self.latent = latent
self.filename_text = filename_text self.filename_text = filename_text
self.cond = None self.latent_dist = latent_dist
self.cond_text = None self.latent_sample = latent_sample
self.cond = cond
self.cond_text = cond_text
self.pixel_values = pixel_values
class PersonalizedBase(Dataset): class PersonalizedBase(Dataset):
def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None, include_cond=False, batch_size=1): def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once'):
re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None
self.placeholder_token = placeholder_token self.placeholder_token = placeholder_token
self.batch_size = batch_size
self.width = width self.width = width
self.height = height self.height = height
self.flip = transforms.RandomHorizontalFlip(p=flip_p) self.flip = transforms.RandomHorizontalFlip(p=flip_p)
@ -45,11 +48,16 @@ class PersonalizedBase(Dataset):
assert os.path.isdir(data_root), "Dataset directory doesn't exist" assert os.path.isdir(data_root), "Dataset directory doesn't exist"
assert os.listdir(data_root), "Dataset directory is empty" assert os.listdir(data_root), "Dataset directory is empty"
cond_model = shared.sd_model.cond_stage_model
self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)] self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
self.shuffle_tags = shuffle_tags
self.tag_drop_out = tag_drop_out
print("Preparing dataset...") print("Preparing dataset...")
for path in tqdm.tqdm(self.image_paths): for path in tqdm.tqdm(self.image_paths):
if shared.state.interrupted:
raise Exception("inturrupted")
try: try:
image = Image.open(path).convert('RGB').resize((self.width, self.height), PIL.Image.BICUBIC) image = Image.open(path).convert('RGB').resize((self.width, self.height), PIL.Image.BICUBIC)
except Exception: except Exception:
@ -71,37 +79,58 @@ class PersonalizedBase(Dataset):
npimage = np.array(image).astype(np.uint8) npimage = np.array(image).astype(np.uint8)
npimage = (npimage / 127.5 - 1.0).astype(np.float32) npimage = (npimage / 127.5 - 1.0).astype(np.float32)
torchdata = torch.from_numpy(npimage).to(device=device, dtype=torch.float32) torchdata = torch.from_numpy(npimage).permute(2, 0, 1).to(device=device, dtype=torch.float32)
torchdata = torch.moveaxis(torchdata, 2, 0) latent_sample = None
init_latent = model.get_first_stage_encoding(model.encode_first_stage(torchdata.unsqueeze(dim=0))).squeeze() with torch.autocast("cuda"):
init_latent = init_latent.to(devices.cpu) latent_dist = model.encode_first_stage(torchdata.unsqueeze(dim=0))
entry = DatasetEntry(filename=path, filename_text=filename_text, latent=init_latent) if latent_sampling_method == "once" or (latent_sampling_method == "deterministic" and not isinstance(latent_dist, DiagonalGaussianDistribution)):
latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu)
latent_sampling_method = "once"
entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample)
elif latent_sampling_method == "deterministic":
# Works only for DiagonalGaussianDistribution
latent_dist.std = 0
latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu)
entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample)
elif latent_sampling_method == "random":
entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist)
if include_cond: if not (self.tag_drop_out != 0 or self.shuffle_tags):
entry.cond_text = self.create_text(filename_text) entry.cond_text = self.create_text(filename_text)
if include_cond and not (self.tag_drop_out != 0 or self.shuffle_tags):
with torch.autocast("cuda"):
entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0) entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0)
# elif not include_cond:
# _, _, _, _, hijack_fixes, token_count = cond_model.process_text([entry.cond_text])
# max_n = token_count // 75
# index_list = [ [] for _ in range(max_n + 1) ]
# for n, (z, _) in hijack_fixes[0]:
# index_list[n].append(z)
# with torch.autocast("cuda"):
# entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0)
# entry.emb_index = index_list
self.dataset.append(entry) self.dataset.append(entry)
del torchdata
del latent_dist
del latent_sample
assert len(self.dataset) > 0, "No images have been found in the dataset." self.length = len(self.dataset)
self.length = len(self.dataset) * repeats // batch_size assert self.length > 0, "No images have been found in the dataset."
self.batch_size = min(batch_size, self.length)
self.dataset_length = len(self.dataset) self.gradient_step = min(gradient_step, self.length // self.batch_size)
self.indexes = None self.latent_sampling_method = latent_sampling_method
self.shuffle()
def shuffle(self):
self.indexes = np.random.permutation(self.dataset_length)
def create_text(self, filename_text): def create_text(self, filename_text):
text = random.choice(self.lines) text = random.choice(self.lines)
text = text.replace("[name]", self.placeholder_token) text = text.replace("[name]", self.placeholder_token)
tags = filename_text.split(',') tags = filename_text.split(',')
if shared.opts.tag_drop_out != 0: if self.tag_drop_out != 0:
tags = [t for t in tags if random.random() > shared.opts.tag_drop_out] tags = [t for t in tags if random.random() > self.tag_drop_out]
if shared.opts.shuffle_tags: if self.shuffle_tags:
random.shuffle(tags) random.shuffle(tags)
text = text.replace("[filewords]", ','.join(tags)) text = text.replace("[filewords]", ','.join(tags))
return text return text
@ -110,19 +139,28 @@ class PersonalizedBase(Dataset):
return self.length return self.length
def __getitem__(self, i): def __getitem__(self, i):
res = [] entry = self.dataset[i]
if self.tag_drop_out != 0 or self.shuffle_tags:
for j in range(self.batch_size):
position = i * self.batch_size + j
if position % len(self.indexes) == 0:
self.shuffle()
index = self.indexes[position % len(self.indexes)]
entry = self.dataset[index]
if entry.cond is None:
entry.cond_text = self.create_text(entry.filename_text) entry.cond_text = self.create_text(entry.filename_text)
if self.latent_sampling_method == "random":
entry.latent_sample = shared.sd_model.get_first_stage_encoding(entry.latent_dist)
return entry
res.append(entry) class PersonalizedDataLoader(DataLoader):
def __init__(self, *args, **kwargs):
super(PersonalizedDataLoader, self).__init__(shuffle=True, drop_last=True, *args, **kwargs)
self.collate_fn = collate_wrapper
return res
class BatchLoader:
def __init__(self, data):
self.cond_text = [entry.cond_text for entry in data]
self.cond = [entry.cond for entry in data]
self.latent_sample = torch.stack([entry.latent_sample for entry in data]).squeeze(1)
def pin_memory(self):
self.latent_sample = self.latent_sample.pin_memory()
return self
def collate_wrapper(batch):
return BatchLoader(batch)

View file

@ -184,7 +184,7 @@ def write_loss(log_directory, filename, step, epoch_len, values):
if shared.opts.training_write_csv_every == 0: if shared.opts.training_write_csv_every == 0:
return return
if (step + 1) % shared.opts.training_write_csv_every != 0: if step % shared.opts.training_write_csv_every != 0:
return return
write_csv_header = False if os.path.exists(os.path.join(log_directory, filename)) else True write_csv_header = False if os.path.exists(os.path.join(log_directory, filename)) else True
@ -194,21 +194,23 @@ def write_loss(log_directory, filename, step, epoch_len, values):
if write_csv_header: if write_csv_header:
csv_writer.writeheader() csv_writer.writeheader()
epoch = step // epoch_len epoch = (step - 1) // epoch_len
epoch_step = step % epoch_len epoch_step = (step - 1) % epoch_len
csv_writer.writerow({ csv_writer.writerow({
"step": step + 1, "step": step,
"epoch": epoch, "epoch": epoch,
"epoch_step": epoch_step + 1, "epoch_step": epoch_step,
**values, **values,
}) })
def validate_train_inputs(model_name, learn_rate, batch_size, data_root, template_file, steps, save_model_every, create_image_every, log_directory, name="embedding"): def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_model_every, create_image_every, log_directory, name="embedding"):
assert model_name, f"{name} not selected" assert model_name, f"{name} not selected"
assert learn_rate, "Learning rate is empty or 0" assert learn_rate, "Learning rate is empty or 0"
assert isinstance(batch_size, int), "Batch size must be integer" assert isinstance(batch_size, int), "Batch size must be integer"
assert batch_size > 0, "Batch size must be positive" assert batch_size > 0, "Batch size must be positive"
assert isinstance(gradient_step, int), "Gradient accumulation step must be integer"
assert gradient_step > 0, "Gradient accumulation step must be positive"
assert data_root, "Dataset directory is empty" assert data_root, "Dataset directory is empty"
assert os.path.isdir(data_root), "Dataset directory doesn't exist" assert os.path.isdir(data_root), "Dataset directory doesn't exist"
assert os.listdir(data_root), "Dataset directory is empty" assert os.listdir(data_root), "Dataset directory is empty"
@ -224,10 +226,10 @@ def validate_train_inputs(model_name, learn_rate, batch_size, data_root, templat
if save_model_every or create_image_every: if save_model_every or create_image_every:
assert log_directory, "Log directory is empty" assert log_directory, "Log directory is empty"
def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
save_embedding_every = save_embedding_every or 0 save_embedding_every = save_embedding_every or 0
create_image_every = create_image_every or 0 create_image_every = create_image_every or 0
validate_train_inputs(embedding_name, learn_rate, batch_size, data_root, template_file, steps, save_embedding_every, create_image_every, log_directory, name="embedding") validate_train_inputs(embedding_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_embedding_every, create_image_every, log_directory, name="embedding")
shared.state.textinfo = "Initializing textual inversion training..." shared.state.textinfo = "Initializing textual inversion training..."
shared.state.job_count = steps shared.state.job_count = steps
@ -255,76 +257,112 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
else: else:
images_embeds_dir = None images_embeds_dir = None
cond_model = shared.sd_model.cond_stage_model
hijack = sd_hijack.model_hijack hijack = sd_hijack.model_hijack
embedding = hijack.embedding_db.word_embeddings[embedding_name] embedding = hijack.embedding_db.word_embeddings[embedding_name]
checkpoint = sd_models.select_checkpoint() checkpoint = sd_models.select_checkpoint()
ititial_step = embedding.step or 0 initial_step = embedding.step or 0
if ititial_step >= steps: if initial_step >= steps:
shared.state.textinfo = f"Model has already been trained beyond specified max steps" shared.state.textinfo = f"Model has already been trained beyond specified max steps"
return embedding, filename return embedding, filename
scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
# dataset loading may take a while, so input validations and early returns should be done before this # dataset loading may take a while, so input validations and early returns should be done before this
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"):
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file, batch_size=batch_size) pin_memory = shared.opts.pin_memory
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method)
latent_sampling_method = ds.latent_sampling_method
dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, batch_size=ds.batch_size, pin_memory=False)
if unload: if unload:
shared.sd_model.first_stage_model.to(devices.cpu) shared.sd_model.first_stage_model.to(devices.cpu)
embedding.vec.requires_grad = True embedding.vec.requires_grad = True
optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate) optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate)
scaler = torch.cuda.amp.GradScaler()
batch_size = ds.batch_size
gradient_step = ds.gradient_step
# n steps = batch_size * gradient_step * n image processed
steps_per_epoch = len(ds) // batch_size // gradient_step
max_steps_per_epoch = len(ds) // batch_size - (len(ds) // batch_size) % gradient_step
loss_step = 0
_loss_step = 0 #internal
losses = torch.zeros((32,))
last_saved_file = "<none>" last_saved_file = "<none>"
last_saved_image = "<none>" last_saved_image = "<none>"
forced_filename = "<none>" forced_filename = "<none>"
embedding_yet_to_be_embedded = False embedding_yet_to_be_embedded = False
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step) pbar = tqdm.tqdm(total=steps - initial_step)
for i, entries in pbar: try:
embedding.step = i + ititial_step for i in range((steps-initial_step) * gradient_step):
if scheduler.finished:
break
if shared.state.interrupted:
break
for j, batch in enumerate(dl):
# works as a drop_last=True for gradient accumulation
if j == max_steps_per_epoch:
break
scheduler.apply(optimizer, embedding.step) scheduler.apply(optimizer, embedding.step)
if scheduler.finished: if scheduler.finished:
break break
if shared.state.interrupted: if shared.state.interrupted:
break break
with torch.autocast("cuda"): with torch.autocast("cuda"):
c = cond_model([entry.cond_text for entry in entries]) # c = stack_conds(batch.cond).to(devices.device)
x = torch.stack([entry.latent for entry in entries]).to(devices.device) # mask = torch.tensor(batch.emb_index).to(devices.device, non_blocking=pin_memory)
loss = shared.sd_model(x, c)[0] # print(mask)
# c[:, 1:1+embedding.vec.shape[0]] = embedding.vec.to(devices.device, non_blocking=pin_memory)
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
c = shared.sd_model.cond_stage_model(batch.cond_text)
loss = shared.sd_model(x, c)[0] / gradient_step
del x del x
losses[embedding.step % losses.shape[0]] = loss.item() _loss_step += loss.item()
scaler.scale(loss).backward()
optimizer.zero_grad() # go back until we reach gradient accumulation steps
loss.backward() if (j + 1) % gradient_step != 0:
optimizer.step() continue
#print(f"grad:{embedding.vec.grad.detach().cpu().abs().mean().item():.7f}")
#scaler.unscale_(optimizer)
#print(f"grad:{embedding.vec.grad.detach().cpu().abs().mean().item():.7f}")
#torch.nn.utils.clip_grad_norm_(embedding.vec, max_norm=1.0)
#print(f"grad:{embedding.vec.grad.detach().cpu().abs().mean().item():.7f}")
scaler.step(optimizer)
scaler.update()
embedding.step += 1
pbar.update()
optimizer.zero_grad(set_to_none=True)
loss_step = _loss_step
_loss_step = 0
steps_done = embedding.step + 1 steps_done = embedding.step + 1
epoch_num = embedding.step // len(ds) epoch_num = embedding.step // steps_per_epoch
epoch_step = embedding.step % len(ds) epoch_step = embedding.step % steps_per_epoch
pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{len(ds)}]loss: {losses.mean():.7f}")
pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}")
if embedding_dir is not None and steps_done % save_embedding_every == 0: if embedding_dir is not None and steps_done % save_embedding_every == 0:
# Before saving, change name to match current checkpoint. # Before saving, change name to match current checkpoint.
embedding_name_every = f'{embedding_name}-{steps_done}' embedding_name_every = f'{embedding_name}-{steps_done}'
last_saved_file = os.path.join(embedding_dir, f'{embedding_name_every}.pt') last_saved_file = os.path.join(embedding_dir, f'{embedding_name_every}.pt')
#if shared.opts.save_optimizer_state:
#embedding.optimizer_state_dict = optimizer.state_dict()
save_embedding(embedding, checkpoint, embedding_name_every, last_saved_file, remove_cached_checksum=True) save_embedding(embedding, checkpoint, embedding_name_every, last_saved_file, remove_cached_checksum=True)
embedding_yet_to_be_embedded = True embedding_yet_to_be_embedded = True
write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, len(ds), { write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, steps_per_epoch, {
"loss": f"{losses.mean():.7f}", "loss": f"{loss_step:.7f}",
"learn_rate": scheduler.learn_rate "learn_rate": scheduler.learn_rate
}) })
@ -351,7 +389,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
p.width = preview_width p.width = preview_width
p.height = preview_height p.height = preview_height
else: else:
p.prompt = entries[0].cond_text p.prompt = batch.cond_text[0]
p.steps = 20 p.steps = 20
p.width = training_width p.width = training_width
p.height = training_height p.height = training_height
@ -359,12 +397,15 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
preview_text = p.prompt preview_text = p.prompt
processed = processing.process_images(p) processed = processing.process_images(p)
image = processed.images[0] image = processed.images[0] if len(processed.images) > 0 else None
if unload: if unload:
shared.sd_model.first_stage_model.to(devices.cpu) shared.sd_model.first_stage_model.to(devices.cpu)
if image is not None:
shared.state.current_image = image shared.state.current_image = image
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
last_saved_image += f", prompt: {preview_text}"
if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded: if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
@ -399,16 +440,21 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
shared.state.textinfo = f""" shared.state.textinfo = f"""
<p> <p>
Loss: {losses.mean():.7f}<br/> Loss: {loss_step:.7f}<br/>
Step: {embedding.step}<br/> Step: {embedding.step}<br/>
Last prompt: {html.escape(entries[0].cond_text)}<br/> Last prompt: {html.escape(batch.cond_text[0])}<br/>
Last saved embedding: {html.escape(last_saved_file)}<br/> Last saved embedding: {html.escape(last_saved_file)}<br/>
Last saved image: {html.escape(last_saved_image)}<br/> Last saved image: {html.escape(last_saved_image)}<br/>
</p> </p>
""" """
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt') filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True) save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True)
except Exception:
print(traceback.format_exc(), file=sys.stderr)
pass
finally:
pbar.leave = False
pbar.close()
shared.sd_model.first_stage_model.to(devices.device) shared.sd_model.first_stage_model.to(devices.device)
return embedding, filename return embedding, filename

View file

@ -1289,6 +1289,7 @@ def create_ui(wrap_gradio_gpu_call):
hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001") hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001")
batch_size = gr.Number(label='Batch size', value=1, precision=0) batch_size = gr.Number(label='Batch size', value=1, precision=0)
gradient_step = gr.Number(label='Gradient accumulation steps', value=1, precision=0)
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images") dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion") log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt")) template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"))
@ -1299,6 +1300,11 @@ def create_ui(wrap_gradio_gpu_call):
save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0) save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0)
save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True) save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True)
preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False) preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False)
with gr.Row():
shuffle_tags = gr.Checkbox(label="Shuffle tags by ',' when creating prompts.", value=False)
tag_drop_out = gr.Slider(minimum=0, maximum=1, step=0.1, label="Drop out tags when creating prompts.", value=0)
with gr.Row():
latent_sampling_method = gr.Radio(label='Choose latent sampling method', value="once", choices=['once', 'deterministic', 'random'])
with gr.Row(): with gr.Row():
interrupt_training = gr.Button(value="Interrupt") interrupt_training = gr.Button(value="Interrupt")
@ -1387,11 +1393,15 @@ def create_ui(wrap_gradio_gpu_call):
train_embedding_name, train_embedding_name,
embedding_learn_rate, embedding_learn_rate,
batch_size, batch_size,
gradient_step,
dataset_directory, dataset_directory,
log_directory, log_directory,
training_width, training_width,
training_height, training_height,
steps, steps,
shuffle_tags,
tag_drop_out,
latent_sampling_method,
create_image_every, create_image_every,
save_embedding_every, save_embedding_every,
template_file, template_file,
@ -1412,11 +1422,15 @@ def create_ui(wrap_gradio_gpu_call):
train_hypernetwork_name, train_hypernetwork_name,
hypernetwork_learn_rate, hypernetwork_learn_rate,
batch_size, batch_size,
gradient_step,
dataset_directory, dataset_directory,
log_directory, log_directory,
training_width, training_width,
training_height, training_height,
steps, steps,
shuffle_tags,
tag_drop_out,
latent_sampling_method,
create_image_every, create_image_every,
save_embedding_every, save_embedding_every,
template_file, template_file,