turns out LayerNorm also has weight and bias and needs to be pre-multiplied and trained for hypernets
This commit is contained in:
parent
e4877722e3
commit
03a1e288c4
1 changed files with 2 additions and 2 deletions
|
@ -52,7 +52,7 @@ class HypernetworkModule(torch.nn.Module):
|
||||||
self.load_state_dict(state_dict)
|
self.load_state_dict(state_dict)
|
||||||
else:
|
else:
|
||||||
for layer in self.linear:
|
for layer in self.linear:
|
||||||
if type(layer) == torch.nn.Linear:
|
if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm:
|
||||||
layer.weight.data.normal_(mean=0.0, std=0.01)
|
layer.weight.data.normal_(mean=0.0, std=0.01)
|
||||||
layer.bias.data.zero_()
|
layer.bias.data.zero_()
|
||||||
|
|
||||||
|
@ -80,7 +80,7 @@ class HypernetworkModule(torch.nn.Module):
|
||||||
def trainables(self):
|
def trainables(self):
|
||||||
layer_structure = []
|
layer_structure = []
|
||||||
for layer in self.linear:
|
for layer in self.linear:
|
||||||
if type(layer) == torch.nn.Linear:
|
if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm:
|
||||||
layer_structure += [layer.weight, layer.bias]
|
layer_structure += [layer.weight, layer.bias]
|
||||||
return layer_structure
|
return layer_structure
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue