add encoder and decoder classes

This commit is contained in:
DepFA 2022-10-09 22:05:09 +01:00 committed by GitHub
parent 969bd8256e
commit 5d12ec82d3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -16,6 +16,27 @@ import json
from modules import shared, devices, sd_hijack, processing, sd_models from modules import shared, devices, sd_hijack, processing, sd_models
import modules.textual_inversion.dataset import modules.textual_inversion.dataset
class EmbeddingEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, torch.Tensor):
return {'EMBEDDINGTENSOR':obj.cpu().detach().numpy().tolist()}
return json.JSONEncoder.default(self, o)
class EmbeddingDecoder(json.JSONDecoder):
def __init__(self, *args, **kwargs):
json.JSONDecoder.__init__(self, object_hook=self.object_hook, *args, **kwargs)
def object_hook(self, d):
if 'EMBEDDINGTENSOR' in d:
return torch.from_numpy(np.array(d['EMBEDDINGTENSOR']))
return d
def embeddingToB64(data):
d = json.dumps(data,cls=EmbeddingEncoder)
return base64.b64encode(d.encode())
def EmbeddingFromB64(data):
d = base64.b64decode(data)
return json.loads(d,cls=EmbeddingDecoder)
class Embedding: class Embedding:
def __init__(self, vec, name, step=None): def __init__(self, vec, name, step=None):