Back compatibility
This commit is contained in:
parent
c702d4d0df
commit
877d94f97c
1 changed files with 10 additions and 7 deletions
|
@ -28,7 +28,7 @@ class HypernetworkModule(torch.nn.Module):
|
||||||
"swish": torch.nn.Hardswish,
|
"swish": torch.nn.Hardswish,
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False):
|
def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False, activate_output=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
assert layer_structure is not None, "layer_structure must not be None"
|
assert layer_structure is not None, "layer_structure must not be None"
|
||||||
|
@ -42,7 +42,7 @@ class HypernetworkModule(torch.nn.Module):
|
||||||
linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
|
linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
|
||||||
|
|
||||||
# Add an activation func except last layer
|
# Add an activation func except last layer
|
||||||
if activation_func == "linear" or activation_func is None or i >= len(layer_structure) - 2:
|
if activation_func == "linear" or activation_func is None or (i >= len(layer_structure) - 2 and not activate_output):
|
||||||
pass
|
pass
|
||||||
elif activation_func in self.activation_dict:
|
elif activation_func in self.activation_dict:
|
||||||
linears.append(self.activation_dict[activation_func]())
|
linears.append(self.activation_dict[activation_func]())
|
||||||
|
@ -105,7 +105,7 @@ class Hypernetwork:
|
||||||
filename = None
|
filename = None
|
||||||
name = None
|
name = None
|
||||||
|
|
||||||
def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False):
|
def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False, activate_output=False):
|
||||||
self.filename = None
|
self.filename = None
|
||||||
self.name = name
|
self.name = name
|
||||||
self.layers = {}
|
self.layers = {}
|
||||||
|
@ -116,11 +116,12 @@ class Hypernetwork:
|
||||||
self.activation_func = activation_func
|
self.activation_func = activation_func
|
||||||
self.add_layer_norm = add_layer_norm
|
self.add_layer_norm = add_layer_norm
|
||||||
self.use_dropout = use_dropout
|
self.use_dropout = use_dropout
|
||||||
|
self.activate_output = activate_output
|
||||||
|
|
||||||
for size in enable_sizes or []:
|
for size in enable_sizes or []:
|
||||||
self.layers[size] = (
|
self.layers[size] = (
|
||||||
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout),
|
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout, self.activate_output),
|
||||||
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout),
|
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout, self.activate_output),
|
||||||
)
|
)
|
||||||
|
|
||||||
def weights(self):
|
def weights(self):
|
||||||
|
@ -147,6 +148,7 @@ class Hypernetwork:
|
||||||
state_dict['use_dropout'] = self.use_dropout
|
state_dict['use_dropout'] = self.use_dropout
|
||||||
state_dict['sd_checkpoint'] = self.sd_checkpoint
|
state_dict['sd_checkpoint'] = self.sd_checkpoint
|
||||||
state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
|
state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
|
||||||
|
state_dict['activate_output'] = self.activate_output
|
||||||
|
|
||||||
torch.save(state_dict, filename)
|
torch.save(state_dict, filename)
|
||||||
|
|
||||||
|
@ -161,12 +163,13 @@ class Hypernetwork:
|
||||||
self.activation_func = state_dict.get('activation_func', None)
|
self.activation_func = state_dict.get('activation_func', None)
|
||||||
self.add_layer_norm = state_dict.get('is_layer_norm', False)
|
self.add_layer_norm = state_dict.get('is_layer_norm', False)
|
||||||
self.use_dropout = state_dict.get('use_dropout', False)
|
self.use_dropout = state_dict.get('use_dropout', False)
|
||||||
|
self.activate_output = state_dict.get('activate_output', True)
|
||||||
|
|
||||||
for size, sd in state_dict.items():
|
for size, sd in state_dict.items():
|
||||||
if type(size) == int:
|
if type(size) == int:
|
||||||
self.layers[size] = (
|
self.layers[size] = (
|
||||||
HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout),
|
HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout, self.activate_output),
|
||||||
HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout),
|
HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout, self.activate_output),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.name = state_dict.get('name', self.name)
|
self.name = state_dict.get('name', self.name)
|
||||||
|
|
Loading…
Reference in a new issue