Refactored learning rate code

This commit is contained in:
Fampai 2022-10-10 17:10:29 -04:00
parent ce37fdd30e
commit 2536ecbb17
2 changed files with 48 additions and 5 deletions

View file

@ -189,8 +189,6 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
embedding = hijack.embedding_db.word_embeddings[embedding_name] embedding = hijack.embedding_db.word_embeddings[embedding_name]
embedding.vec.requires_grad = True embedding.vec.requires_grad = True
optimizer = torch.optim.AdamW([embedding.vec], lr=learn_rate)
losses = torch.zeros((32,)) losses = torch.zeros((32,))
last_saved_file = "<none>" last_saved_file = "<none>"
@ -203,12 +201,24 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
tr_img_len = len([os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]) tr_img_len = len([os.path.join(data_root, file_path) for file_path in os.listdir(data_root)])
epoch_len = (tr_img_len * num_repeats) + tr_img_len epoch_len = (tr_img_len * num_repeats) + tr_img_len
scheduleIter = iter(LearnSchedule(learn_rate, steps, ititial_step))
(learn_rate, end_step) = next(scheduleIter)
print(f'Training at rate of {learn_rate} until step {end_step}')
optimizer = torch.optim.AdamW([embedding.vec], lr=learn_rate)
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step) pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
for i, (x, text) in pbar: for i, (x, text) in pbar:
embedding.step = i + ititial_step embedding.step = i + ititial_step
if embedding.step > steps: if embedding.step > end_step:
try:
(learn_rate, end_step) = next(scheduleIter)
except:
break break
tqdm.tqdm.write(f'Training at rate of {learn_rate} until step {end_step}')
for pg in optimizer.param_groups:
pg['lr'] = learn_rate
if shared.state.interrupted: if shared.state.interrupted:
break break
@ -277,3 +287,36 @@ Last saved image: {html.escape(last_saved_image)}<br/>
return embedding, filename return embedding, filename
class LearnSchedule:
def __init__(self, learn_rate, max_steps, cur_step=0):
pairs = learn_rate.split(',')
self.rates = []
self.it = 0
self.maxit = 0
for i, pair in enumerate(pairs):
tmp = pair.split(':')
if len(tmp) == 2:
step = int(tmp[1])
if step > cur_step:
self.rates.append((float(tmp[0]), min(step, max_steps)))
self.maxit += 1
if step > max_steps:
return
elif step == -1:
self.rates.append((float(tmp[0]), max_steps))
self.maxit += 1
return
else:
self.rates.append((float(tmp[0]), max_steps))
self.maxit += 1
return
def __iter__(self):
return self
def __next__(self):
if self.it < self.maxit:
self.it += 1
return self.rates[self.it - 1]
else:
raise StopIteration

View file

@ -1047,7 +1047,7 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Group(): with gr.Group():
gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding; must specify a directory with a set of 1:1 ratio images</p>") gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding; must specify a directory with a set of 1:1 ratio images</p>")
train_embedding_name = gr.Dropdown(label='Embedding', choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())) train_embedding_name = gr.Dropdown(label='Embedding', choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
learn_rate = gr.Number(label='Learning rate', value=5.0e-03) learn_rate = gr.Textbox(label='Learning rate', placeholder="Learning rate", value = "5.0e-03")
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images") dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion") log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt")) template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"))