Update to support embedding with length greater than 1.
This commit is contained in:
parent
4cafad66d2
commit
0dca0db7eb
2 changed files with 10 additions and 7 deletions
|
@ -7,7 +7,7 @@ set VENV_DIR=venv
|
||||||
|
|
||||||
mkdir tmp 2>NUL
|
mkdir tmp 2>NUL
|
||||||
|
|
||||||
set TORCH_COMMAND=pip install torch --extra-index-url https://download.pytorch.org/whl/cu113
|
set TORCH_COMMAND=pip install torch==1.12.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113
|
||||||
set REQS_FILE=requirements_versions.txt
|
set REQS_FILE=requirements_versions.txt
|
||||||
|
|
||||||
%PYTHON% -c "" >tmp/stdout.txt 2>tmp/stderr.txt
|
%PYTHON% -c "" >tmp/stdout.txt 2>tmp/stderr.txt
|
||||||
|
|
15
webui.py
15
webui.py
|
@ -746,9 +746,9 @@ class StableDiffusionModelHijack:
|
||||||
if hasattr(param_dict, '_parameters'):
|
if hasattr(param_dict, '_parameters'):
|
||||||
param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11
|
param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11
|
||||||
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
|
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
|
||||||
emb = next(iter(param_dict.items()))[1].reshape(768)
|
emb = next(iter(param_dict.items()))[1]
|
||||||
self.word_embeddings[name] = emb
|
self.word_embeddings[name] = emb.detach()
|
||||||
self.word_embeddings_checksums[name] = f'{const_hash(emb)&0xffff:04x}'
|
self.word_embeddings_checksums[name] = f'{const_hash(emb.reshape(-1))&0xffff:04x}'
|
||||||
|
|
||||||
ids = tokenizer([name], add_special_tokens=False)['input_ids'][0]
|
ids = tokenizer([name], add_special_tokens=False)['input_ids'][0]
|
||||||
|
|
||||||
|
@ -838,9 +838,10 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
||||||
found = False
|
found = False
|
||||||
for ids, word in possible_matches:
|
for ids, word in possible_matches:
|
||||||
if tokens[i:i+len(ids)] == ids:
|
if tokens[i:i+len(ids)] == ids:
|
||||||
|
emb_len = int(self.hijack.word_embeddings[word].shape[0])
|
||||||
fixes.append((len(remade_tokens), word))
|
fixes.append((len(remade_tokens), word))
|
||||||
remade_tokens.append(777)
|
remade_tokens += [0] * emb_len
|
||||||
multipliers.append(mult)
|
multipliers += [mult] * emb_len
|
||||||
i += len(ids) - 1
|
i += len(ids) - 1
|
||||||
found = True
|
found = True
|
||||||
used_custom_terms.append((word, self.hijack.word_embeddings_checksums[word]))
|
used_custom_terms.append((word, self.hijack.word_embeddings_checksums[word]))
|
||||||
|
@ -903,7 +904,9 @@ class EmbeddingsWithFixes(nn.Module):
|
||||||
if batch_fixes is not None:
|
if batch_fixes is not None:
|
||||||
for fixes, tensor in zip(batch_fixes, inputs_embeds):
|
for fixes, tensor in zip(batch_fixes, inputs_embeds):
|
||||||
for offset, word in fixes:
|
for offset, word in fixes:
|
||||||
tensor[offset] = self.embeddings.word_embeddings[word]
|
emb = self.embeddings.word_embeddings[word]
|
||||||
|
emb_len = min(tensor.shape[0]-offset, emb.shape[0])
|
||||||
|
tensor[offset:offset+emb_len] = self.embeddings.word_embeddings[word][0:emb_len]
|
||||||
|
|
||||||
return inputs_embeds
|
return inputs_embeds
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue