add an option to avoid dying relu
This commit is contained in:
parent
dcb45dfecf
commit
fccba4729d
1 changed files with 6 additions and 6 deletions
|
@ -32,7 +32,6 @@ class HypernetworkModule(torch.nn.Module):
|
|||
assert layer_structure is not None, "layer_structure must not be None"
|
||||
assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
|
||||
assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
|
||||
assert activation_func not in self.activation_dict.keys() + "linear", f"Valid activation funcs: 'linear', 'relu', 'leakyrelu', 'elu', 'swish'"
|
||||
|
||||
linears = []
|
||||
for i in range(len(layer_structure) - 1):
|
||||
|
@ -43,12 +42,13 @@ class HypernetworkModule(torch.nn.Module):
|
|||
# Add an activation func
|
||||
if activation_func == "linear" or activation_func is None:
|
||||
pass
|
||||
# If ReLU, Skip adding it to the first layer to avoid dying ReLU
|
||||
elif activation_func == "relu" and i < 1:
|
||||
pass
|
||||
elif activation_func in self.activation_dict:
|
||||
linears.append(self.activation_dict[activation_func]())
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Valid activation funcs: 'linear', 'relu', 'leakyrelu', 'elu', 'swish'"
|
||||
)
|
||||
raise RuntimeError(f'hypernetwork uses an unsupported activation function: {activation_func}')
|
||||
|
||||
# Add dropout
|
||||
if use_dropout:
|
||||
|
@ -166,8 +166,8 @@ class Hypernetwork:
|
|||
for size, sd in state_dict.items():
|
||||
if type(size) == int:
|
||||
self.layers[size] = (
|
||||
HypernetworkModule(size, sd[0], self.layer_structure, self.add_layer_norm, self.activation_func),
|
||||
HypernetworkModule(size, sd[1], self.layer_structure, self.add_layer_norm, self.activation_func),
|
||||
HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout),
|
||||
HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout),
|
||||
)
|
||||
|
||||
self.name = state_dict.get('name', self.name)
|
||||
|
|
Loading…
Reference in a new issue