diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 7d617680..905cbeef 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -21,21 +21,27 @@ from modules.textual_inversion.learn_schedule import LearnRateScheduler class HypernetworkModule(torch.nn.Module): multiplier = 1.0 - + activation_dict = {"relu": torch.nn.ReLU, "leakyrelu": torch.nn.LeakyReLU, "elu": torch.nn.ELU, + "swish": torch.nn.Hardswish} + def __init__(self, dim, state_dict=None, layer_structure=None, add_layer_norm=False, activation_func=None): super().__init__() assert layer_structure is not None, "layer_structure must not be None" assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!" assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!" - + linears = [] for i in range(len(layer_structure) - 1): linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1]))) - if activation_func == "relu": - linears.append(torch.nn.ReLU()) - if activation_func == "leakyrelu": - linears.append(torch.nn.LeakyReLU()) + # if skip_first_layer because first parameters potentially contain negative values + # if i < 1: continue + if activation_func in HypernetworkModule.activation_dict: + linears.append(HypernetworkModule.activation_dict[activation_func]()) + else: + print("Invalid key {} encountered as activation function!".format(activation_func)) + # if use_dropout: + # linears.append(torch.nn.Dropout(p=0.3)) if add_layer_norm: linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1]))) @@ -46,7 +52,7 @@ class HypernetworkModule(torch.nn.Module): self.load_state_dict(state_dict) else: for layer in self.linear: - if not "ReLU" in layer.__str__(): + if isinstance(layer, torch.nn.Linear): layer.weight.data.normal_(mean=0.0, std=0.01) layer.bias.data.zero_() @@ -74,7 +80,7 @@ class HypernetworkModule(torch.nn.Module): def trainables(self): layer_structure = [] for layer in self.linear: - if not "ReLU" in layer.__str__(): + if isinstance(layer, torch.nn.Linear): layer_structure += [layer.weight, layer.bias] return layer_structure @@ -298,6 +304,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log return hypernetwork, filename scheduler = LearnRateScheduler(learn_rate, steps, ititial_step) + # if optimizer == "AdamW": or else Adam / AdamW / SGD, etc... optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate) pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)