apply lr schedule to hypernets

This commit is contained in:
AUTOMATIC 2022-10-11 22:03:05 +03:00
parent 12f4f4761b
commit d6fcc6b87b
4 changed files with 54 additions and 45 deletions

View file

@ -14,6 +14,7 @@ import torch
from torch import einsum from torch import einsum
from einops import rearrange, repeat from einops import rearrange, repeat
import modules.textual_inversion.dataset import modules.textual_inversion.dataset
from modules.textual_inversion.learn_schedule import LearnSchedule
class HypernetworkModule(torch.nn.Module): class HypernetworkModule(torch.nn.Module):
@ -202,8 +203,6 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
for weight in weights: for weight in weights:
weight.requires_grad = True weight.requires_grad = True
optimizer = torch.optim.AdamW(weights, lr=learn_rate)
losses = torch.zeros((32,)) losses = torch.zeros((32,))
last_saved_file = "<none>" last_saved_file = "<none>"
@ -213,12 +212,24 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
if ititial_step > steps: if ititial_step > steps:
return hypernetwork, filename return hypernetwork, filename
schedules = iter(LearnSchedule(learn_rate, steps, ititial_step))
(learn_rate, end_step) = next(schedules)
print(f'Training at rate of {learn_rate} until step {end_step}')
optimizer = torch.optim.AdamW(weights, lr=learn_rate)
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step) pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
for i, (x, text, cond) in pbar: for i, (x, text, cond) in pbar:
hypernetwork.step = i + ititial_step hypernetwork.step = i + ititial_step
if hypernetwork.step > steps: if hypernetwork.step > end_step:
try:
(learn_rate, end_step) = next(schedules)
except Exception:
break break
tqdm.tqdm.write(f'Training at rate of {learn_rate} until step {end_step}')
for pg in optimizer.param_groups:
pg['lr'] = learn_rate
if shared.state.interrupted: if shared.state.interrupted:
break break

View file

@ -0,0 +1,34 @@
class LearnSchedule:
def __init__(self, learn_rate, max_steps, cur_step=0):
pairs = learn_rate.split(',')
self.rates = []
self.it = 0
self.maxit = 0
for i, pair in enumerate(pairs):
tmp = pair.split(':')
if len(tmp) == 2:
step = int(tmp[1])
if step > cur_step:
self.rates.append((float(tmp[0]), min(step, max_steps)))
self.maxit += 1
if step > max_steps:
return
elif step == -1:
self.rates.append((float(tmp[0]), max_steps))
self.maxit += 1
return
else:
self.rates.append((float(tmp[0]), max_steps))
self.maxit += 1
return
def __iter__(self):
return self
def __next__(self):
if self.it < self.maxit:
self.it += 1
return self.rates[self.it - 1]
else:
raise StopIteration

View file

@ -10,6 +10,7 @@ import datetime
from modules import shared, devices, sd_hijack, processing, sd_models from modules import shared, devices, sd_hijack, processing, sd_models
import modules.textual_inversion.dataset import modules.textual_inversion.dataset
from modules.textual_inversion.learn_schedule import LearnSchedule
class Embedding: class Embedding:
@ -198,11 +199,8 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
if ititial_step > steps: if ititial_step > steps:
return embedding, filename return embedding, filename
tr_img_len = len([os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]) schedules = iter(LearnSchedule(learn_rate, steps, ititial_step))
epoch_len = (tr_img_len * num_repeats) + tr_img_len (learn_rate, end_step) = next(schedules)
scheduleIter = iter(LearnSchedule(learn_rate, steps, ititial_step))
(learn_rate, end_step) = next(scheduleIter)
print(f'Training at rate of {learn_rate} until step {end_step}') print(f'Training at rate of {learn_rate} until step {end_step}')
optimizer = torch.optim.AdamW([embedding.vec], lr=learn_rate) optimizer = torch.optim.AdamW([embedding.vec], lr=learn_rate)
@ -213,7 +211,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
if embedding.step > end_step: if embedding.step > end_step:
try: try:
(learn_rate, end_step) = next(scheduleIter) (learn_rate, end_step) = next(schedules)
except: except:
break break
tqdm.tqdm.write(f'Training at rate of {learn_rate} until step {end_step}') tqdm.tqdm.write(f'Training at rate of {learn_rate} until step {end_step}')
@ -288,37 +286,3 @@ Last saved image: {html.escape(last_saved_image)}<br/>
embedding.save(filename) embedding.save(filename)
return embedding, filename return embedding, filename
class LearnSchedule:
def __init__(self, learn_rate, max_steps, cur_step=0):
pairs = learn_rate.split(',')
self.rates = []
self.it = 0
self.maxit = 0
for i, pair in enumerate(pairs):
tmp = pair.split(':')
if len(tmp) == 2:
step = int(tmp[1])
if step > cur_step:
self.rates.append((float(tmp[0]), min(step, max_steps)))
self.maxit += 1
if step > max_steps:
return
elif step == -1:
self.rates.append((float(tmp[0]), max_steps))
self.maxit += 1
return
else:
self.rates.append((float(tmp[0]), max_steps))
self.maxit += 1
return
def __iter__(self):
return self
def __next__(self):
if self.it < self.maxit:
self.it += 1
return self.rates[self.it - 1]
else:
raise StopIteration

View file

@ -1070,7 +1070,7 @@ def create_ui(wrap_gradio_gpu_call):
gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding; must specify a directory with a set of 1:1 ratio images</p>") gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding; must specify a directory with a set of 1:1 ratio images</p>")
train_embedding_name = gr.Dropdown(label='Embedding', choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())) train_embedding_name = gr.Dropdown(label='Embedding', choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', choices=[x for x in shared.hypernetworks.keys()]) train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', choices=[x for x in shared.hypernetworks.keys()])
learn_rate = gr.Textbox(label='Learning rate', placeholder="Learning rate", value = "5.0e-03") learn_rate = gr.Textbox(label='Learning rate', placeholder="Learning rate", value="0.005")
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images") dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion") log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt")) template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"))