formatting

This commit is contained in:
DepFA 2022-10-12 13:15:35 +01:00 committed by GitHub
parent 50be33e953
commit 10a2de644f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -7,14 +7,14 @@ import tqdm
import html import html
import datetime import datetime
from PIL import Image,PngImagePlugin from PIL import Image, PngImagePlugin
from modules import shared, devices, sd_hijack, processing, sd_models from modules import shared, devices, sd_hijack, processing, sd_models
import modules.textual_inversion.dataset import modules.textual_inversion.dataset
from modules.textual_inversion.image_embedding import (embedding_to_b64,embedding_from_b64, from modules.textual_inversion.image_embedding import (embedding_to_b64, embedding_from_b64,
insert_image_data_embed,extract_image_data_embed, insert_image_data_embed, extract_image_data_embed,
caption_image_overlay ) caption_image_overlay)
class Embedding: class Embedding:
def __init__(self, vec, name, step=None): def __init__(self, vec, name, step=None):
@ -90,10 +90,10 @@ class EmbeddingDatabase:
embed_image = Image.open(path) embed_image = Image.open(path)
if 'sd-ti-embedding' in embed_image.text: if 'sd-ti-embedding' in embed_image.text:
data = embedding_from_b64(embed_image.text['sd-ti-embedding']) data = embedding_from_b64(embed_image.text['sd-ti-embedding'])
name = data.get('name',name) name = data.get('name', name)
else: else:
data = extract_image_data_embed(embed_image) data = extract_image_data_embed(embed_image)
name = data.get('name',name) name = data.get('name', name)
else: else:
data = torch.load(path, map_location="cpu") data = torch.load(path, map_location="cpu")
@ -285,14 +285,14 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
data = torch.load(last_saved_file) data = torch.load(last_saved_file)
info.add_text("sd-ti-embedding", embedding_to_b64(data)) info.add_text("sd-ti-embedding", embedding_to_b64(data))
title = "<{}>".format(data.get('name','???')) title = "<{}>".format(data.get('name', '???'))
checkpoint = sd_models.select_checkpoint() checkpoint = sd_models.select_checkpoint()
footer_left = checkpoint.model_name footer_left = checkpoint.model_name
footer_mid = '[{}]'.format(checkpoint.hash) footer_mid = '[{}]'.format(checkpoint.hash)
footer_right = '{}'.format(embedding.step) footer_right = '{}'.format(embedding.step)
captioned_image = caption_image_overlay(image,title,footer_left,footer_mid,footer_right) captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right)
captioned_image = insert_image_data_embed(captioned_image,data) captioned_image = insert_image_data_embed(captioned_image, data)
captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info) captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info)