Optimized code for Ignoring last CLIP layers
This commit is contained in:
parent
6c383d2e82
commit
e59c66c008
1 changed files with 4 additions and 8 deletions
|
@ -282,14 +282,10 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
|||
remade_batch_tokens_of_same_length = [x + [self.wrapped.tokenizer.eos_token_id] * (target_token_count - len(x)) for x in remade_batch_tokens]
|
||||
tokens = torch.asarray(remade_batch_tokens_of_same_length).to(device)
|
||||
|
||||
tmp = -opts.CLIP_ignore_last_layers
|
||||
if (opts.CLIP_ignore_last_layers == 0):
|
||||
outputs = self.wrapped.transformer(input_ids=tokens, position_ids=position_ids)
|
||||
z = outputs.last_hidden_state
|
||||
else:
|
||||
outputs = self.wrapped.transformer(input_ids=tokens, position_ids=position_ids, output_hidden_states=tmp)
|
||||
z = outputs.hidden_states[tmp]
|
||||
z = self.wrapped.transformer.text_model.final_layer_norm(z)
|
||||
tmp = -opts.CLIP_stop_at_last_layers
|
||||
outputs = self.wrapped.transformer(input_ids=tokens, position_ids=position_ids, output_hidden_states=tmp)
|
||||
z = outputs.hidden_states[tmp]
|
||||
z = self.wrapped.transformer.text_model.final_layer_norm(z)
|
||||
|
||||
# restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
|
||||
batch_multipliers_of_same_length = [x + [1.0] * (target_token_count - len(x)) for x in batch_multipliers]
|
||||
|
|
Loading…
Reference in a new issue