small fix
This commit is contained in:
parent
97749b7c7d
commit
6a4fa73a38
1 changed files with 3 additions and 4 deletions
|
@ -51,10 +51,9 @@ class HypernetworkModule(torch.nn.Module):
|
|||
if add_layer_norm:
|
||||
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
|
||||
|
||||
# Add dropout
|
||||
if use_dropout:
|
||||
p = 0.5 if 0 <= i <= len(layer_structure) - 3 else 0.2
|
||||
linears.append(torch.nn.Dropout(p=p))
|
||||
# Add dropout expect last layer
|
||||
if use_dropout and i < len(layer_structure) - 3:
|
||||
linears.append(torch.nn.Dropout(p=0.3))
|
||||
|
||||
self.linear = torch.nn.Sequential(*linears)
|
||||
|
||||
|
|
Loading…
Reference in a new issue