diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 7d617680..84e7e350 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -32,10 +32,16 @@ class HypernetworkModule(torch.nn.Module): linears = [] for i in range(len(layer_structure) - 1): linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1]))) + if activation_func == "relu": linears.append(torch.nn.ReLU()) - if activation_func == "leakyrelu": + elif activation_func == "leakyrelu": linears.append(torch.nn.LeakyReLU()) + elif activation_func == 'linear' or activation_func is None: + pass + else: + raise RuntimeError(f'hypernetwork uses an unsupported activation function: {activation_func}') + if add_layer_norm: linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1]))) @@ -46,7 +52,7 @@ class HypernetworkModule(torch.nn.Module): self.load_state_dict(state_dict) else: for layer in self.linear: - if not "ReLU" in layer.__str__(): + if type(layer) == torch.nn.Linear: layer.weight.data.normal_(mean=0.0, std=0.01) layer.bias.data.zero_() @@ -74,7 +80,7 @@ class HypernetworkModule(torch.nn.Module): def trainables(self): layer_structure = [] for layer in self.linear: - if not "ReLU" in layer.__str__(): + if type(layer) == torch.nn.Linear: layer_structure += [layer.weight, layer.bias] return layer_structure