small fix
This commit is contained in:
parent
fccba4729d
commit
7912acef72
2 changed files with 5 additions and 8 deletions
|
@ -42,22 +42,20 @@ class HypernetworkModule(torch.nn.Module):
|
||||||
# Add an activation func
|
# Add an activation func
|
||||||
if activation_func == "linear" or activation_func is None:
|
if activation_func == "linear" or activation_func is None:
|
||||||
pass
|
pass
|
||||||
# If ReLU, Skip adding it to the first layer to avoid dying ReLU
|
|
||||||
elif activation_func == "relu" and i < 1:
|
|
||||||
pass
|
|
||||||
elif activation_func in self.activation_dict:
|
elif activation_func in self.activation_dict:
|
||||||
linears.append(self.activation_dict[activation_func]())
|
linears.append(self.activation_dict[activation_func]())
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(f'hypernetwork uses an unsupported activation function: {activation_func}')
|
raise RuntimeError(f'hypernetwork uses an unsupported activation function: {activation_func}')
|
||||||
|
|
||||||
# Add dropout
|
|
||||||
if use_dropout:
|
|
||||||
linears.append(torch.nn.Dropout(p=0.3))
|
|
||||||
|
|
||||||
# Add layer normalization
|
# Add layer normalization
|
||||||
if add_layer_norm:
|
if add_layer_norm:
|
||||||
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
|
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
|
||||||
|
|
||||||
|
# Add dropout
|
||||||
|
if use_dropout:
|
||||||
|
p = 0.5 if 0 <= i <= len(layer_structure) - 3 else 0.2
|
||||||
|
linears.append(torch.nn.Dropout(p=p))
|
||||||
|
|
||||||
self.linear = torch.nn.Sequential(*linears)
|
self.linear = torch.nn.Sequential(*linears)
|
||||||
|
|
||||||
if state_dict is not None:
|
if state_dict is not None:
|
||||||
|
|
|
@ -1244,7 +1244,6 @@ def create_ui(wrap_gradio_gpu_call):
|
||||||
new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization")
|
new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization")
|
||||||
new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout")
|
new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout")
|
||||||
overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork")
|
overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork")
|
||||||
new_hypernetwork_activation_func = gr.Dropdown(value="relu", label="Select activation function of hypernetwork", choices=["linear", "relu", "leakyrelu"])
|
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column(scale=3):
|
with gr.Column(scale=3):
|
||||||
|
|
Loading…
Reference in a new issue