a more strict check for activation type and a more reasonable check for type of layer in hypernets

This commit is contained in:
AUTOMATIC 2022-10-21 09:47:43 +03:00
parent a26fc2834c
commit c23f666dba

View file

@ -32,10 +32,16 @@ class HypernetworkModule(torch.nn.Module):
linears = [] linears = []
for i in range(len(layer_structure) - 1): for i in range(len(layer_structure) - 1):
linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1]))) linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
if activation_func == "relu": if activation_func == "relu":
linears.append(torch.nn.ReLU()) linears.append(torch.nn.ReLU())
if activation_func == "leakyrelu": elif activation_func == "leakyrelu":
linears.append(torch.nn.LeakyReLU()) linears.append(torch.nn.LeakyReLU())
elif activation_func == 'linear' or activation_func is None:
pass
else:
raise RuntimeError(f'hypernetwork uses an unsupported activation function: {activation_func}')
if add_layer_norm: if add_layer_norm:
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1]))) linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
@ -46,7 +52,7 @@ class HypernetworkModule(torch.nn.Module):
self.load_state_dict(state_dict) self.load_state_dict(state_dict)
else: else:
for layer in self.linear: for layer in self.linear:
if not "ReLU" in layer.__str__(): if type(layer) == torch.nn.Linear:
layer.weight.data.normal_(mean=0.0, std=0.01) layer.weight.data.normal_(mean=0.0, std=0.01)
layer.bias.data.zero_() layer.bias.data.zero_()
@ -74,7 +80,7 @@ class HypernetworkModule(torch.nn.Module):
def trainables(self): def trainables(self):
layer_structure = [] layer_structure = []
for layer in self.linear: for layer in self.linear:
if not "ReLU" in layer.__str__(): if type(layer) == torch.nn.Linear:
layer_structure += [layer.weight, layer.bias] layer_structure += [layer.weight, layer.bias]
return layer_structure return layer_structure