cleanup some code
This commit is contained in:
parent
b297cc3324
commit
40b56c9289
1 changed files with 3 additions and 11 deletions
|
@ -16,6 +16,7 @@ from modules.textual_inversion import textual_inversion
|
||||||
from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
||||||
from torch import einsum
|
from torch import einsum
|
||||||
|
|
||||||
|
from collections import defaultdict, deque
|
||||||
from statistics import stdev, mean
|
from statistics import stdev, mean
|
||||||
|
|
||||||
class HypernetworkModule(torch.nn.Module):
|
class HypernetworkModule(torch.nn.Module):
|
||||||
|
@ -269,15 +270,6 @@ def stack_conds(conds):
|
||||||
return torch.stack(conds)
|
return torch.stack(conds)
|
||||||
|
|
||||||
|
|
||||||
def log_statistics(loss_info:dict, key, value):
|
|
||||||
if key not in loss_info:
|
|
||||||
loss_info[key] = [value]
|
|
||||||
else:
|
|
||||||
loss_info[key].append(value)
|
|
||||||
if len(loss_info[key]) > 1024:
|
|
||||||
loss_info[key].pop(0)
|
|
||||||
|
|
||||||
|
|
||||||
def statistics(data):
|
def statistics(data):
|
||||||
total_information = f"loss:{mean(data):.3f}"+u"\u00B1"+f"({stdev(data)/ (len(data)**0.5):.3f})"
|
total_information = f"loss:{mean(data):.3f}"+u"\u00B1"+f"({stdev(data)/ (len(data)**0.5):.3f})"
|
||||||
recent_data = data[-32:]
|
recent_data = data[-32:]
|
||||||
|
@ -341,7 +333,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
||||||
weight.requires_grad = True
|
weight.requires_grad = True
|
||||||
|
|
||||||
size = len(ds.indexes)
|
size = len(ds.indexes)
|
||||||
loss_dict = {}
|
loss_dict = defaultdict(lambda : deque(maxlen = 1024))
|
||||||
losses = torch.zeros((size,))
|
losses = torch.zeros((size,))
|
||||||
previous_mean_loss = 0
|
previous_mean_loss = 0
|
||||||
print("Mean loss of {} elements".format(size))
|
print("Mean loss of {} elements".format(size))
|
||||||
|
@ -383,7 +375,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
||||||
|
|
||||||
losses[hypernetwork.step % losses.shape[0]] = loss.item()
|
losses[hypernetwork.step % losses.shape[0]] = loss.item()
|
||||||
for entry in entries:
|
for entry in entries:
|
||||||
log_statistics(loss_dict, entry.filename, loss.item())
|
loss_dict[entry.filename].append(loss.item())
|
||||||
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
weights[0].grad = None
|
weights[0].grad = None
|
||||||
|
|
Loading…
Reference in a new issue