add pixel data footer
This commit is contained in:
parent
ce2d7f7eac
commit
707a431100
1 changed files with 46 additions and 2 deletions
|
@ -12,6 +12,7 @@ from ..images import captionImageOverlay
|
|||
import numpy as np
|
||||
import base64
|
||||
import json
|
||||
import zlib
|
||||
|
||||
from modules import shared, devices, sd_hijack, processing, sd_models
|
||||
import modules.textual_inversion.dataset
|
||||
|
@ -20,7 +21,7 @@ class EmbeddingEncoder(json.JSONEncoder):
|
|||
def default(self, obj):
|
||||
if isinstance(obj, torch.Tensor):
|
||||
return {'TORCHTENSOR':obj.cpu().detach().numpy().tolist()}
|
||||
return json.JSONEncoder.default(self, o)
|
||||
return json.JSONEncoder.default(self, obj)
|
||||
|
||||
class EmbeddingDecoder(json.JSONDecoder):
|
||||
def __init__(self, *args, **kwargs):
|
||||
|
@ -38,6 +39,45 @@ def embeddingFromB64(data):
|
|||
d = base64.b64decode(data)
|
||||
return json.loads(d,cls=EmbeddingDecoder)
|
||||
|
||||
def appendImageDataFooter(image,data):
|
||||
d = 3
|
||||
data_compressed = zlib.compress( json.dumps(data,cls=EmbeddingEncoder).encode(),level=9)
|
||||
dnp = np.frombuffer(data_compressed,np.uint8).copy()
|
||||
w = image.size[0]
|
||||
next_size = dnp.shape[0] + (w-(dnp.shape[0]%w))
|
||||
next_size = next_size + ((w*d)-(next_size%(w*d)))
|
||||
dnp.resize(next_size)
|
||||
dnp = dnp.reshape((-1,w,d))
|
||||
print(dnp.shape)
|
||||
im = Image.fromarray(dnp,mode='RGB')
|
||||
background = Image.new('RGB',(image.size[0],image.size[1]+im.size[1]+1),(0,0,0))
|
||||
background.paste(image,(0,0))
|
||||
background.paste(im,(0,image.size[1]+1))
|
||||
return background
|
||||
|
||||
def crop_black(img,tol=0):
|
||||
mask = (img>tol).all(2)
|
||||
mask0,mask1 = mask.any(0),mask.any(1)
|
||||
col_start,col_end = mask0.argmax(),mask.shape[1]-mask0[::-1].argmax()
|
||||
row_start,row_end = mask1.argmax(),mask.shape[0]-mask1[::-1].argmax()
|
||||
return img[row_start:row_end,col_start:col_end]
|
||||
|
||||
def extractImageDataFooter(image):
|
||||
d=3
|
||||
outarr = crop_black(np.array(image.getdata()).reshape(image.size[1],image.size[0],d ).astype(np.uint8) )
|
||||
lastRow = np.where( np.sum(outarr, axis=(1,2))==0)
|
||||
if lastRow[0].shape[0] == 0:
|
||||
print('Image data block not found.')
|
||||
return None
|
||||
lastRow = lastRow[0]
|
||||
|
||||
lastRow = lastRow.max()
|
||||
|
||||
dataBlock = outarr[lastRow+1::].astype(np.uint8).flatten().tobytes()
|
||||
print(lastRow)
|
||||
data = zlib.decompress(dataBlock)
|
||||
return json.loads(data,cls=EmbeddingDecoder)
|
||||
|
||||
class Embedding:
|
||||
def __init__(self, vec, name, step=None):
|
||||
self.vec = vec
|
||||
|
@ -113,6 +153,9 @@ class EmbeddingDatabase:
|
|||
if 'sd-ti-embedding' in embed_image.text:
|
||||
data = embeddingFromB64(embed_image.text['sd-ti-embedding'])
|
||||
name = data.get('name',name)
|
||||
else:
|
||||
data = extractImageDataFooter(embed_image)
|
||||
name = data.get('name',name)
|
||||
else:
|
||||
data = torch.load(path, map_location="cpu")
|
||||
|
||||
|
@ -190,7 +233,7 @@ def create_embedding(name, num_vectors_per_token, init_text='*'):
|
|||
return fn
|
||||
|
||||
|
||||
def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, num_repeats, create_image_every, save_embedding_every, template_file):
|
||||
def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, num_repeats, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding):
|
||||
assert embedding_name, 'embedding not selected'
|
||||
|
||||
shared.state.textinfo = "Initializing textual inversion training..."
|
||||
|
@ -308,6 +351,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
|
|||
footer_right = '{}'.format(embedding.step)
|
||||
|
||||
captioned_image = captionImageOverlay(image,title,footer_left,footer_mid,footer_right)
|
||||
captioned_image = appendImageDataFooter(captioned_image,data)
|
||||
|
||||
captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info)
|
||||
|
||||
|
|
Loading…
Reference in a new issue