feat(api): return more data for embeddings
This commit is contained in:
parent
b5819d9bf1
commit
c65909ad16
3 changed files with 28 additions and 8 deletions
|
@ -330,9 +330,22 @@ class Api:
|
||||||
|
|
||||||
def get_embeddings(self):
|
def get_embeddings(self):
|
||||||
db = sd_hijack.model_hijack.embedding_db
|
db = sd_hijack.model_hijack.embedding_db
|
||||||
|
|
||||||
|
def convert_embedding(embedding):
|
||||||
|
return {
|
||||||
|
"step": embedding.step,
|
||||||
|
"sd_checkpoint": embedding.sd_checkpoint,
|
||||||
|
"sd_checkpoint_name": embedding.sd_checkpoint_name,
|
||||||
|
"shape": embedding.shape,
|
||||||
|
"vectors": embedding.vectors,
|
||||||
|
}
|
||||||
|
|
||||||
|
def convert_embeddings(embeddings):
|
||||||
|
return {embedding.name: convert_embedding(embedding) for embedding in embeddings.values()}
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"loaded": sorted(db.word_embeddings.keys()),
|
"loaded": convert_embeddings(db.word_embeddings),
|
||||||
"skipped": sorted(db.skipped_embeddings),
|
"skipped": convert_embeddings(db.skipped_embeddings),
|
||||||
}
|
}
|
||||||
|
|
||||||
def refresh_checkpoints(self):
|
def refresh_checkpoints(self):
|
||||||
|
|
|
@ -249,6 +249,13 @@ class ArtistItem(BaseModel):
|
||||||
score: float = Field(title="Score")
|
score: float = Field(title="Score")
|
||||||
category: str = Field(title="Category")
|
category: str = Field(title="Category")
|
||||||
|
|
||||||
|
class EmbeddingItem(BaseModel):
|
||||||
|
step: Optional[int] = Field(title="Step", description="The number of steps that were used to train this embedding, if available")
|
||||||
|
sd_checkpoint: Optional[str] = Field(title="SD Checkpoint", description="The hash of the checkpoint this embedding was trained on, if available")
|
||||||
|
sd_checkpoint_name: Optional[str] = Field(title="SD Checkpoint Name", description="The name of the checkpoint this embedding was trained on, if available. Note that this is the name that was used by the trainer; for a stable identifier, use `sd_checkpoint` instead")
|
||||||
|
shape: int = Field(title="Shape", description="The length of each individual vector in the embedding")
|
||||||
|
vectors: int = Field(title="Vectors", description="The number of vectors in the embedding")
|
||||||
|
|
||||||
class EmbeddingsResponse(BaseModel):
|
class EmbeddingsResponse(BaseModel):
|
||||||
loaded: List[str] = Field(title="Loaded", description="Embeddings loaded for the current model")
|
loaded: Dict[str, EmbeddingItem] = Field(title="Loaded", description="Embeddings loaded for the current model")
|
||||||
skipped: List[str] = Field(title="Skipped", description="Embeddings skipped for the current model (likely due to architecture incompatibility)")
|
skipped: Dict[str, EmbeddingItem] = Field(title="Skipped", description="Embeddings skipped for the current model (likely due to architecture incompatibility)")
|
|
@ -59,7 +59,7 @@ class EmbeddingDatabase:
|
||||||
def __init__(self, embeddings_dir):
|
def __init__(self, embeddings_dir):
|
||||||
self.ids_lookup = {}
|
self.ids_lookup = {}
|
||||||
self.word_embeddings = {}
|
self.word_embeddings = {}
|
||||||
self.skipped_embeddings = []
|
self.skipped_embeddings = {}
|
||||||
self.dir_mtime = None
|
self.dir_mtime = None
|
||||||
self.embeddings_dir = embeddings_dir
|
self.embeddings_dir = embeddings_dir
|
||||||
self.expected_shape = -1
|
self.expected_shape = -1
|
||||||
|
@ -91,7 +91,7 @@ class EmbeddingDatabase:
|
||||||
self.dir_mtime = mt
|
self.dir_mtime = mt
|
||||||
self.ids_lookup.clear()
|
self.ids_lookup.clear()
|
||||||
self.word_embeddings.clear()
|
self.word_embeddings.clear()
|
||||||
self.skipped_embeddings = []
|
self.skipped_embeddings.clear()
|
||||||
self.expected_shape = self.get_expected_shape()
|
self.expected_shape = self.get_expected_shape()
|
||||||
|
|
||||||
def process_file(path, filename):
|
def process_file(path, filename):
|
||||||
|
@ -136,7 +136,7 @@ class EmbeddingDatabase:
|
||||||
if self.expected_shape == -1 or self.expected_shape == embedding.shape:
|
if self.expected_shape == -1 or self.expected_shape == embedding.shape:
|
||||||
self.register_embedding(embedding, shared.sd_model)
|
self.register_embedding(embedding, shared.sd_model)
|
||||||
else:
|
else:
|
||||||
self.skipped_embeddings.append(name)
|
self.skipped_embeddings[name] = embedding
|
||||||
|
|
||||||
for fn in os.listdir(self.embeddings_dir):
|
for fn in os.listdir(self.embeddings_dir):
|
||||||
try:
|
try:
|
||||||
|
@ -153,7 +153,7 @@ class EmbeddingDatabase:
|
||||||
|
|
||||||
print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}")
|
print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}")
|
||||||
if len(self.skipped_embeddings) > 0:
|
if len(self.skipped_embeddings) > 0:
|
||||||
print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings)}")
|
print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings.keys())}")
|
||||||
|
|
||||||
def find_embedding_at_position(self, tokens, offset):
|
def find_embedding_at_position(self, tokens, offset):
|
||||||
token = tokens[offset]
|
token = tokens[offset]
|
||||||
|
|
Loading…
Reference in a new issue