Merge pull request #2 from aria1th/patch-6
generalized some functions and option for ignoring first layer
This commit is contained in:
commit
6a02841fff
1 changed files with 15 additions and 8 deletions
|
@ -21,21 +21,27 @@ from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
||||||
|
|
||||||
class HypernetworkModule(torch.nn.Module):
|
class HypernetworkModule(torch.nn.Module):
|
||||||
multiplier = 1.0
|
multiplier = 1.0
|
||||||
|
activation_dict = {"relu": torch.nn.ReLU, "leakyrelu": torch.nn.LeakyReLU, "elu": torch.nn.ELU,
|
||||||
|
"swish": torch.nn.Hardswish}
|
||||||
|
|
||||||
def __init__(self, dim, state_dict=None, layer_structure=None, add_layer_norm=False, activation_func=None):
|
def __init__(self, dim, state_dict=None, layer_structure=None, add_layer_norm=False, activation_func=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
assert layer_structure is not None, "layer_structure must not be None"
|
assert layer_structure is not None, "layer_structure must not be None"
|
||||||
assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
|
assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
|
||||||
assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
|
assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
|
||||||
|
|
||||||
linears = []
|
linears = []
|
||||||
for i in range(len(layer_structure) - 1):
|
for i in range(len(layer_structure) - 1):
|
||||||
linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
|
linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
|
||||||
if activation_func == "relu":
|
# if skip_first_layer because first parameters potentially contain negative values
|
||||||
linears.append(torch.nn.ReLU())
|
# if i < 1: continue
|
||||||
if activation_func == "leakyrelu":
|
if activation_func in HypernetworkModule.activation_dict:
|
||||||
linears.append(torch.nn.LeakyReLU())
|
linears.append(HypernetworkModule.activation_dict[activation_func]())
|
||||||
|
else:
|
||||||
|
print("Invalid key {} encountered as activation function!".format(activation_func))
|
||||||
|
# if use_dropout:
|
||||||
|
# linears.append(torch.nn.Dropout(p=0.3))
|
||||||
if add_layer_norm:
|
if add_layer_norm:
|
||||||
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
|
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
|
||||||
|
|
||||||
|
@ -46,7 +52,7 @@ class HypernetworkModule(torch.nn.Module):
|
||||||
self.load_state_dict(state_dict)
|
self.load_state_dict(state_dict)
|
||||||
else:
|
else:
|
||||||
for layer in self.linear:
|
for layer in self.linear:
|
||||||
if not "ReLU" in layer.__str__():
|
if isinstance(layer, torch.nn.Linear):
|
||||||
layer.weight.data.normal_(mean=0.0, std=0.01)
|
layer.weight.data.normal_(mean=0.0, std=0.01)
|
||||||
layer.bias.data.zero_()
|
layer.bias.data.zero_()
|
||||||
|
|
||||||
|
@ -74,7 +80,7 @@ class HypernetworkModule(torch.nn.Module):
|
||||||
def trainables(self):
|
def trainables(self):
|
||||||
layer_structure = []
|
layer_structure = []
|
||||||
for layer in self.linear:
|
for layer in self.linear:
|
||||||
if not "ReLU" in layer.__str__():
|
if isinstance(layer, torch.nn.Linear):
|
||||||
layer_structure += [layer.weight, layer.bias]
|
layer_structure += [layer.weight, layer.bias]
|
||||||
return layer_structure
|
return layer_structure
|
||||||
|
|
||||||
|
@ -298,6 +304,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
||||||
return hypernetwork, filename
|
return hypernetwork, filename
|
||||||
|
|
||||||
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
|
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
|
||||||
|
# if optimizer == "AdamW": or else Adam / AdamW / SGD, etc...
|
||||||
optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate)
|
optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate)
|
||||||
|
|
||||||
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
|
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
|
||||||
|
|
Loading…
Reference in a new issue