use shared function from processing for creating dummy mask when training inpainting model
This commit is contained in:
parent
184e670126
commit
525cea9245
2 changed files with 29 additions and 43 deletions
|
@ -76,6 +76,24 @@ def apply_overlay(image, paste_loc, index, overlays):
|
||||||
return image
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
def txt2img_image_conditioning(sd_model, x, width, height):
|
||||||
|
if sd_model.model.conditioning_key not in {'hybrid', 'concat'}:
|
||||||
|
# Dummy zero conditioning if we're not using inpainting model.
|
||||||
|
# Still takes up a bit of memory, but no encoder call.
|
||||||
|
# Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
|
||||||
|
return x.new_zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device)
|
||||||
|
|
||||||
|
# The "masked-image" in this case will just be all zeros since the entire image is masked.
|
||||||
|
image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
|
||||||
|
image_conditioning = sd_model.get_first_stage_encoding(sd_model.encode_first_stage(image_conditioning))
|
||||||
|
|
||||||
|
# Add the fake full 1s mask to the first dimension.
|
||||||
|
image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
|
||||||
|
image_conditioning = image_conditioning.to(x.dtype)
|
||||||
|
|
||||||
|
return image_conditioning
|
||||||
|
|
||||||
|
|
||||||
class StableDiffusionProcessing():
|
class StableDiffusionProcessing():
|
||||||
"""
|
"""
|
||||||
The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
|
The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
|
||||||
|
@ -139,26 +157,9 @@ class StableDiffusionProcessing():
|
||||||
self.iteration = 0
|
self.iteration = 0
|
||||||
|
|
||||||
def txt2img_image_conditioning(self, x, width=None, height=None):
|
def txt2img_image_conditioning(self, x, width=None, height=None):
|
||||||
if self.sampler.conditioning_key not in {'hybrid', 'concat'}:
|
self.is_using_inpainting_conditioning = self.sd_model.model.conditioning_key in {'hybrid', 'concat'}
|
||||||
# Dummy zero conditioning if we're not using inpainting model.
|
|
||||||
# Still takes up a bit of memory, but no encoder call.
|
|
||||||
# Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
|
|
||||||
return x.new_zeros(x.shape[0], 5, 1, 1)
|
|
||||||
|
|
||||||
self.is_using_inpainting_conditioning = True
|
return txt2img_image_conditioning(self.sd_model, x, width or self.width, height or self.height)
|
||||||
|
|
||||||
height = height or self.height
|
|
||||||
width = width or self.width
|
|
||||||
|
|
||||||
# The "masked-image" in this case will just be all zeros since the entire image is masked.
|
|
||||||
image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
|
|
||||||
image_conditioning = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image_conditioning))
|
|
||||||
|
|
||||||
# Add the fake full 1s mask to the first dimension.
|
|
||||||
image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
|
|
||||||
image_conditioning = image_conditioning.to(x.dtype)
|
|
||||||
|
|
||||||
return image_conditioning
|
|
||||||
|
|
||||||
def depth2img_image_conditioning(self, source_image):
|
def depth2img_image_conditioning(self, source_image):
|
||||||
# Use the AddMiDaS helper to Format our source image to suit the MiDaS model
|
# Use the AddMiDaS helper to Format our source image to suit the MiDaS model
|
||||||
|
|
|
@ -252,26 +252,6 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat
|
||||||
assert log_directory, "Log directory is empty"
|
assert log_directory, "Log directory is empty"
|
||||||
|
|
||||||
|
|
||||||
def create_dummy_mask(x, width=None, height=None):
|
|
||||||
if shared.sd_model.model.conditioning_key in {'hybrid', 'concat'}:
|
|
||||||
|
|
||||||
# The "masked-image" in this case will just be all zeros since the entire image is masked.
|
|
||||||
image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
|
|
||||||
image_conditioning = shared.sd_model.get_first_stage_encoding(shared.sd_model.encode_first_stage(image_conditioning))
|
|
||||||
|
|
||||||
# Add the fake full 1s mask to the first dimension.
|
|
||||||
image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
|
|
||||||
image_conditioning = image_conditioning.to(x.dtype)
|
|
||||||
|
|
||||||
else:
|
|
||||||
# Dummy zero conditioning if we're not using inpainting model.
|
|
||||||
# Still takes up a bit of memory, but no encoder call.
|
|
||||||
# Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
|
|
||||||
image_conditioning = torch.zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device)
|
|
||||||
|
|
||||||
return image_conditioning
|
|
||||||
|
|
||||||
|
|
||||||
def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
||||||
save_embedding_every = save_embedding_every or 0
|
save_embedding_every = save_embedding_every or 0
|
||||||
create_image_every = create_image_every or 0
|
create_image_every = create_image_every or 0
|
||||||
|
@ -346,7 +326,6 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
|
||||||
else:
|
else:
|
||||||
print("No saved optimizer exists in checkpoint")
|
print("No saved optimizer exists in checkpoint")
|
||||||
|
|
||||||
|
|
||||||
scaler = torch.cuda.amp.GradScaler()
|
scaler = torch.cuda.amp.GradScaler()
|
||||||
|
|
||||||
batch_size = ds.batch_size
|
batch_size = ds.batch_size
|
||||||
|
@ -362,7 +341,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
|
||||||
forced_filename = "<none>"
|
forced_filename = "<none>"
|
||||||
embedding_yet_to_be_embedded = False
|
embedding_yet_to_be_embedded = False
|
||||||
|
|
||||||
|
is_training_inpainting_model = shared.sd_model.model.conditioning_key in {'hybrid', 'concat'}
|
||||||
img_c = None
|
img_c = None
|
||||||
|
|
||||||
pbar = tqdm.tqdm(total=steps - initial_step)
|
pbar = tqdm.tqdm(total=steps - initial_step)
|
||||||
try:
|
try:
|
||||||
for i in range((steps-initial_step) * gradient_step):
|
for i in range((steps-initial_step) * gradient_step):
|
||||||
|
@ -384,10 +365,14 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
|
||||||
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
|
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
|
||||||
c = shared.sd_model.cond_stage_model(batch.cond_text)
|
c = shared.sd_model.cond_stage_model(batch.cond_text)
|
||||||
|
|
||||||
if img_c is None:
|
if is_training_inpainting_model:
|
||||||
img_c = create_dummy_mask(c, training_width, training_height)
|
if img_c is None:
|
||||||
|
img_c = processing.txt2img_image_conditioning(shared.sd_model, c, training_width, training_height)
|
||||||
|
|
||||||
|
cond = {"c_concat": [img_c], "c_crossattn": [c]}
|
||||||
|
else:
|
||||||
|
cond = c
|
||||||
|
|
||||||
cond = {"c_concat": [img_c], "c_crossattn": [c]}
|
|
||||||
loss = shared.sd_model(x, cond)[0] / gradient_step
|
loss = shared.sd_model(x, cond)[0] / gradient_step
|
||||||
del x
|
del x
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue