formatting
This commit is contained in:
parent
50be33e953
commit
10a2de644f
1 changed files with 11 additions and 11 deletions
|
@ -7,14 +7,14 @@ import tqdm
|
||||||
import html
|
import html
|
||||||
import datetime
|
import datetime
|
||||||
|
|
||||||
from PIL import Image,PngImagePlugin
|
from PIL import Image, PngImagePlugin
|
||||||
|
|
||||||
from modules import shared, devices, sd_hijack, processing, sd_models
|
from modules import shared, devices, sd_hijack, processing, sd_models
|
||||||
import modules.textual_inversion.dataset
|
import modules.textual_inversion.dataset
|
||||||
|
|
||||||
from modules.textual_inversion.image_embedding import (embedding_to_b64,embedding_from_b64,
|
from modules.textual_inversion.image_embedding import (embedding_to_b64, embedding_from_b64,
|
||||||
insert_image_data_embed,extract_image_data_embed,
|
insert_image_data_embed, extract_image_data_embed,
|
||||||
caption_image_overlay )
|
caption_image_overlay)
|
||||||
|
|
||||||
class Embedding:
|
class Embedding:
|
||||||
def __init__(self, vec, name, step=None):
|
def __init__(self, vec, name, step=None):
|
||||||
|
@ -90,10 +90,10 @@ class EmbeddingDatabase:
|
||||||
embed_image = Image.open(path)
|
embed_image = Image.open(path)
|
||||||
if 'sd-ti-embedding' in embed_image.text:
|
if 'sd-ti-embedding' in embed_image.text:
|
||||||
data = embedding_from_b64(embed_image.text['sd-ti-embedding'])
|
data = embedding_from_b64(embed_image.text['sd-ti-embedding'])
|
||||||
name = data.get('name',name)
|
name = data.get('name', name)
|
||||||
else:
|
else:
|
||||||
data = extract_image_data_embed(embed_image)
|
data = extract_image_data_embed(embed_image)
|
||||||
name = data.get('name',name)
|
name = data.get('name', name)
|
||||||
else:
|
else:
|
||||||
data = torch.load(path, map_location="cpu")
|
data = torch.load(path, map_location="cpu")
|
||||||
|
|
||||||
|
@ -278,24 +278,24 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
|
||||||
shared.state.current_image = image
|
shared.state.current_image = image
|
||||||
|
|
||||||
if save_image_with_stored_embedding and os.path.exists(last_saved_file):
|
if save_image_with_stored_embedding and os.path.exists(last_saved_file):
|
||||||
|
|
||||||
last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{embedding.step}.png')
|
last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{embedding.step}.png')
|
||||||
|
|
||||||
info = PngImagePlugin.PngInfo()
|
info = PngImagePlugin.PngInfo()
|
||||||
data = torch.load(last_saved_file)
|
data = torch.load(last_saved_file)
|
||||||
info.add_text("sd-ti-embedding", embedding_to_b64(data))
|
info.add_text("sd-ti-embedding", embedding_to_b64(data))
|
||||||
|
|
||||||
title = "<{}>".format(data.get('name','???'))
|
title = "<{}>".format(data.get('name', '???'))
|
||||||
checkpoint = sd_models.select_checkpoint()
|
checkpoint = sd_models.select_checkpoint()
|
||||||
footer_left = checkpoint.model_name
|
footer_left = checkpoint.model_name
|
||||||
footer_mid = '[{}]'.format(checkpoint.hash)
|
footer_mid = '[{}]'.format(checkpoint.hash)
|
||||||
footer_right = '{}'.format(embedding.step)
|
footer_right = '{}'.format(embedding.step)
|
||||||
|
|
||||||
captioned_image = caption_image_overlay(image,title,footer_left,footer_mid,footer_right)
|
captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right)
|
||||||
captioned_image = insert_image_data_embed(captioned_image,data)
|
captioned_image = insert_image_data_embed(captioned_image, data)
|
||||||
|
|
||||||
captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info)
|
captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info)
|
||||||
|
|
||||||
image.save(last_saved_image)
|
image.save(last_saved_image)
|
||||||
|
|
||||||
last_saved_image += f", prompt: {preview_text}"
|
last_saved_image += f", prompt: {preview_text}"
|
||||||
|
|
Loading…
Reference in a new issue