fix the merge
This commit is contained in:
parent
8839b372bf
commit
184e670126
1 changed files with 5 additions and 9 deletions
|
@ -251,6 +251,7 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat
|
||||||
if save_model_every or create_image_every:
|
if save_model_every or create_image_every:
|
||||||
assert log_directory, "Log directory is empty"
|
assert log_directory, "Log directory is empty"
|
||||||
|
|
||||||
|
|
||||||
def create_dummy_mask(x, width=None, height=None):
|
def create_dummy_mask(x, width=None, height=None):
|
||||||
if shared.sd_model.model.conditioning_key in {'hybrid', 'concat'}:
|
if shared.sd_model.model.conditioning_key in {'hybrid', 'concat'}:
|
||||||
|
|
||||||
|
@ -380,17 +381,12 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
|
||||||
break
|
break
|
||||||
|
|
||||||
with devices.autocast():
|
with devices.autocast():
|
||||||
# c = stack_conds(batch.cond).to(devices.device)
|
|
||||||
# mask = torch.tensor(batch.emb_index).to(devices.device, non_blocking=pin_memory)
|
|
||||||
# print(mask)
|
|
||||||
# c[:, 1:1+embedding.vec.shape[0]] = embedding.vec.to(devices.device, non_blocking=pin_memory)
|
|
||||||
|
|
||||||
|
|
||||||
if img_c is None:
|
|
||||||
img_c = create_dummy_mask(c, training_width, training_height)
|
|
||||||
|
|
||||||
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
|
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
|
||||||
c = shared.sd_model.cond_stage_model(batch.cond_text)
|
c = shared.sd_model.cond_stage_model(batch.cond_text)
|
||||||
|
|
||||||
|
if img_c is None:
|
||||||
|
img_c = create_dummy_mask(c, training_width, training_height)
|
||||||
|
|
||||||
cond = {"c_concat": [img_c], "c_crossattn": [c]}
|
cond = {"c_concat": [img_c], "c_crossattn": [c]}
|
||||||
loss = shared.sd_model(x, cond)[0] / gradient_step
|
loss = shared.sd_model(x, cond)[0] / gradient_step
|
||||||
del x
|
del x
|
||||||
|
|
Loading…
Reference in a new issue