turns out LayerNorm also has weight and bias and needs to be pre-multiplied and trained for hypernets

This commit is contained in:
AUTOMATIC 2022-10-21 10:13:24 +03:00
parent e4877722e3
commit 03a1e288c4

View file

@ -52,7 +52,7 @@ class HypernetworkModule(torch.nn.Module):
self.load_state_dict(state_dict) self.load_state_dict(state_dict)
else: else:
for layer in self.linear: for layer in self.linear:
if type(layer) == torch.nn.Linear: if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm:
layer.weight.data.normal_(mean=0.0, std=0.01) layer.weight.data.normal_(mean=0.0, std=0.01)
layer.bias.data.zero_() layer.bias.data.zero_()
@ -80,7 +80,7 @@ class HypernetworkModule(torch.nn.Module):
def trainables(self): def trainables(self):
layer_structure = [] layer_structure = []
for layer in self.linear: for layer in self.linear:
if type(layer) == torch.nn.Linear: if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm:
layer_structure += [layer.weight, layer.bias] layer_structure += [layer.weight, layer.bias]
return layer_structure return layer_structure