Update textual_inversion.py
This commit is contained in:
parent
050a6a798c
commit
5841990b0d
1 changed files with 22 additions and 3 deletions
|
@ -7,6 +7,9 @@ import tqdm
|
|||
import html
|
||||
import datetime
|
||||
|
||||
from PIL import Image, PngImagePlugin
|
||||
import base64
|
||||
from io import BytesIO
|
||||
|
||||
from modules import shared, devices, sd_hijack, processing, sd_models
|
||||
import modules.textual_inversion.dataset
|
||||
|
@ -80,7 +83,15 @@ class EmbeddingDatabase:
|
|||
def process_file(path, filename):
|
||||
name = os.path.splitext(filename)[0]
|
||||
|
||||
data = torch.load(path, map_location="cpu")
|
||||
data = []
|
||||
|
||||
if filename.upper().endswith('.PNG'):
|
||||
embed_image = Image.open(path)
|
||||
if 'sd-embedding' in embed_image.text:
|
||||
embeddingData = base64.b64decode(embed_image.text['sd-embedding'])
|
||||
data = torch.load(BytesIO(embeddingData), map_location="cpu")
|
||||
else:
|
||||
data = torch.load(path, map_location="cpu")
|
||||
|
||||
# textual inversion embeddings
|
||||
if 'string_to_param' in data:
|
||||
|
@ -156,7 +167,7 @@ def create_embedding(name, num_vectors_per_token, init_text='*'):
|
|||
return fn
|
||||
|
||||
|
||||
def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, create_image_every, save_embedding_every, template_file):
|
||||
def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding):
|
||||
assert embedding_name, 'embedding not selected'
|
||||
|
||||
shared.state.textinfo = "Initializing textual inversion training..."
|
||||
|
@ -244,7 +255,15 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps,
|
|||
image = processed.images[0]
|
||||
|
||||
shared.state.current_image = image
|
||||
image.save(last_saved_image)
|
||||
|
||||
if save_image_with_stored_embedding:
|
||||
info = PngImagePlugin.PngInfo()
|
||||
info.add_text("sd-embedding", base64.b64encode(open(last_saved_file,'rb').read()))
|
||||
image.save(last_saved_image, "PNG", pnginfo=info)
|
||||
else:
|
||||
image.save(last_saved_image)
|
||||
|
||||
|
||||
|
||||
last_saved_image += f", prompt: {text}"
|
||||
|
||||
|
|
Loading…
Reference in a new issue