Merge pull request #3490 from Nerogar/inpaint_textual_inversion

Fix textual inversion training with inpainting models
This commit is contained in:
AUTOMATIC1111 2023-01-04 17:40:29 +03:00 committed by GitHub
commit 8839b372bf
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -251,6 +251,26 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat
if save_model_every or create_image_every: if save_model_every or create_image_every:
assert log_directory, "Log directory is empty" assert log_directory, "Log directory is empty"
def create_dummy_mask(x, width=None, height=None):
if shared.sd_model.model.conditioning_key in {'hybrid', 'concat'}:
# The "masked-image" in this case will just be all zeros since the entire image is masked.
image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
image_conditioning = shared.sd_model.get_first_stage_encoding(shared.sd_model.encode_first_stage(image_conditioning))
# Add the fake full 1s mask to the first dimension.
image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
image_conditioning = image_conditioning.to(x.dtype)
else:
# Dummy zero conditioning if we're not using inpainting model.
# Still takes up a bit of memory, but no encoder call.
# Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
image_conditioning = torch.zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device)
return image_conditioning
def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
save_embedding_every = save_embedding_every or 0 save_embedding_every = save_embedding_every or 0
create_image_every = create_image_every or 0 create_image_every = create_image_every or 0
@ -341,6 +361,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
forced_filename = "<none>" forced_filename = "<none>"
embedding_yet_to_be_embedded = False embedding_yet_to_be_embedded = False
img_c = None
pbar = tqdm.tqdm(total=steps - initial_step) pbar = tqdm.tqdm(total=steps - initial_step)
try: try:
for i in range((steps-initial_step) * gradient_step): for i in range((steps-initial_step) * gradient_step):
@ -363,9 +384,15 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
# mask = torch.tensor(batch.emb_index).to(devices.device, non_blocking=pin_memory) # mask = torch.tensor(batch.emb_index).to(devices.device, non_blocking=pin_memory)
# print(mask) # print(mask)
# c[:, 1:1+embedding.vec.shape[0]] = embedding.vec.to(devices.device, non_blocking=pin_memory) # c[:, 1:1+embedding.vec.shape[0]] = embedding.vec.to(devices.device, non_blocking=pin_memory)
if img_c is None:
img_c = create_dummy_mask(c, training_width, training_height)
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory) x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
c = shared.sd_model.cond_stage_model(batch.cond_text) c = shared.sd_model.cond_stage_model(batch.cond_text)
loss = shared.sd_model(x, c)[0] / gradient_step cond = {"c_concat": [img_c], "c_crossattn": [c]}
loss = shared.sd_model(x, cond)[0] / gradient_step
del x del x
_loss_step += loss.item() _loss_step += loss.item()