apply lr schedule to hypernets
This commit is contained in:
parent
12f4f4761b
commit
d6fcc6b87b
4 changed files with 54 additions and 45 deletions
|
@ -14,6 +14,7 @@ import torch
|
||||||
from torch import einsum
|
from torch import einsum
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
import modules.textual_inversion.dataset
|
import modules.textual_inversion.dataset
|
||||||
|
from modules.textual_inversion.learn_schedule import LearnSchedule
|
||||||
|
|
||||||
|
|
||||||
class HypernetworkModule(torch.nn.Module):
|
class HypernetworkModule(torch.nn.Module):
|
||||||
|
@ -202,8 +203,6 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
|
||||||
for weight in weights:
|
for weight in weights:
|
||||||
weight.requires_grad = True
|
weight.requires_grad = True
|
||||||
|
|
||||||
optimizer = torch.optim.AdamW(weights, lr=learn_rate)
|
|
||||||
|
|
||||||
losses = torch.zeros((32,))
|
losses = torch.zeros((32,))
|
||||||
|
|
||||||
last_saved_file = "<none>"
|
last_saved_file = "<none>"
|
||||||
|
@ -213,12 +212,24 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
|
||||||
if ititial_step > steps:
|
if ititial_step > steps:
|
||||||
return hypernetwork, filename
|
return hypernetwork, filename
|
||||||
|
|
||||||
|
schedules = iter(LearnSchedule(learn_rate, steps, ititial_step))
|
||||||
|
(learn_rate, end_step) = next(schedules)
|
||||||
|
print(f'Training at rate of {learn_rate} until step {end_step}')
|
||||||
|
|
||||||
|
optimizer = torch.optim.AdamW(weights, lr=learn_rate)
|
||||||
|
|
||||||
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
|
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
|
||||||
for i, (x, text, cond) in pbar:
|
for i, (x, text, cond) in pbar:
|
||||||
hypernetwork.step = i + ititial_step
|
hypernetwork.step = i + ititial_step
|
||||||
|
|
||||||
if hypernetwork.step > steps:
|
if hypernetwork.step > end_step:
|
||||||
break
|
try:
|
||||||
|
(learn_rate, end_step) = next(schedules)
|
||||||
|
except Exception:
|
||||||
|
break
|
||||||
|
tqdm.tqdm.write(f'Training at rate of {learn_rate} until step {end_step}')
|
||||||
|
for pg in optimizer.param_groups:
|
||||||
|
pg['lr'] = learn_rate
|
||||||
|
|
||||||
if shared.state.interrupted:
|
if shared.state.interrupted:
|
||||||
break
|
break
|
||||||
|
|
34
modules/textual_inversion/learn_schedule.py
Normal file
34
modules/textual_inversion/learn_schedule.py
Normal file
|
@ -0,0 +1,34 @@
|
||||||
|
|
||||||
|
class LearnSchedule:
|
||||||
|
def __init__(self, learn_rate, max_steps, cur_step=0):
|
||||||
|
pairs = learn_rate.split(',')
|
||||||
|
self.rates = []
|
||||||
|
self.it = 0
|
||||||
|
self.maxit = 0
|
||||||
|
for i, pair in enumerate(pairs):
|
||||||
|
tmp = pair.split(':')
|
||||||
|
if len(tmp) == 2:
|
||||||
|
step = int(tmp[1])
|
||||||
|
if step > cur_step:
|
||||||
|
self.rates.append((float(tmp[0]), min(step, max_steps)))
|
||||||
|
self.maxit += 1
|
||||||
|
if step > max_steps:
|
||||||
|
return
|
||||||
|
elif step == -1:
|
||||||
|
self.rates.append((float(tmp[0]), max_steps))
|
||||||
|
self.maxit += 1
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
self.rates.append((float(tmp[0]), max_steps))
|
||||||
|
self.maxit += 1
|
||||||
|
return
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __next__(self):
|
||||||
|
if self.it < self.maxit:
|
||||||
|
self.it += 1
|
||||||
|
return self.rates[self.it - 1]
|
||||||
|
else:
|
||||||
|
raise StopIteration
|
|
@ -10,6 +10,7 @@ import datetime
|
||||||
|
|
||||||
from modules import shared, devices, sd_hijack, processing, sd_models
|
from modules import shared, devices, sd_hijack, processing, sd_models
|
||||||
import modules.textual_inversion.dataset
|
import modules.textual_inversion.dataset
|
||||||
|
from modules.textual_inversion.learn_schedule import LearnSchedule
|
||||||
|
|
||||||
|
|
||||||
class Embedding:
|
class Embedding:
|
||||||
|
@ -198,11 +199,8 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
|
||||||
if ititial_step > steps:
|
if ititial_step > steps:
|
||||||
return embedding, filename
|
return embedding, filename
|
||||||
|
|
||||||
tr_img_len = len([os.path.join(data_root, file_path) for file_path in os.listdir(data_root)])
|
schedules = iter(LearnSchedule(learn_rate, steps, ititial_step))
|
||||||
epoch_len = (tr_img_len * num_repeats) + tr_img_len
|
(learn_rate, end_step) = next(schedules)
|
||||||
|
|
||||||
scheduleIter = iter(LearnSchedule(learn_rate, steps, ititial_step))
|
|
||||||
(learn_rate, end_step) = next(scheduleIter)
|
|
||||||
print(f'Training at rate of {learn_rate} until step {end_step}')
|
print(f'Training at rate of {learn_rate} until step {end_step}')
|
||||||
|
|
||||||
optimizer = torch.optim.AdamW([embedding.vec], lr=learn_rate)
|
optimizer = torch.optim.AdamW([embedding.vec], lr=learn_rate)
|
||||||
|
@ -213,7 +211,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
|
||||||
|
|
||||||
if embedding.step > end_step:
|
if embedding.step > end_step:
|
||||||
try:
|
try:
|
||||||
(learn_rate, end_step) = next(scheduleIter)
|
(learn_rate, end_step) = next(schedules)
|
||||||
except:
|
except:
|
||||||
break
|
break
|
||||||
tqdm.tqdm.write(f'Training at rate of {learn_rate} until step {end_step}')
|
tqdm.tqdm.write(f'Training at rate of {learn_rate} until step {end_step}')
|
||||||
|
@ -288,37 +286,3 @@ Last saved image: {html.escape(last_saved_image)}<br/>
|
||||||
embedding.save(filename)
|
embedding.save(filename)
|
||||||
|
|
||||||
return embedding, filename
|
return embedding, filename
|
||||||
|
|
||||||
class LearnSchedule:
|
|
||||||
def __init__(self, learn_rate, max_steps, cur_step=0):
|
|
||||||
pairs = learn_rate.split(',')
|
|
||||||
self.rates = []
|
|
||||||
self.it = 0
|
|
||||||
self.maxit = 0
|
|
||||||
for i, pair in enumerate(pairs):
|
|
||||||
tmp = pair.split(':')
|
|
||||||
if len(tmp) == 2:
|
|
||||||
step = int(tmp[1])
|
|
||||||
if step > cur_step:
|
|
||||||
self.rates.append((float(tmp[0]), min(step, max_steps)))
|
|
||||||
self.maxit += 1
|
|
||||||
if step > max_steps:
|
|
||||||
return
|
|
||||||
elif step == -1:
|
|
||||||
self.rates.append((float(tmp[0]), max_steps))
|
|
||||||
self.maxit += 1
|
|
||||||
return
|
|
||||||
else:
|
|
||||||
self.rates.append((float(tmp[0]), max_steps))
|
|
||||||
self.maxit += 1
|
|
||||||
return
|
|
||||||
|
|
||||||
def __iter__(self):
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __next__(self):
|
|
||||||
if self.it < self.maxit:
|
|
||||||
self.it += 1
|
|
||||||
return self.rates[self.it - 1]
|
|
||||||
else:
|
|
||||||
raise StopIteration
|
|
||||||
|
|
|
@ -1070,7 +1070,7 @@ def create_ui(wrap_gradio_gpu_call):
|
||||||
gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding; must specify a directory with a set of 1:1 ratio images</p>")
|
gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding; must specify a directory with a set of 1:1 ratio images</p>")
|
||||||
train_embedding_name = gr.Dropdown(label='Embedding', choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
|
train_embedding_name = gr.Dropdown(label='Embedding', choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
|
||||||
train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', choices=[x for x in shared.hypernetworks.keys()])
|
train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', choices=[x for x in shared.hypernetworks.keys()])
|
||||||
learn_rate = gr.Textbox(label='Learning rate', placeholder="Learning rate", value = "5.0e-03")
|
learn_rate = gr.Textbox(label='Learning rate', placeholder="Learning rate", value="0.005")
|
||||||
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
|
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
|
||||||
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
|
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
|
||||||
template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"))
|
template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"))
|
||||||
|
|
Loading…
Reference in a new issue