add an option to avoid dying relu

This commit is contained in:
discus0434 2022-10-22 12:02:41 +00:00
parent dcb45dfecf
commit fccba4729d

View file

@ -32,7 +32,6 @@ class HypernetworkModule(torch.nn.Module):
assert layer_structure is not None, "layer_structure must not be None" assert layer_structure is not None, "layer_structure must not be None"
assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!" assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!" assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
assert activation_func not in self.activation_dict.keys() + "linear", f"Valid activation funcs: 'linear', 'relu', 'leakyrelu', 'elu', 'swish'"
linears = [] linears = []
for i in range(len(layer_structure) - 1): for i in range(len(layer_structure) - 1):
@ -43,12 +42,13 @@ class HypernetworkModule(torch.nn.Module):
# Add an activation func # Add an activation func
if activation_func == "linear" or activation_func is None: if activation_func == "linear" or activation_func is None:
pass pass
# If ReLU, Skip adding it to the first layer to avoid dying ReLU
elif activation_func == "relu" and i < 1:
pass
elif activation_func in self.activation_dict: elif activation_func in self.activation_dict:
linears.append(self.activation_dict[activation_func]()) linears.append(self.activation_dict[activation_func]())
else: else:
raise RuntimeError( raise RuntimeError(f'hypernetwork uses an unsupported activation function: {activation_func}')
"Valid activation funcs: 'linear', 'relu', 'leakyrelu', 'elu', 'swish'"
)
# Add dropout # Add dropout
if use_dropout: if use_dropout:
@ -166,8 +166,8 @@ class Hypernetwork:
for size, sd in state_dict.items(): for size, sd in state_dict.items():
if type(size) == int: if type(size) == int:
self.layers[size] = ( self.layers[size] = (
HypernetworkModule(size, sd[0], self.layer_structure, self.add_layer_norm, self.activation_func), HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout),
HypernetworkModule(size, sd[1], self.layer_structure, self.add_layer_norm, self.activation_func), HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout),
) )
self.name = state_dict.get('name', self.name) self.name = state_dict.get('name', self.name)