more careful loading of model weights (eliminates some issues with checkpoints that have weird cond_stage_model layer names)
This commit is contained in:
parent
c1093b8051
commit
10aca1ca3e
1 changed files with 25 additions and 3 deletions
|
@ -122,11 +122,33 @@ def select_checkpoint():
|
|||
return checkpoint_info
|
||||
|
||||
|
||||
chckpoint_dict_replacements = {
|
||||
'cond_stage_model.transformer.embeddings.': 'cond_stage_model.transformer.text_model.embeddings.',
|
||||
'cond_stage_model.transformer.encoder.': 'cond_stage_model.transformer.text_model.encoder.',
|
||||
'cond_stage_model.transformer.final_layer_norm.': 'cond_stage_model.transformer.text_model.final_layer_norm.',
|
||||
}
|
||||
|
||||
|
||||
def transform_checkpoint_dict_key(k):
|
||||
for text, replacement in chckpoint_dict_replacements.items():
|
||||
if k.startswith(text):
|
||||
k = replacement + k[len(text):]
|
||||
|
||||
return k
|
||||
|
||||
|
||||
def get_state_dict_from_checkpoint(pl_sd):
|
||||
if "state_dict" in pl_sd:
|
||||
return pl_sd["state_dict"]
|
||||
pl_sd = pl_sd["state_dict"]
|
||||
|
||||
return pl_sd
|
||||
sd = {}
|
||||
for k, v in pl_sd.items():
|
||||
new_key = transform_checkpoint_dict_key(k)
|
||||
|
||||
if new_key is not None:
|
||||
sd[new_key] = v
|
||||
|
||||
return sd
|
||||
|
||||
|
||||
def load_model_weights(model, checkpoint_info):
|
||||
|
@ -141,7 +163,7 @@ def load_model_weights(model, checkpoint_info):
|
|||
print(f"Global Step: {pl_sd['global_step']}")
|
||||
|
||||
sd = get_state_dict_from_checkpoint(pl_sd)
|
||||
model.load_state_dict(sd, strict=False)
|
||||
missing, extra = model.load_state_dict(sd, strict=False)
|
||||
|
||||
if shared.cmd_opts.opt_channelslast:
|
||||
model.to(memory_format=torch.channels_last)
|
||||
|
|
Loading…
Reference in a new issue