Merge remote-tracking branch 'flamelaw/master'
This commit is contained in:
commit
b48b7999c8
7 changed files with 431 additions and 280 deletions
|
@ -38,7 +38,7 @@ class HypernetworkModule(torch.nn.Module):
|
||||||
activation_dict.update({cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'})
|
activation_dict.update({cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'})
|
||||||
|
|
||||||
def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal',
|
def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal',
|
||||||
add_layer_norm=False, use_dropout=False, activate_output=False, last_layer_dropout=True):
|
add_layer_norm=False, use_dropout=False, activate_output=False, last_layer_dropout=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
assert layer_structure is not None, "layer_structure must not be None"
|
assert layer_structure is not None, "layer_structure must not be None"
|
||||||
|
@ -154,16 +154,28 @@ class Hypernetwork:
|
||||||
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
|
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
|
||||||
self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout),
|
self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout),
|
||||||
)
|
)
|
||||||
|
self.eval_mode()
|
||||||
|
|
||||||
def weights(self):
|
def weights(self):
|
||||||
res = []
|
res = []
|
||||||
|
for k, layers in self.layers.items():
|
||||||
|
for layer in layers:
|
||||||
|
res += layer.parameters()
|
||||||
|
return res
|
||||||
|
|
||||||
|
def train_mode(self):
|
||||||
for k, layers in self.layers.items():
|
for k, layers in self.layers.items():
|
||||||
for layer in layers:
|
for layer in layers:
|
||||||
layer.train()
|
layer.train()
|
||||||
res += layer.trainables()
|
for param in layer.parameters():
|
||||||
|
param.requires_grad = True
|
||||||
|
|
||||||
return res
|
def eval_mode(self):
|
||||||
|
for k, layers in self.layers.items():
|
||||||
|
for layer in layers:
|
||||||
|
layer.eval()
|
||||||
|
for param in layer.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
def save(self, filename):
|
def save(self, filename):
|
||||||
state_dict = {}
|
state_dict = {}
|
||||||
|
@ -367,13 +379,13 @@ def report_statistics(loss_info:dict):
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
||||||
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
|
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
|
||||||
from modules import images
|
from modules import images
|
||||||
|
|
||||||
save_hypernetwork_every = save_hypernetwork_every or 0
|
save_hypernetwork_every = save_hypernetwork_every or 0
|
||||||
create_image_every = create_image_every or 0
|
create_image_every = create_image_every or 0
|
||||||
textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, data_root, template_file, steps, save_hypernetwork_every, create_image_every, log_directory, name="hypernetwork")
|
textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_hypernetwork_every, create_image_every, log_directory, name="hypernetwork")
|
||||||
|
|
||||||
path = shared.hypernetworks.get(hypernetwork_name, None)
|
path = shared.hypernetworks.get(hypernetwork_name, None)
|
||||||
shared.loaded_hypernetwork = Hypernetwork()
|
shared.loaded_hypernetwork = Hypernetwork()
|
||||||
|
@ -403,32 +415,30 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
||||||
hypernetwork = shared.loaded_hypernetwork
|
hypernetwork = shared.loaded_hypernetwork
|
||||||
checkpoint = sd_models.select_checkpoint()
|
checkpoint = sd_models.select_checkpoint()
|
||||||
|
|
||||||
ititial_step = hypernetwork.step or 0
|
initial_step = hypernetwork.step or 0
|
||||||
if ititial_step >= steps:
|
if initial_step >= steps:
|
||||||
shared.state.textinfo = f"Model has already been trained beyond specified max steps"
|
shared.state.textinfo = f"Model has already been trained beyond specified max steps"
|
||||||
return hypernetwork, filename
|
return hypernetwork, filename
|
||||||
|
|
||||||
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
|
scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
|
||||||
|
|
||||||
# dataset loading may take a while, so input validations and early returns should be done before this
|
# dataset loading may take a while, so input validations and early returns should be done before this
|
||||||
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
||||||
with torch.autocast("cuda"):
|
|
||||||
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size)
|
pin_memory = shared.opts.pin_memory
|
||||||
|
|
||||||
|
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method)
|
||||||
|
|
||||||
|
latent_sampling_method = ds.latent_sampling_method
|
||||||
|
|
||||||
|
dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory)
|
||||||
|
|
||||||
if unload:
|
if unload:
|
||||||
shared.sd_model.cond_stage_model.to(devices.cpu)
|
shared.sd_model.cond_stage_model.to(devices.cpu)
|
||||||
shared.sd_model.first_stage_model.to(devices.cpu)
|
shared.sd_model.first_stage_model.to(devices.cpu)
|
||||||
|
|
||||||
size = len(ds.indexes)
|
|
||||||
loss_dict = defaultdict(lambda : deque(maxlen = 1024))
|
|
||||||
losses = torch.zeros((size,))
|
|
||||||
previous_mean_losses = [0]
|
|
||||||
previous_mean_loss = 0
|
|
||||||
print("Mean loss of {} elements".format(size))
|
|
||||||
|
|
||||||
weights = hypernetwork.weights()
|
weights = hypernetwork.weights()
|
||||||
for weight in weights:
|
hypernetwork.train_mode()
|
||||||
weight.requires_grad = True
|
|
||||||
|
|
||||||
# Here we use optimizer from saved HN, or we can specify as UI option.
|
# Here we use optimizer from saved HN, or we can specify as UI option.
|
||||||
if hypernetwork.optimizer_name in optimizer_dict:
|
if hypernetwork.optimizer_name in optimizer_dict:
|
||||||
|
@ -446,131 +456,156 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
||||||
print("Cannot resume from saved optimizer!")
|
print("Cannot resume from saved optimizer!")
|
||||||
print(e)
|
print(e)
|
||||||
|
|
||||||
|
scaler = torch.cuda.amp.GradScaler()
|
||||||
|
|
||||||
|
batch_size = ds.batch_size
|
||||||
|
gradient_step = ds.gradient_step
|
||||||
|
# n steps = batch_size * gradient_step * n image processed
|
||||||
|
steps_per_epoch = len(ds) // batch_size // gradient_step
|
||||||
|
max_steps_per_epoch = len(ds) // batch_size - (len(ds) // batch_size) % gradient_step
|
||||||
|
loss_step = 0
|
||||||
|
_loss_step = 0 #internal
|
||||||
|
# size = len(ds.indexes)
|
||||||
|
# loss_dict = defaultdict(lambda : deque(maxlen = 1024))
|
||||||
|
# losses = torch.zeros((size,))
|
||||||
|
# previous_mean_losses = [0]
|
||||||
|
# previous_mean_loss = 0
|
||||||
|
# print("Mean loss of {} elements".format(size))
|
||||||
|
|
||||||
steps_without_grad = 0
|
steps_without_grad = 0
|
||||||
|
|
||||||
last_saved_file = "<none>"
|
last_saved_file = "<none>"
|
||||||
last_saved_image = "<none>"
|
last_saved_image = "<none>"
|
||||||
forced_filename = "<none>"
|
forced_filename = "<none>"
|
||||||
|
|
||||||
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
|
pbar = tqdm.tqdm(total=steps - initial_step)
|
||||||
for i, entries in pbar:
|
try:
|
||||||
hypernetwork.step = i + ititial_step
|
for i in range((steps-initial_step) * gradient_step):
|
||||||
if len(loss_dict) > 0:
|
if scheduler.finished:
|
||||||
previous_mean_losses = [i[-1] for i in loss_dict.values()]
|
break
|
||||||
previous_mean_loss = mean(previous_mean_losses)
|
if shared.state.interrupted:
|
||||||
|
break
|
||||||
scheduler.apply(optimizer, hypernetwork.step)
|
for j, batch in enumerate(dl):
|
||||||
if scheduler.finished:
|
# works as a drop_last=True for gradient accumulation
|
||||||
break
|
if j == max_steps_per_epoch:
|
||||||
|
break
|
||||||
|
scheduler.apply(optimizer, hypernetwork.step)
|
||||||
|
if scheduler.finished:
|
||||||
|
break
|
||||||
|
if shared.state.interrupted:
|
||||||
|
break
|
||||||
|
|
||||||
if shared.state.interrupted:
|
with torch.autocast("cuda"):
|
||||||
break
|
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
|
||||||
|
if tag_drop_out != 0 or shuffle_tags:
|
||||||
|
shared.sd_model.cond_stage_model.to(devices.device)
|
||||||
|
c = shared.sd_model.cond_stage_model(batch.cond_text).to(devices.device, non_blocking=pin_memory)
|
||||||
|
shared.sd_model.cond_stage_model.to(devices.cpu)
|
||||||
|
else:
|
||||||
|
c = stack_conds(batch.cond).to(devices.device, non_blocking=pin_memory)
|
||||||
|
loss = shared.sd_model(x, c)[0] / gradient_step
|
||||||
|
del x
|
||||||
|
del c
|
||||||
|
|
||||||
with torch.autocast("cuda"):
|
_loss_step += loss.item()
|
||||||
c = stack_conds([entry.cond for entry in entries]).to(devices.device)
|
scaler.scale(loss).backward()
|
||||||
# c = torch.vstack([entry.cond for entry in entries]).to(devices.device)
|
# go back until we reach gradient accumulation steps
|
||||||
x = torch.stack([entry.latent for entry in entries]).to(devices.device)
|
if (j + 1) % gradient_step != 0:
|
||||||
loss = shared.sd_model(x, c)[0]
|
continue
|
||||||
del x
|
# print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.7f}")
|
||||||
del c
|
# scaler.unscale_(optimizer)
|
||||||
|
# print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.15f}")
|
||||||
|
# torch.nn.utils.clip_grad_norm_(weights, max_norm=1.0)
|
||||||
|
# print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.15f}")
|
||||||
|
scaler.step(optimizer)
|
||||||
|
scaler.update()
|
||||||
|
hypernetwork.step += 1
|
||||||
|
pbar.update()
|
||||||
|
optimizer.zero_grad(set_to_none=True)
|
||||||
|
loss_step = _loss_step
|
||||||
|
_loss_step = 0
|
||||||
|
|
||||||
losses[hypernetwork.step % losses.shape[0]] = loss.item()
|
steps_done = hypernetwork.step + 1
|
||||||
for entry in entries:
|
|
||||||
loss_dict[entry.filename].append(loss.item())
|
|
||||||
|
|
||||||
optimizer.zero_grad()
|
epoch_num = hypernetwork.step // steps_per_epoch
|
||||||
weights[0].grad = None
|
epoch_step = hypernetwork.step % steps_per_epoch
|
||||||
loss.backward()
|
|
||||||
|
|
||||||
if weights[0].grad is None:
|
pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}")
|
||||||
steps_without_grad += 1
|
if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0:
|
||||||
else:
|
# Before saving, change name to match current checkpoint.
|
||||||
steps_without_grad = 0
|
hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}'
|
||||||
assert steps_without_grad < 10, 'no gradient found for the trained weight after backward() for 10 steps in a row; this is a bug; training cannot continue'
|
last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name_every}.pt')
|
||||||
|
hypernetwork.optimizer_name = optimizer_name
|
||||||
|
if shared.opts.save_optimizer_state:
|
||||||
|
hypernetwork.optimizer_state_dict = optimizer.state_dict()
|
||||||
|
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file)
|
||||||
|
hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
|
||||||
|
|
||||||
optimizer.step()
|
textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, steps_per_epoch, {
|
||||||
|
"loss": f"{loss_step:.7f}",
|
||||||
|
"learn_rate": scheduler.learn_rate
|
||||||
|
})
|
||||||
|
|
||||||
steps_done = hypernetwork.step + 1
|
if images_dir is not None and steps_done % create_image_every == 0:
|
||||||
|
forced_filename = f'{hypernetwork_name}-{steps_done}'
|
||||||
|
last_saved_image = os.path.join(images_dir, forced_filename)
|
||||||
|
hypernetwork.eval_mode()
|
||||||
|
shared.sd_model.cond_stage_model.to(devices.device)
|
||||||
|
shared.sd_model.first_stage_model.to(devices.device)
|
||||||
|
|
||||||
if torch.isnan(losses[hypernetwork.step % losses.shape[0]]):
|
p = processing.StableDiffusionProcessingTxt2Img(
|
||||||
raise RuntimeError("Loss diverged.")
|
sd_model=shared.sd_model,
|
||||||
|
do_not_save_grid=True,
|
||||||
if len(previous_mean_losses) > 1:
|
do_not_save_samples=True,
|
||||||
std = stdev(previous_mean_losses)
|
)
|
||||||
else:
|
|
||||||
std = 0
|
|
||||||
dataset_loss_info = f"dataset loss:{mean(previous_mean_losses):.3f}" + u"\u00B1" + f"({std / (len(previous_mean_losses) ** 0.5):.3f})"
|
|
||||||
pbar.set_description(dataset_loss_info)
|
|
||||||
|
|
||||||
if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0:
|
if preview_from_txt2img:
|
||||||
# Before saving, change name to match current checkpoint.
|
p.prompt = preview_prompt
|
||||||
hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}'
|
p.negative_prompt = preview_negative_prompt
|
||||||
last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name_every}.pt')
|
p.steps = preview_steps
|
||||||
hypernetwork.optimizer_name = optimizer_name
|
p.sampler_name = sd_samplers.samplers[preview_sampler_index].name
|
||||||
if shared.opts.save_optimizer_state:
|
p.cfg_scale = preview_cfg_scale
|
||||||
hypernetwork.optimizer_state_dict = optimizer.state_dict()
|
p.seed = preview_seed
|
||||||
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file)
|
p.width = preview_width
|
||||||
hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
|
p.height = preview_height
|
||||||
|
else:
|
||||||
|
p.prompt = batch.cond_text[0]
|
||||||
|
p.steps = 20
|
||||||
|
p.width = training_width
|
||||||
|
p.height = training_height
|
||||||
|
|
||||||
textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), {
|
preview_text = p.prompt
|
||||||
"loss": f"{previous_mean_loss:.7f}",
|
|
||||||
"learn_rate": scheduler.learn_rate
|
|
||||||
})
|
|
||||||
|
|
||||||
if images_dir is not None and steps_done % create_image_every == 0:
|
processed = processing.process_images(p)
|
||||||
forced_filename = f'{hypernetwork_name}-{steps_done}'
|
image = processed.images[0] if len(processed.images) > 0 else None
|
||||||
last_saved_image = os.path.join(images_dir, forced_filename)
|
|
||||||
|
|
||||||
optimizer.zero_grad()
|
if unload:
|
||||||
shared.sd_model.cond_stage_model.to(devices.device)
|
shared.sd_model.cond_stage_model.to(devices.cpu)
|
||||||
shared.sd_model.first_stage_model.to(devices.device)
|
shared.sd_model.first_stage_model.to(devices.cpu)
|
||||||
|
hypernetwork.train_mode()
|
||||||
|
if image is not None:
|
||||||
|
shared.state.current_image = image
|
||||||
|
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
|
||||||
|
last_saved_image += f", prompt: {preview_text}"
|
||||||
|
|
||||||
p = processing.StableDiffusionProcessingTxt2Img(
|
shared.state.job_no = hypernetwork.step
|
||||||
sd_model=shared.sd_model,
|
|
||||||
do_not_save_grid=True,
|
|
||||||
do_not_save_samples=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
if preview_from_txt2img:
|
shared.state.textinfo = f"""
|
||||||
p.prompt = preview_prompt
|
|
||||||
p.negative_prompt = preview_negative_prompt
|
|
||||||
p.steps = preview_steps
|
|
||||||
p.sampler_name = sd_samplers.samplers[preview_sampler_index].name
|
|
||||||
p.cfg_scale = preview_cfg_scale
|
|
||||||
p.seed = preview_seed
|
|
||||||
p.width = preview_width
|
|
||||||
p.height = preview_height
|
|
||||||
else:
|
|
||||||
p.prompt = entries[0].cond_text
|
|
||||||
p.steps = 20
|
|
||||||
|
|
||||||
preview_text = p.prompt
|
|
||||||
|
|
||||||
processed = processing.process_images(p)
|
|
||||||
image = processed.images[0] if len(processed.images)>0 else None
|
|
||||||
|
|
||||||
if unload:
|
|
||||||
shared.sd_model.cond_stage_model.to(devices.cpu)
|
|
||||||
shared.sd_model.first_stage_model.to(devices.cpu)
|
|
||||||
|
|
||||||
if image is not None:
|
|
||||||
shared.state.current_image = image
|
|
||||||
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
|
|
||||||
last_saved_image += f", prompt: {preview_text}"
|
|
||||||
|
|
||||||
shared.state.job_no = hypernetwork.step
|
|
||||||
|
|
||||||
shared.state.textinfo = f"""
|
|
||||||
<p>
|
<p>
|
||||||
Loss: {previous_mean_loss:.7f}<br/>
|
Loss: {loss_step:.7f}<br/>
|
||||||
Step: {hypernetwork.step}<br/>
|
Step: {steps_done}<br/>
|
||||||
Last prompt: {html.escape(entries[0].cond_text)}<br/>
|
Last prompt: {html.escape(batch.cond_text[0])}<br/>
|
||||||
Last saved hypernetwork: {html.escape(last_saved_file)}<br/>
|
Last saved hypernetwork: {html.escape(last_saved_file)}<br/>
|
||||||
Last saved image: {html.escape(last_saved_image)}<br/>
|
Last saved image: {html.escape(last_saved_image)}<br/>
|
||||||
</p>
|
</p>
|
||||||
"""
|
"""
|
||||||
|
except Exception:
|
||||||
report_statistics(loss_dict)
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
finally:
|
||||||
|
pbar.leave = False
|
||||||
|
pbar.close()
|
||||||
|
hypernetwork.eval_mode()
|
||||||
|
#report_statistics(loss_dict)
|
||||||
|
|
||||||
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
|
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
|
||||||
hypernetwork.optimizer_name = optimizer_name
|
hypernetwork.optimizer_name = optimizer_name
|
||||||
|
@ -579,6 +614,9 @@ Last saved image: {html.escape(last_saved_image)}<br/>
|
||||||
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename)
|
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename)
|
||||||
del optimizer
|
del optimizer
|
||||||
hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
|
hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
|
||||||
|
shared.sd_model.cond_stage_model.to(devices.device)
|
||||||
|
shared.sd_model.first_stage_model.to(devices.device)
|
||||||
|
|
||||||
return hypernetwork, filename
|
return hypernetwork, filename
|
||||||
|
|
||||||
def save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename):
|
def save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename):
|
||||||
|
|
|
@ -8,9 +8,9 @@ from torch import einsum
|
||||||
from torch.nn.functional import silu
|
from torch.nn.functional import silu
|
||||||
|
|
||||||
import modules.textual_inversion.textual_inversion
|
import modules.textual_inversion.textual_inversion
|
||||||
from modules import prompt_parser, devices, sd_hijack_optimizations, shared
|
from modules import prompt_parser, devices, sd_hijack_optimizations, shared, sd_hijack_checkpoint
|
||||||
from modules.hypernetworks import hypernetwork
|
from modules.hypernetworks import hypernetwork
|
||||||
from modules.shared import cmd_opts
|
from modules.shared import opts, device, cmd_opts
|
||||||
from modules import sd_hijack_clip, sd_hijack_open_clip
|
from modules import sd_hijack_clip, sd_hijack_open_clip
|
||||||
|
|
||||||
from modules.sd_hijack_optimizations import invokeAI_mps_available
|
from modules.sd_hijack_optimizations import invokeAI_mps_available
|
||||||
|
@ -66,6 +66,10 @@ def undo_optimizations():
|
||||||
ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
|
ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
|
||||||
|
|
||||||
|
|
||||||
|
def fix_checkpoint():
|
||||||
|
ldm.modules.attention.BasicTransformerBlock.forward = sd_hijack_checkpoint.BasicTransformerBlock_forward
|
||||||
|
ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = sd_hijack_checkpoint.ResBlock_forward
|
||||||
|
ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = sd_hijack_checkpoint.AttentionBlock_forward
|
||||||
|
|
||||||
class StableDiffusionModelHijack:
|
class StableDiffusionModelHijack:
|
||||||
fixes = None
|
fixes = None
|
||||||
|
@ -88,6 +92,7 @@ class StableDiffusionModelHijack:
|
||||||
self.clip = m.cond_stage_model
|
self.clip = m.cond_stage_model
|
||||||
|
|
||||||
apply_optimizations()
|
apply_optimizations()
|
||||||
|
fix_checkpoint()
|
||||||
|
|
||||||
def flatten(el):
|
def flatten(el):
|
||||||
flattened = [flatten(children) for children in el.children()]
|
flattened = [flatten(children) for children in el.children()]
|
||||||
|
|
10
modules/sd_hijack_checkpoint.py
Normal file
10
modules/sd_hijack_checkpoint.py
Normal file
|
@ -0,0 +1,10 @@
|
||||||
|
from torch.utils.checkpoint import checkpoint
|
||||||
|
|
||||||
|
def BasicTransformerBlock_forward(self, x, context=None):
|
||||||
|
return checkpoint(self._forward, x, context)
|
||||||
|
|
||||||
|
def AttentionBlock_forward(self, x):
|
||||||
|
return checkpoint(self._forward, x)
|
||||||
|
|
||||||
|
def ResBlock_forward(self, x, emb):
|
||||||
|
return checkpoint(self._forward, x, emb)
|
|
@ -345,8 +345,7 @@ options_templates.update(options_section(('system', "System"), {
|
||||||
|
|
||||||
options_templates.update(options_section(('training', "Training"), {
|
options_templates.update(options_section(('training', "Training"), {
|
||||||
"unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."),
|
"unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."),
|
||||||
"shuffle_tags": OptionInfo(False, "Shuffleing tags by ',' when create texts."),
|
"pin_memory": OptionInfo(False, "Turn on pin_memory for DataLoader. Makes training slightly faster but can increase memory usage."),
|
||||||
"tag_drop_out": OptionInfo(0, "Dropout tags when create texts", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.1}),
|
|
||||||
"save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training can be resumed with HN itself and matching optim file."),
|
"save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training can be resumed with HN itself and matching optim file."),
|
||||||
"dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
|
"dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
|
||||||
"dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
|
"dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
|
||||||
|
|
|
@ -3,7 +3,7 @@ import numpy as np
|
||||||
import PIL
|
import PIL
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset, DataLoader
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
|
|
||||||
import random
|
import random
|
||||||
|
@ -11,25 +11,28 @@ import tqdm
|
||||||
from modules import devices, shared
|
from modules import devices, shared
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
|
||||||
|
|
||||||
re_numbers_at_start = re.compile(r"^[-\d]+\s*")
|
re_numbers_at_start = re.compile(r"^[-\d]+\s*")
|
||||||
|
|
||||||
|
|
||||||
class DatasetEntry:
|
class DatasetEntry:
|
||||||
def __init__(self, filename=None, latent=None, filename_text=None):
|
def __init__(self, filename=None, filename_text=None, latent_dist=None, latent_sample=None, cond=None, cond_text=None, pixel_values=None):
|
||||||
self.filename = filename
|
self.filename = filename
|
||||||
self.latent = latent
|
|
||||||
self.filename_text = filename_text
|
self.filename_text = filename_text
|
||||||
self.cond = None
|
self.latent_dist = latent_dist
|
||||||
self.cond_text = None
|
self.latent_sample = latent_sample
|
||||||
|
self.cond = cond
|
||||||
|
self.cond_text = cond_text
|
||||||
|
self.pixel_values = pixel_values
|
||||||
|
|
||||||
|
|
||||||
class PersonalizedBase(Dataset):
|
class PersonalizedBase(Dataset):
|
||||||
def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None, include_cond=False, batch_size=1):
|
def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once'):
|
||||||
re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None
|
re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None
|
||||||
|
|
||||||
self.placeholder_token = placeholder_token
|
self.placeholder_token = placeholder_token
|
||||||
|
|
||||||
self.batch_size = batch_size
|
|
||||||
self.width = width
|
self.width = width
|
||||||
self.height = height
|
self.height = height
|
||||||
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
|
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
|
||||||
|
@ -45,11 +48,16 @@ class PersonalizedBase(Dataset):
|
||||||
assert os.path.isdir(data_root), "Dataset directory doesn't exist"
|
assert os.path.isdir(data_root), "Dataset directory doesn't exist"
|
||||||
assert os.listdir(data_root), "Dataset directory is empty"
|
assert os.listdir(data_root), "Dataset directory is empty"
|
||||||
|
|
||||||
cond_model = shared.sd_model.cond_stage_model
|
|
||||||
|
|
||||||
self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
|
self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
|
||||||
|
|
||||||
|
|
||||||
|
self.shuffle_tags = shuffle_tags
|
||||||
|
self.tag_drop_out = tag_drop_out
|
||||||
|
|
||||||
print("Preparing dataset...")
|
print("Preparing dataset...")
|
||||||
for path in tqdm.tqdm(self.image_paths):
|
for path in tqdm.tqdm(self.image_paths):
|
||||||
|
if shared.state.interrupted:
|
||||||
|
raise Exception("inturrupted")
|
||||||
try:
|
try:
|
||||||
image = Image.open(path).convert('RGB').resize((self.width, self.height), PIL.Image.BICUBIC)
|
image = Image.open(path).convert('RGB').resize((self.width, self.height), PIL.Image.BICUBIC)
|
||||||
except Exception:
|
except Exception:
|
||||||
|
@ -71,37 +79,49 @@ class PersonalizedBase(Dataset):
|
||||||
npimage = np.array(image).astype(np.uint8)
|
npimage = np.array(image).astype(np.uint8)
|
||||||
npimage = (npimage / 127.5 - 1.0).astype(np.float32)
|
npimage = (npimage / 127.5 - 1.0).astype(np.float32)
|
||||||
|
|
||||||
torchdata = torch.from_numpy(npimage).to(device=device, dtype=torch.float32)
|
torchdata = torch.from_numpy(npimage).permute(2, 0, 1).to(device=device, dtype=torch.float32)
|
||||||
torchdata = torch.moveaxis(torchdata, 2, 0)
|
latent_sample = None
|
||||||
|
|
||||||
init_latent = model.get_first_stage_encoding(model.encode_first_stage(torchdata.unsqueeze(dim=0))).squeeze()
|
with torch.autocast("cuda"):
|
||||||
init_latent = init_latent.to(devices.cpu)
|
latent_dist = model.encode_first_stage(torchdata.unsqueeze(dim=0))
|
||||||
|
|
||||||
entry = DatasetEntry(filename=path, filename_text=filename_text, latent=init_latent)
|
if latent_sampling_method == "once" or (latent_sampling_method == "deterministic" and not isinstance(latent_dist, DiagonalGaussianDistribution)):
|
||||||
|
latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu)
|
||||||
|
latent_sampling_method = "once"
|
||||||
|
entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample)
|
||||||
|
elif latent_sampling_method == "deterministic":
|
||||||
|
# Works only for DiagonalGaussianDistribution
|
||||||
|
latent_dist.std = 0
|
||||||
|
latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu)
|
||||||
|
entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample)
|
||||||
|
elif latent_sampling_method == "random":
|
||||||
|
entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist)
|
||||||
|
|
||||||
if include_cond:
|
if not (self.tag_drop_out != 0 or self.shuffle_tags):
|
||||||
entry.cond_text = self.create_text(filename_text)
|
entry.cond_text = self.create_text(filename_text)
|
||||||
entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0)
|
|
||||||
|
if include_cond and not (self.tag_drop_out != 0 or self.shuffle_tags):
|
||||||
|
with torch.autocast("cuda"):
|
||||||
|
entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0)
|
||||||
|
|
||||||
self.dataset.append(entry)
|
self.dataset.append(entry)
|
||||||
|
del torchdata
|
||||||
|
del latent_dist
|
||||||
|
del latent_sample
|
||||||
|
|
||||||
assert len(self.dataset) > 0, "No images have been found in the dataset."
|
self.length = len(self.dataset)
|
||||||
self.length = len(self.dataset) * repeats // batch_size
|
assert self.length > 0, "No images have been found in the dataset."
|
||||||
|
self.batch_size = min(batch_size, self.length)
|
||||||
self.dataset_length = len(self.dataset)
|
self.gradient_step = min(gradient_step, self.length // self.batch_size)
|
||||||
self.indexes = None
|
self.latent_sampling_method = latent_sampling_method
|
||||||
self.shuffle()
|
|
||||||
|
|
||||||
def shuffle(self):
|
|
||||||
self.indexes = np.random.permutation(self.dataset_length)
|
|
||||||
|
|
||||||
def create_text(self, filename_text):
|
def create_text(self, filename_text):
|
||||||
text = random.choice(self.lines)
|
text = random.choice(self.lines)
|
||||||
text = text.replace("[name]", self.placeholder_token)
|
text = text.replace("[name]", self.placeholder_token)
|
||||||
tags = filename_text.split(',')
|
tags = filename_text.split(',')
|
||||||
if shared.opts.tag_drop_out != 0:
|
if self.tag_drop_out != 0:
|
||||||
tags = [t for t in tags if random.random() > shared.opts.tag_drop_out]
|
tags = [t for t in tags if random.random() > self.tag_drop_out]
|
||||||
if shared.opts.shuffle_tags:
|
if self.shuffle_tags:
|
||||||
random.shuffle(tags)
|
random.shuffle(tags)
|
||||||
text = text.replace("[filewords]", ','.join(tags))
|
text = text.replace("[filewords]", ','.join(tags))
|
||||||
return text
|
return text
|
||||||
|
@ -110,19 +130,43 @@ class PersonalizedBase(Dataset):
|
||||||
return self.length
|
return self.length
|
||||||
|
|
||||||
def __getitem__(self, i):
|
def __getitem__(self, i):
|
||||||
res = []
|
entry = self.dataset[i]
|
||||||
|
if self.tag_drop_out != 0 or self.shuffle_tags:
|
||||||
|
entry.cond_text = self.create_text(entry.filename_text)
|
||||||
|
if self.latent_sampling_method == "random":
|
||||||
|
entry.latent_sample = shared.sd_model.get_first_stage_encoding(entry.latent_dist).to(devices.cpu)
|
||||||
|
return entry
|
||||||
|
|
||||||
for j in range(self.batch_size):
|
class PersonalizedDataLoader(DataLoader):
|
||||||
position = i * self.batch_size + j
|
def __init__(self, dataset, latent_sampling_method="once", batch_size=1, pin_memory=False):
|
||||||
if position % len(self.indexes) == 0:
|
super(PersonalizedDataLoader, self).__init__(dataset, shuffle=True, drop_last=True, batch_size=batch_size, pin_memory=pin_memory)
|
||||||
self.shuffle()
|
if latent_sampling_method == "random":
|
||||||
|
self.collate_fn = collate_wrapper_random
|
||||||
|
else:
|
||||||
|
self.collate_fn = collate_wrapper
|
||||||
|
|
||||||
|
|
||||||
index = self.indexes[position % len(self.indexes)]
|
class BatchLoader:
|
||||||
entry = self.dataset[index]
|
def __init__(self, data):
|
||||||
|
self.cond_text = [entry.cond_text for entry in data]
|
||||||
|
self.cond = [entry.cond for entry in data]
|
||||||
|
self.latent_sample = torch.stack([entry.latent_sample for entry in data]).squeeze(1)
|
||||||
|
#self.emb_index = [entry.emb_index for entry in data]
|
||||||
|
#print(self.latent_sample.device)
|
||||||
|
|
||||||
if entry.cond is None:
|
def pin_memory(self):
|
||||||
entry.cond_text = self.create_text(entry.filename_text)
|
self.latent_sample = self.latent_sample.pin_memory()
|
||||||
|
return self
|
||||||
|
|
||||||
res.append(entry)
|
def collate_wrapper(batch):
|
||||||
|
return BatchLoader(batch)
|
||||||
|
|
||||||
return res
|
class BatchLoaderRandom(BatchLoader):
|
||||||
|
def __init__(self, data):
|
||||||
|
super().__init__(data)
|
||||||
|
|
||||||
|
def pin_memory(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def collate_wrapper_random(batch):
|
||||||
|
return BatchLoaderRandom(batch)
|
|
@ -183,7 +183,7 @@ def write_loss(log_directory, filename, step, epoch_len, values):
|
||||||
if shared.opts.training_write_csv_every == 0:
|
if shared.opts.training_write_csv_every == 0:
|
||||||
return
|
return
|
||||||
|
|
||||||
if (step + 1) % shared.opts.training_write_csv_every != 0:
|
if step % shared.opts.training_write_csv_every != 0:
|
||||||
return
|
return
|
||||||
write_csv_header = False if os.path.exists(os.path.join(log_directory, filename)) else True
|
write_csv_header = False if os.path.exists(os.path.join(log_directory, filename)) else True
|
||||||
|
|
||||||
|
@ -193,21 +193,23 @@ def write_loss(log_directory, filename, step, epoch_len, values):
|
||||||
if write_csv_header:
|
if write_csv_header:
|
||||||
csv_writer.writeheader()
|
csv_writer.writeheader()
|
||||||
|
|
||||||
epoch = step // epoch_len
|
epoch = (step - 1) // epoch_len
|
||||||
epoch_step = step % epoch_len
|
epoch_step = (step - 1) % epoch_len
|
||||||
|
|
||||||
csv_writer.writerow({
|
csv_writer.writerow({
|
||||||
"step": step + 1,
|
"step": step,
|
||||||
"epoch": epoch,
|
"epoch": epoch,
|
||||||
"epoch_step": epoch_step + 1,
|
"epoch_step": epoch_step,
|
||||||
**values,
|
**values,
|
||||||
})
|
})
|
||||||
|
|
||||||
def validate_train_inputs(model_name, learn_rate, batch_size, data_root, template_file, steps, save_model_every, create_image_every, log_directory, name="embedding"):
|
def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_model_every, create_image_every, log_directory, name="embedding"):
|
||||||
assert model_name, f"{name} not selected"
|
assert model_name, f"{name} not selected"
|
||||||
assert learn_rate, "Learning rate is empty or 0"
|
assert learn_rate, "Learning rate is empty or 0"
|
||||||
assert isinstance(batch_size, int), "Batch size must be integer"
|
assert isinstance(batch_size, int), "Batch size must be integer"
|
||||||
assert batch_size > 0, "Batch size must be positive"
|
assert batch_size > 0, "Batch size must be positive"
|
||||||
|
assert isinstance(gradient_step, int), "Gradient accumulation step must be integer"
|
||||||
|
assert gradient_step > 0, "Gradient accumulation step must be positive"
|
||||||
assert data_root, "Dataset directory is empty"
|
assert data_root, "Dataset directory is empty"
|
||||||
assert os.path.isdir(data_root), "Dataset directory doesn't exist"
|
assert os.path.isdir(data_root), "Dataset directory doesn't exist"
|
||||||
assert os.listdir(data_root), "Dataset directory is empty"
|
assert os.listdir(data_root), "Dataset directory is empty"
|
||||||
|
@ -223,10 +225,10 @@ def validate_train_inputs(model_name, learn_rate, batch_size, data_root, templat
|
||||||
if save_model_every or create_image_every:
|
if save_model_every or create_image_every:
|
||||||
assert log_directory, "Log directory is empty"
|
assert log_directory, "Log directory is empty"
|
||||||
|
|
||||||
def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
||||||
save_embedding_every = save_embedding_every or 0
|
save_embedding_every = save_embedding_every or 0
|
||||||
create_image_every = create_image_every or 0
|
create_image_every = create_image_every or 0
|
||||||
validate_train_inputs(embedding_name, learn_rate, batch_size, data_root, template_file, steps, save_embedding_every, create_image_every, log_directory, name="embedding")
|
validate_train_inputs(embedding_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_embedding_every, create_image_every, log_directory, name="embedding")
|
||||||
|
|
||||||
shared.state.textinfo = "Initializing textual inversion training..."
|
shared.state.textinfo = "Initializing textual inversion training..."
|
||||||
shared.state.job_count = steps
|
shared.state.job_count = steps
|
||||||
|
@ -254,161 +256,200 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
|
||||||
else:
|
else:
|
||||||
images_embeds_dir = None
|
images_embeds_dir = None
|
||||||
|
|
||||||
cond_model = shared.sd_model.cond_stage_model
|
|
||||||
|
|
||||||
hijack = sd_hijack.model_hijack
|
hijack = sd_hijack.model_hijack
|
||||||
|
|
||||||
embedding = hijack.embedding_db.word_embeddings[embedding_name]
|
embedding = hijack.embedding_db.word_embeddings[embedding_name]
|
||||||
checkpoint = sd_models.select_checkpoint()
|
checkpoint = sd_models.select_checkpoint()
|
||||||
|
|
||||||
ititial_step = embedding.step or 0
|
initial_step = embedding.step or 0
|
||||||
if ititial_step >= steps:
|
if initial_step >= steps:
|
||||||
shared.state.textinfo = f"Model has already been trained beyond specified max steps"
|
shared.state.textinfo = f"Model has already been trained beyond specified max steps"
|
||||||
return embedding, filename
|
return embedding, filename
|
||||||
|
scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
|
||||||
|
|
||||||
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
|
# dataset loading may take a while, so input validations and early returns should be done before this
|
||||||
|
|
||||||
# dataset loading may take a while, so input validations and early returns should be done before this
|
|
||||||
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
||||||
with torch.autocast("cuda"):
|
|
||||||
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file, batch_size=batch_size)
|
pin_memory = shared.opts.pin_memory
|
||||||
|
|
||||||
|
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method)
|
||||||
|
|
||||||
|
latent_sampling_method = ds.latent_sampling_method
|
||||||
|
|
||||||
|
dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory)
|
||||||
|
|
||||||
if unload:
|
if unload:
|
||||||
shared.sd_model.first_stage_model.to(devices.cpu)
|
shared.sd_model.first_stage_model.to(devices.cpu)
|
||||||
|
|
||||||
embedding.vec.requires_grad = True
|
embedding.vec.requires_grad = True
|
||||||
optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate)
|
optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate, weight_decay=0.0)
|
||||||
|
scaler = torch.cuda.amp.GradScaler()
|
||||||
|
|
||||||
losses = torch.zeros((32,))
|
batch_size = ds.batch_size
|
||||||
|
gradient_step = ds.gradient_step
|
||||||
|
# n steps = batch_size * gradient_step * n image processed
|
||||||
|
steps_per_epoch = len(ds) // batch_size // gradient_step
|
||||||
|
max_steps_per_epoch = len(ds) // batch_size - (len(ds) // batch_size) % gradient_step
|
||||||
|
loss_step = 0
|
||||||
|
_loss_step = 0 #internal
|
||||||
|
|
||||||
|
|
||||||
last_saved_file = "<none>"
|
last_saved_file = "<none>"
|
||||||
last_saved_image = "<none>"
|
last_saved_image = "<none>"
|
||||||
forced_filename = "<none>"
|
forced_filename = "<none>"
|
||||||
embedding_yet_to_be_embedded = False
|
embedding_yet_to_be_embedded = False
|
||||||
|
|
||||||
|
pbar = tqdm.tqdm(total=steps - initial_step)
|
||||||
|
try:
|
||||||
|
for i in range((steps-initial_step) * gradient_step):
|
||||||
|
if scheduler.finished:
|
||||||
|
break
|
||||||
|
if shared.state.interrupted:
|
||||||
|
break
|
||||||
|
for j, batch in enumerate(dl):
|
||||||
|
# works as a drop_last=True for gradient accumulation
|
||||||
|
if j == max_steps_per_epoch:
|
||||||
|
break
|
||||||
|
scheduler.apply(optimizer, embedding.step)
|
||||||
|
if scheduler.finished:
|
||||||
|
break
|
||||||
|
if shared.state.interrupted:
|
||||||
|
break
|
||||||
|
|
||||||
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
|
with torch.autocast("cuda"):
|
||||||
for i, entries in pbar:
|
# c = stack_conds(batch.cond).to(devices.device)
|
||||||
embedding.step = i + ititial_step
|
# mask = torch.tensor(batch.emb_index).to(devices.device, non_blocking=pin_memory)
|
||||||
|
# print(mask)
|
||||||
|
# c[:, 1:1+embedding.vec.shape[0]] = embedding.vec.to(devices.device, non_blocking=pin_memory)
|
||||||
|
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
|
||||||
|
c = shared.sd_model.cond_stage_model(batch.cond_text)
|
||||||
|
loss = shared.sd_model(x, c)[0] / gradient_step
|
||||||
|
del x
|
||||||
|
|
||||||
|
_loss_step += loss.item()
|
||||||
|
scaler.scale(loss).backward()
|
||||||
|
|
||||||
|
# go back until we reach gradient accumulation steps
|
||||||
|
if (j + 1) % gradient_step != 0:
|
||||||
|
continue
|
||||||
|
scaler.step(optimizer)
|
||||||
|
scaler.update()
|
||||||
|
embedding.step += 1
|
||||||
|
pbar.update()
|
||||||
|
optimizer.zero_grad(set_to_none=True)
|
||||||
|
loss_step = _loss_step
|
||||||
|
_loss_step = 0
|
||||||
|
|
||||||
scheduler.apply(optimizer, embedding.step)
|
steps_done = embedding.step + 1
|
||||||
if scheduler.finished:
|
|
||||||
break
|
|
||||||
|
|
||||||
if shared.state.interrupted:
|
epoch_num = embedding.step // steps_per_epoch
|
||||||
break
|
epoch_step = embedding.step % steps_per_epoch
|
||||||
|
|
||||||
with torch.autocast("cuda"):
|
pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}")
|
||||||
c = cond_model([entry.cond_text for entry in entries])
|
if embedding_dir is not None and steps_done % save_embedding_every == 0:
|
||||||
x = torch.stack([entry.latent for entry in entries]).to(devices.device)
|
# Before saving, change name to match current checkpoint.
|
||||||
loss = shared.sd_model(x, c)[0]
|
embedding_name_every = f'{embedding_name}-{steps_done}'
|
||||||
del x
|
last_saved_file = os.path.join(embedding_dir, f'{embedding_name_every}.pt')
|
||||||
|
#if shared.opts.save_optimizer_state:
|
||||||
|
#embedding.optimizer_state_dict = optimizer.state_dict()
|
||||||
|
save_embedding(embedding, checkpoint, embedding_name_every, last_saved_file, remove_cached_checksum=True)
|
||||||
|
embedding_yet_to_be_embedded = True
|
||||||
|
|
||||||
losses[embedding.step % losses.shape[0]] = loss.item()
|
write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, steps_per_epoch, {
|
||||||
|
"loss": f"{loss_step:.7f}",
|
||||||
|
"learn_rate": scheduler.learn_rate
|
||||||
|
})
|
||||||
|
|
||||||
optimizer.zero_grad()
|
if images_dir is not None and steps_done % create_image_every == 0:
|
||||||
loss.backward()
|
forced_filename = f'{embedding_name}-{steps_done}'
|
||||||
optimizer.step()
|
last_saved_image = os.path.join(images_dir, forced_filename)
|
||||||
|
|
||||||
steps_done = embedding.step + 1
|
shared.sd_model.first_stage_model.to(devices.device)
|
||||||
|
|
||||||
epoch_num = embedding.step // len(ds)
|
p = processing.StableDiffusionProcessingTxt2Img(
|
||||||
epoch_step = embedding.step % len(ds)
|
sd_model=shared.sd_model,
|
||||||
|
do_not_save_grid=True,
|
||||||
|
do_not_save_samples=True,
|
||||||
|
do_not_reload_embeddings=True,
|
||||||
|
)
|
||||||
|
|
||||||
pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{len(ds)}]loss: {losses.mean():.7f}")
|
if preview_from_txt2img:
|
||||||
|
p.prompt = preview_prompt
|
||||||
|
p.negative_prompt = preview_negative_prompt
|
||||||
|
p.steps = preview_steps
|
||||||
|
p.sampler_name = sd_samplers.samplers[preview_sampler_index].name
|
||||||
|
p.cfg_scale = preview_cfg_scale
|
||||||
|
p.seed = preview_seed
|
||||||
|
p.width = preview_width
|
||||||
|
p.height = preview_height
|
||||||
|
else:
|
||||||
|
p.prompt = batch.cond_text[0]
|
||||||
|
p.steps = 20
|
||||||
|
p.width = training_width
|
||||||
|
p.height = training_height
|
||||||
|
|
||||||
if embedding_dir is not None and steps_done % save_embedding_every == 0:
|
preview_text = p.prompt
|
||||||
# Before saving, change name to match current checkpoint.
|
|
||||||
embedding_name_every = f'{embedding_name}-{steps_done}'
|
|
||||||
last_saved_file = os.path.join(embedding_dir, f'{embedding_name_every}.pt')
|
|
||||||
save_embedding(embedding, checkpoint, embedding_name_every, last_saved_file, remove_cached_checksum=True)
|
|
||||||
embedding_yet_to_be_embedded = True
|
|
||||||
|
|
||||||
write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, len(ds), {
|
processed = processing.process_images(p)
|
||||||
"loss": f"{losses.mean():.7f}",
|
image = processed.images[0] if len(processed.images) > 0 else None
|
||||||
"learn_rate": scheduler.learn_rate
|
|
||||||
})
|
|
||||||
|
|
||||||
if images_dir is not None and steps_done % create_image_every == 0:
|
if unload:
|
||||||
forced_filename = f'{embedding_name}-{steps_done}'
|
shared.sd_model.first_stage_model.to(devices.cpu)
|
||||||
last_saved_image = os.path.join(images_dir, forced_filename)
|
|
||||||
|
|
||||||
shared.sd_model.first_stage_model.to(devices.device)
|
if image is not None:
|
||||||
|
shared.state.current_image = image
|
||||||
|
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
|
||||||
|
last_saved_image += f", prompt: {preview_text}"
|
||||||
|
|
||||||
p = processing.StableDiffusionProcessingTxt2Img(
|
if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
|
||||||
sd_model=shared.sd_model,
|
|
||||||
do_not_save_grid=True,
|
|
||||||
do_not_save_samples=True,
|
|
||||||
do_not_reload_embeddings=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
if preview_from_txt2img:
|
last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{steps_done}.png')
|
||||||
p.prompt = preview_prompt
|
|
||||||
p.negative_prompt = preview_negative_prompt
|
|
||||||
p.steps = preview_steps
|
|
||||||
p.sampler_name = sd_samplers.samplers[preview_sampler_index].name
|
|
||||||
p.cfg_scale = preview_cfg_scale
|
|
||||||
p.seed = preview_seed
|
|
||||||
p.width = preview_width
|
|
||||||
p.height = preview_height
|
|
||||||
else:
|
|
||||||
p.prompt = entries[0].cond_text
|
|
||||||
p.steps = 20
|
|
||||||
p.width = training_width
|
|
||||||
p.height = training_height
|
|
||||||
|
|
||||||
preview_text = p.prompt
|
info = PngImagePlugin.PngInfo()
|
||||||
|
data = torch.load(last_saved_file)
|
||||||
|
info.add_text("sd-ti-embedding", embedding_to_b64(data))
|
||||||
|
|
||||||
processed = processing.process_images(p)
|
title = "<{}>".format(data.get('name', '???'))
|
||||||
image = processed.images[0]
|
|
||||||
|
|
||||||
if unload:
|
try:
|
||||||
shared.sd_model.first_stage_model.to(devices.cpu)
|
vectorSize = list(data['string_to_param'].values())[0].shape[0]
|
||||||
|
except Exception as e:
|
||||||
|
vectorSize = '?'
|
||||||
|
|
||||||
shared.state.current_image = image
|
checkpoint = sd_models.select_checkpoint()
|
||||||
|
footer_left = checkpoint.model_name
|
||||||
|
footer_mid = '[{}]'.format(checkpoint.hash)
|
||||||
|
footer_right = '{}v {}s'.format(vectorSize, steps_done)
|
||||||
|
|
||||||
if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
|
captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right)
|
||||||
|
captioned_image = insert_image_data_embed(captioned_image, data)
|
||||||
|
|
||||||
last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{steps_done}.png')
|
captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info)
|
||||||
|
embedding_yet_to_be_embedded = False
|
||||||
|
|
||||||
info = PngImagePlugin.PngInfo()
|
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
|
||||||
data = torch.load(last_saved_file)
|
last_saved_image += f", prompt: {preview_text}"
|
||||||
info.add_text("sd-ti-embedding", embedding_to_b64(data))
|
|
||||||
|
|
||||||
title = "<{}>".format(data.get('name', '???'))
|
shared.state.job_no = embedding.step
|
||||||
|
|
||||||
try:
|
shared.state.textinfo = f"""
|
||||||
vectorSize = list(data['string_to_param'].values())[0].shape[0]
|
|
||||||
except Exception as e:
|
|
||||||
vectorSize = '?'
|
|
||||||
|
|
||||||
checkpoint = sd_models.select_checkpoint()
|
|
||||||
footer_left = checkpoint.model_name
|
|
||||||
footer_mid = '[{}]'.format(checkpoint.hash)
|
|
||||||
footer_right = '{}v {}s'.format(vectorSize, steps_done)
|
|
||||||
|
|
||||||
captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right)
|
|
||||||
captioned_image = insert_image_data_embed(captioned_image, data)
|
|
||||||
|
|
||||||
captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info)
|
|
||||||
embedding_yet_to_be_embedded = False
|
|
||||||
|
|
||||||
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
|
|
||||||
last_saved_image += f", prompt: {preview_text}"
|
|
||||||
|
|
||||||
shared.state.job_no = embedding.step
|
|
||||||
|
|
||||||
shared.state.textinfo = f"""
|
|
||||||
<p>
|
<p>
|
||||||
Loss: {losses.mean():.7f}<br/>
|
Loss: {loss_step:.7f}<br/>
|
||||||
Step: {embedding.step}<br/>
|
Step: {steps_done}<br/>
|
||||||
Last prompt: {html.escape(entries[0].cond_text)}<br/>
|
Last prompt: {html.escape(batch.cond_text[0])}<br/>
|
||||||
Last saved embedding: {html.escape(last_saved_file)}<br/>
|
Last saved embedding: {html.escape(last_saved_file)}<br/>
|
||||||
Last saved image: {html.escape(last_saved_image)}<br/>
|
Last saved image: {html.escape(last_saved_image)}<br/>
|
||||||
</p>
|
</p>
|
||||||
"""
|
"""
|
||||||
|
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
|
||||||
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
|
save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True)
|
||||||
save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True)
|
except Exception:
|
||||||
shared.sd_model.first_stage_model.to(devices.device)
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
pbar.leave = False
|
||||||
|
pbar.close()
|
||||||
|
shared.sd_model.first_stage_model.to(devices.device)
|
||||||
|
|
||||||
return embedding, filename
|
return embedding, filename
|
||||||
|
|
||||||
|
|
|
@ -1240,7 +1240,7 @@ def create_ui(wrap_gradio_gpu_call):
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
interrupt_preprocessing = gr.Button("Interrupt")
|
interrupt_preprocessing = gr.Button("Interrupt")
|
||||||
run_preprocess = gr.Button(value="Preprocess", variant='primary')
|
run_preprocess = gr.Button(value="Preprocess", variant='primary')
|
||||||
|
|
||||||
process_split.change(
|
process_split.change(
|
||||||
fn=lambda show: gr_show(show),
|
fn=lambda show: gr_show(show),
|
||||||
|
@ -1267,6 +1267,7 @@ def create_ui(wrap_gradio_gpu_call):
|
||||||
hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001")
|
hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001")
|
||||||
|
|
||||||
batch_size = gr.Number(label='Batch size', value=1, precision=0)
|
batch_size = gr.Number(label='Batch size', value=1, precision=0)
|
||||||
|
gradient_step = gr.Number(label='Gradient accumulation steps', value=1, precision=0)
|
||||||
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
|
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
|
||||||
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
|
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
|
||||||
template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"))
|
template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"))
|
||||||
|
@ -1277,6 +1278,11 @@ def create_ui(wrap_gradio_gpu_call):
|
||||||
save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0)
|
save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0)
|
||||||
save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True)
|
save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True)
|
||||||
preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False)
|
preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False)
|
||||||
|
with gr.Row():
|
||||||
|
shuffle_tags = gr.Checkbox(label="Shuffle tags by ',' when creating prompts.", value=False)
|
||||||
|
tag_drop_out = gr.Slider(minimum=0, maximum=1, step=0.1, label="Drop out tags when creating prompts.", value=0)
|
||||||
|
with gr.Row():
|
||||||
|
latent_sampling_method = gr.Radio(label='Choose latent sampling method', value="once", choices=['once', 'deterministic', 'random'])
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
interrupt_training = gr.Button(value="Interrupt")
|
interrupt_training = gr.Button(value="Interrupt")
|
||||||
|
@ -1365,11 +1371,15 @@ def create_ui(wrap_gradio_gpu_call):
|
||||||
train_embedding_name,
|
train_embedding_name,
|
||||||
embedding_learn_rate,
|
embedding_learn_rate,
|
||||||
batch_size,
|
batch_size,
|
||||||
|
gradient_step,
|
||||||
dataset_directory,
|
dataset_directory,
|
||||||
log_directory,
|
log_directory,
|
||||||
training_width,
|
training_width,
|
||||||
training_height,
|
training_height,
|
||||||
steps,
|
steps,
|
||||||
|
shuffle_tags,
|
||||||
|
tag_drop_out,
|
||||||
|
latent_sampling_method,
|
||||||
create_image_every,
|
create_image_every,
|
||||||
save_embedding_every,
|
save_embedding_every,
|
||||||
template_file,
|
template_file,
|
||||||
|
@ -1390,11 +1400,15 @@ def create_ui(wrap_gradio_gpu_call):
|
||||||
train_hypernetwork_name,
|
train_hypernetwork_name,
|
||||||
hypernetwork_learn_rate,
|
hypernetwork_learn_rate,
|
||||||
batch_size,
|
batch_size,
|
||||||
|
gradient_step,
|
||||||
dataset_directory,
|
dataset_directory,
|
||||||
log_directory,
|
log_directory,
|
||||||
training_width,
|
training_width,
|
||||||
training_height,
|
training_height,
|
||||||
steps,
|
steps,
|
||||||
|
shuffle_tags,
|
||||||
|
tag_drop_out,
|
||||||
|
latent_sampling_method,
|
||||||
create_image_every,
|
create_image_every,
|
||||||
save_embedding_every,
|
save_embedding_every,
|
||||||
template_file,
|
template_file,
|
||||||
|
|
Loading…
Reference in a new issue