Update textual_inversion.py

This commit is contained in:
DepFA 2022-10-09 05:38:38 +01:00 committed by GitHub
parent 050a6a798c
commit 5841990b0d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -7,6 +7,9 @@ import tqdm
import html import html
import datetime import datetime
from PIL import Image, PngImagePlugin
import base64
from io import BytesIO
from modules import shared, devices, sd_hijack, processing, sd_models from modules import shared, devices, sd_hijack, processing, sd_models
import modules.textual_inversion.dataset import modules.textual_inversion.dataset
@ -80,6 +83,14 @@ class EmbeddingDatabase:
def process_file(path, filename): def process_file(path, filename):
name = os.path.splitext(filename)[0] name = os.path.splitext(filename)[0]
data = []
if filename.upper().endswith('.PNG'):
embed_image = Image.open(path)
if 'sd-embedding' in embed_image.text:
embeddingData = base64.b64decode(embed_image.text['sd-embedding'])
data = torch.load(BytesIO(embeddingData), map_location="cpu")
else:
data = torch.load(path, map_location="cpu") data = torch.load(path, map_location="cpu")
# textual inversion embeddings # textual inversion embeddings
@ -156,7 +167,7 @@ def create_embedding(name, num_vectors_per_token, init_text='*'):
return fn return fn
def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, create_image_every, save_embedding_every, template_file): def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding):
assert embedding_name, 'embedding not selected' assert embedding_name, 'embedding not selected'
shared.state.textinfo = "Initializing textual inversion training..." shared.state.textinfo = "Initializing textual inversion training..."
@ -244,8 +255,16 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps,
image = processed.images[0] image = processed.images[0]
shared.state.current_image = image shared.state.current_image = image
if save_image_with_stored_embedding:
info = PngImagePlugin.PngInfo()
info.add_text("sd-embedding", base64.b64encode(open(last_saved_file,'rb').read()))
image.save(last_saved_image, "PNG", pnginfo=info)
else:
image.save(last_saved_image) image.save(last_saved_image)
last_saved_image += f", prompt: {text}" last_saved_image += f", prompt: {text}"
shared.state.job_no = embedding.step shared.state.job_no = embedding.step