set TI AdamW default weight decay to 0

This commit is contained in:
flamelaw 2022-11-27 00:35:44 +09:00
parent 1bd57cc979
commit 755df94b2a

View file

@ -283,7 +283,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
shared.sd_model.first_stage_model.to(devices.cpu)
embedding.vec.requires_grad = True
optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate)
optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate, weight_decay=0.0)
scaler = torch.cuda.amp.GradScaler()
batch_size = ds.batch_size