Update textual_inversion.py
This commit is contained in:
parent
050a6a798c
commit
5841990b0d
1 changed files with 22 additions and 3 deletions
|
@ -7,6 +7,9 @@ import tqdm
|
||||||
import html
|
import html
|
||||||
import datetime
|
import datetime
|
||||||
|
|
||||||
|
from PIL import Image, PngImagePlugin
|
||||||
|
import base64
|
||||||
|
from io import BytesIO
|
||||||
|
|
||||||
from modules import shared, devices, sd_hijack, processing, sd_models
|
from modules import shared, devices, sd_hijack, processing, sd_models
|
||||||
import modules.textual_inversion.dataset
|
import modules.textual_inversion.dataset
|
||||||
|
@ -80,6 +83,14 @@ class EmbeddingDatabase:
|
||||||
def process_file(path, filename):
|
def process_file(path, filename):
|
||||||
name = os.path.splitext(filename)[0]
|
name = os.path.splitext(filename)[0]
|
||||||
|
|
||||||
|
data = []
|
||||||
|
|
||||||
|
if filename.upper().endswith('.PNG'):
|
||||||
|
embed_image = Image.open(path)
|
||||||
|
if 'sd-embedding' in embed_image.text:
|
||||||
|
embeddingData = base64.b64decode(embed_image.text['sd-embedding'])
|
||||||
|
data = torch.load(BytesIO(embeddingData), map_location="cpu")
|
||||||
|
else:
|
||||||
data = torch.load(path, map_location="cpu")
|
data = torch.load(path, map_location="cpu")
|
||||||
|
|
||||||
# textual inversion embeddings
|
# textual inversion embeddings
|
||||||
|
@ -156,7 +167,7 @@ def create_embedding(name, num_vectors_per_token, init_text='*'):
|
||||||
return fn
|
return fn
|
||||||
|
|
||||||
|
|
||||||
def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, create_image_every, save_embedding_every, template_file):
|
def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding):
|
||||||
assert embedding_name, 'embedding not selected'
|
assert embedding_name, 'embedding not selected'
|
||||||
|
|
||||||
shared.state.textinfo = "Initializing textual inversion training..."
|
shared.state.textinfo = "Initializing textual inversion training..."
|
||||||
|
@ -244,8 +255,16 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps,
|
||||||
image = processed.images[0]
|
image = processed.images[0]
|
||||||
|
|
||||||
shared.state.current_image = image
|
shared.state.current_image = image
|
||||||
|
|
||||||
|
if save_image_with_stored_embedding:
|
||||||
|
info = PngImagePlugin.PngInfo()
|
||||||
|
info.add_text("sd-embedding", base64.b64encode(open(last_saved_file,'rb').read()))
|
||||||
|
image.save(last_saved_image, "PNG", pnginfo=info)
|
||||||
|
else:
|
||||||
image.save(last_saved_image)
|
image.save(last_saved_image)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
last_saved_image += f", prompt: {text}"
|
last_saved_image += f", prompt: {text}"
|
||||||
|
|
||||||
shared.state.job_no = embedding.step
|
shared.state.job_no = embedding.step
|
||||||
|
|
Loading…
Reference in a new issue