add pixel data footer
This commit is contained in:
parent
ce2d7f7eac
commit
707a431100
1 changed files with 46 additions and 2 deletions
|
@ -12,6 +12,7 @@ from ..images import captionImageOverlay
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import base64
|
import base64
|
||||||
import json
|
import json
|
||||||
|
import zlib
|
||||||
|
|
||||||
from modules import shared, devices, sd_hijack, processing, sd_models
|
from modules import shared, devices, sd_hijack, processing, sd_models
|
||||||
import modules.textual_inversion.dataset
|
import modules.textual_inversion.dataset
|
||||||
|
@ -20,7 +21,7 @@ class EmbeddingEncoder(json.JSONEncoder):
|
||||||
def default(self, obj):
|
def default(self, obj):
|
||||||
if isinstance(obj, torch.Tensor):
|
if isinstance(obj, torch.Tensor):
|
||||||
return {'TORCHTENSOR':obj.cpu().detach().numpy().tolist()}
|
return {'TORCHTENSOR':obj.cpu().detach().numpy().tolist()}
|
||||||
return json.JSONEncoder.default(self, o)
|
return json.JSONEncoder.default(self, obj)
|
||||||
|
|
||||||
class EmbeddingDecoder(json.JSONDecoder):
|
class EmbeddingDecoder(json.JSONDecoder):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
|
@ -38,6 +39,45 @@ def embeddingFromB64(data):
|
||||||
d = base64.b64decode(data)
|
d = base64.b64decode(data)
|
||||||
return json.loads(d,cls=EmbeddingDecoder)
|
return json.loads(d,cls=EmbeddingDecoder)
|
||||||
|
|
||||||
|
def appendImageDataFooter(image,data):
|
||||||
|
d = 3
|
||||||
|
data_compressed = zlib.compress( json.dumps(data,cls=EmbeddingEncoder).encode(),level=9)
|
||||||
|
dnp = np.frombuffer(data_compressed,np.uint8).copy()
|
||||||
|
w = image.size[0]
|
||||||
|
next_size = dnp.shape[0] + (w-(dnp.shape[0]%w))
|
||||||
|
next_size = next_size + ((w*d)-(next_size%(w*d)))
|
||||||
|
dnp.resize(next_size)
|
||||||
|
dnp = dnp.reshape((-1,w,d))
|
||||||
|
print(dnp.shape)
|
||||||
|
im = Image.fromarray(dnp,mode='RGB')
|
||||||
|
background = Image.new('RGB',(image.size[0],image.size[1]+im.size[1]+1),(0,0,0))
|
||||||
|
background.paste(image,(0,0))
|
||||||
|
background.paste(im,(0,image.size[1]+1))
|
||||||
|
return background
|
||||||
|
|
||||||
|
def crop_black(img,tol=0):
|
||||||
|
mask = (img>tol).all(2)
|
||||||
|
mask0,mask1 = mask.any(0),mask.any(1)
|
||||||
|
col_start,col_end = mask0.argmax(),mask.shape[1]-mask0[::-1].argmax()
|
||||||
|
row_start,row_end = mask1.argmax(),mask.shape[0]-mask1[::-1].argmax()
|
||||||
|
return img[row_start:row_end,col_start:col_end]
|
||||||
|
|
||||||
|
def extractImageDataFooter(image):
|
||||||
|
d=3
|
||||||
|
outarr = crop_black(np.array(image.getdata()).reshape(image.size[1],image.size[0],d ).astype(np.uint8) )
|
||||||
|
lastRow = np.where( np.sum(outarr, axis=(1,2))==0)
|
||||||
|
if lastRow[0].shape[0] == 0:
|
||||||
|
print('Image data block not found.')
|
||||||
|
return None
|
||||||
|
lastRow = lastRow[0]
|
||||||
|
|
||||||
|
lastRow = lastRow.max()
|
||||||
|
|
||||||
|
dataBlock = outarr[lastRow+1::].astype(np.uint8).flatten().tobytes()
|
||||||
|
print(lastRow)
|
||||||
|
data = zlib.decompress(dataBlock)
|
||||||
|
return json.loads(data,cls=EmbeddingDecoder)
|
||||||
|
|
||||||
class Embedding:
|
class Embedding:
|
||||||
def __init__(self, vec, name, step=None):
|
def __init__(self, vec, name, step=None):
|
||||||
self.vec = vec
|
self.vec = vec
|
||||||
|
@ -113,6 +153,9 @@ class EmbeddingDatabase:
|
||||||
if 'sd-ti-embedding' in embed_image.text:
|
if 'sd-ti-embedding' in embed_image.text:
|
||||||
data = embeddingFromB64(embed_image.text['sd-ti-embedding'])
|
data = embeddingFromB64(embed_image.text['sd-ti-embedding'])
|
||||||
name = data.get('name',name)
|
name = data.get('name',name)
|
||||||
|
else:
|
||||||
|
data = extractImageDataFooter(embed_image)
|
||||||
|
name = data.get('name',name)
|
||||||
else:
|
else:
|
||||||
data = torch.load(path, map_location="cpu")
|
data = torch.load(path, map_location="cpu")
|
||||||
|
|
||||||
|
@ -190,7 +233,7 @@ def create_embedding(name, num_vectors_per_token, init_text='*'):
|
||||||
return fn
|
return fn
|
||||||
|
|
||||||
|
|
||||||
def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, num_repeats, create_image_every, save_embedding_every, template_file):
|
def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, num_repeats, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding):
|
||||||
assert embedding_name, 'embedding not selected'
|
assert embedding_name, 'embedding not selected'
|
||||||
|
|
||||||
shared.state.textinfo = "Initializing textual inversion training..."
|
shared.state.textinfo = "Initializing textual inversion training..."
|
||||||
|
@ -308,6 +351,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
|
||||||
footer_right = '{}'.format(embedding.step)
|
footer_right = '{}'.format(embedding.step)
|
||||||
|
|
||||||
captioned_image = captionImageOverlay(image,title,footer_left,footer_mid,footer_right)
|
captioned_image = captionImageOverlay(image,title,footer_left,footer_mid,footer_right)
|
||||||
|
captioned_image = appendImageDataFooter(captioned_image,data)
|
||||||
|
|
||||||
captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info)
|
captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info)
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue