parent
108be15500
commit
f89829ec3a
1 changed files with 40 additions and 53 deletions
|
@ -36,14 +36,14 @@ class HypernetworkModule(torch.nn.Module):
|
||||||
linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
|
linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
|
||||||
# if skip_first_layer because first parameters potentially contain negative values
|
# if skip_first_layer because first parameters potentially contain negative values
|
||||||
# if i < 1: continue
|
# if i < 1: continue
|
||||||
if add_layer_norm:
|
|
||||||
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
|
|
||||||
if activation_func in HypernetworkModule.activation_dict:
|
if activation_func in HypernetworkModule.activation_dict:
|
||||||
linears.append(HypernetworkModule.activation_dict[activation_func]())
|
linears.append(HypernetworkModule.activation_dict[activation_func]())
|
||||||
else:
|
else:
|
||||||
print("Invalid key {} encountered as activation function!".format(activation_func))
|
print("Invalid key {} encountered as activation function!".format(activation_func))
|
||||||
# if use_dropout:
|
# if use_dropout:
|
||||||
# linears.append(torch.nn.Dropout(p=0.3))
|
# linears.append(torch.nn.Dropout(p=0.3))
|
||||||
|
if add_layer_norm:
|
||||||
|
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
|
||||||
|
|
||||||
self.linear = torch.nn.Sequential(*linears)
|
self.linear = torch.nn.Sequential(*linears)
|
||||||
|
|
||||||
|
@ -115,24 +115,11 @@ class Hypernetwork:
|
||||||
|
|
||||||
for k, layers in self.layers.items():
|
for k, layers in self.layers.items():
|
||||||
for layer in layers:
|
for layer in layers:
|
||||||
|
layer.train()
|
||||||
res += layer.trainables()
|
res += layer.trainables()
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def eval(self):
|
|
||||||
for k, layers in self.layers.items():
|
|
||||||
for layer in layers:
|
|
||||||
layer.eval()
|
|
||||||
for items in self.weights():
|
|
||||||
items.requires_grad = False
|
|
||||||
|
|
||||||
def train(self):
|
|
||||||
for k, layers in self.layers.items():
|
|
||||||
for layer in layers:
|
|
||||||
layer.train()
|
|
||||||
for items in self.weights():
|
|
||||||
items.requires_grad = True
|
|
||||||
|
|
||||||
def save(self, filename):
|
def save(self, filename):
|
||||||
state_dict = {}
|
state_dict = {}
|
||||||
|
|
||||||
|
@ -303,6 +290,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
||||||
shared.sd_model.first_stage_model.to(devices.cpu)
|
shared.sd_model.first_stage_model.to(devices.cpu)
|
||||||
|
|
||||||
hypernetwork = shared.loaded_hypernetwork
|
hypernetwork = shared.loaded_hypernetwork
|
||||||
|
weights = hypernetwork.weights()
|
||||||
|
for weight in weights:
|
||||||
|
weight.requires_grad = True
|
||||||
|
|
||||||
losses = torch.zeros((32,))
|
losses = torch.zeros((32,))
|
||||||
|
|
||||||
last_saved_file = "<none>"
|
last_saved_file = "<none>"
|
||||||
|
@ -313,10 +304,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
||||||
return hypernetwork, filename
|
return hypernetwork, filename
|
||||||
|
|
||||||
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
|
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
|
||||||
optimizer = torch.optim.AdamW(hypernetwork.weights(), lr=scheduler.learn_rate)
|
# if optimizer == "AdamW": or else Adam / AdamW / SGD, etc...
|
||||||
|
optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate)
|
||||||
|
|
||||||
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
|
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
|
||||||
hypernetwork.train()
|
|
||||||
for i, entries in pbar:
|
for i, entries in pbar:
|
||||||
hypernetwork.step = i + ititial_step
|
hypernetwork.step = i + ititial_step
|
||||||
|
|
||||||
|
@ -337,9 +328,8 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
||||||
|
|
||||||
losses[hypernetwork.step % losses.shape[0]] = loss.item()
|
losses[hypernetwork.step % losses.shape[0]] = loss.item()
|
||||||
|
|
||||||
optimizer.zero_grad(set_to_none=True)
|
optimizer.zero_grad()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
del loss
|
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
mean_loss = losses.mean()
|
mean_loss = losses.mean()
|
||||||
if torch.isnan(mean_loss):
|
if torch.isnan(mean_loss):
|
||||||
|
@ -356,47 +346,44 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
||||||
})
|
})
|
||||||
|
|
||||||
if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0:
|
if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0:
|
||||||
torch.cuda.empty_cache()
|
|
||||||
last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png')
|
last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png')
|
||||||
with torch.no_grad():
|
|
||||||
hypernetwork.eval()
|
|
||||||
shared.sd_model.cond_stage_model.to(devices.device)
|
|
||||||
shared.sd_model.first_stage_model.to(devices.device)
|
|
||||||
|
|
||||||
p = processing.StableDiffusionProcessingTxt2Img(
|
optimizer.zero_grad()
|
||||||
sd_model=shared.sd_model,
|
shared.sd_model.cond_stage_model.to(devices.device)
|
||||||
do_not_save_grid=True,
|
shared.sd_model.first_stage_model.to(devices.device)
|
||||||
do_not_save_samples=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
if preview_from_txt2img:
|
p = processing.StableDiffusionProcessingTxt2Img(
|
||||||
p.prompt = preview_prompt
|
sd_model=shared.sd_model,
|
||||||
p.negative_prompt = preview_negative_prompt
|
do_not_save_grid=True,
|
||||||
p.steps = preview_steps
|
do_not_save_samples=True,
|
||||||
p.sampler_index = preview_sampler_index
|
)
|
||||||
p.cfg_scale = preview_cfg_scale
|
|
||||||
p.seed = preview_seed
|
|
||||||
p.width = preview_width
|
|
||||||
p.height = preview_height
|
|
||||||
else:
|
|
||||||
p.prompt = entries[0].cond_text
|
|
||||||
p.steps = 20
|
|
||||||
|
|
||||||
preview_text = p.prompt
|
if preview_from_txt2img:
|
||||||
|
p.prompt = preview_prompt
|
||||||
|
p.negative_prompt = preview_negative_prompt
|
||||||
|
p.steps = preview_steps
|
||||||
|
p.sampler_index = preview_sampler_index
|
||||||
|
p.cfg_scale = preview_cfg_scale
|
||||||
|
p.seed = preview_seed
|
||||||
|
p.width = preview_width
|
||||||
|
p.height = preview_height
|
||||||
|
else:
|
||||||
|
p.prompt = entries[0].cond_text
|
||||||
|
p.steps = 20
|
||||||
|
|
||||||
processed = processing.process_images(p)
|
preview_text = p.prompt
|
||||||
image = processed.images[0] if len(processed.images)>0 else None
|
|
||||||
|
|
||||||
if unload:
|
processed = processing.process_images(p)
|
||||||
shared.sd_model.cond_stage_model.to(devices.cpu)
|
image = processed.images[0] if len(processed.images)>0 else None
|
||||||
shared.sd_model.first_stage_model.to(devices.cpu)
|
|
||||||
|
|
||||||
if image is not None:
|
if unload:
|
||||||
shared.state.current_image = image
|
shared.sd_model.cond_stage_model.to(devices.cpu)
|
||||||
image.save(last_saved_image)
|
shared.sd_model.first_stage_model.to(devices.cpu)
|
||||||
last_saved_image += f", prompt: {preview_text}"
|
|
||||||
|
|
||||||
hypernetwork.train()
|
if image is not None:
|
||||||
|
shared.state.current_image = image
|
||||||
|
image.save(last_saved_image)
|
||||||
|
last_saved_image += f", prompt: {preview_text}"
|
||||||
|
|
||||||
shared.state.job_no = hypernetwork.step
|
shared.state.job_no = hypernetwork.step
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue