Add PNG alpha channel as weight maps to data entries
This commit is contained in:
parent
c4bfd20f31
commit
21642000b3
1 changed files with 38 additions and 13 deletions
|
@ -19,9 +19,10 @@ re_numbers_at_start = re.compile(r"^[-\d]+\s*")
|
||||||
|
|
||||||
|
|
||||||
class DatasetEntry:
|
class DatasetEntry:
|
||||||
def __init__(self, filename=None, filename_text=None, latent_dist=None, latent_sample=None, cond=None, cond_text=None, pixel_values=None):
|
def __init__(self, filename=None, filename_text=None, latent_dist=None, latent_sample=None, cond=None, cond_text=None, pixel_values=None, weight=None):
|
||||||
self.filename = filename
|
self.filename = filename
|
||||||
self.filename_text = filename_text
|
self.filename_text = filename_text
|
||||||
|
self.weight = weight
|
||||||
self.latent_dist = latent_dist
|
self.latent_dist = latent_dist
|
||||||
self.latent_sample = latent_sample
|
self.latent_sample = latent_sample
|
||||||
self.cond = cond
|
self.cond = cond
|
||||||
|
@ -56,10 +57,16 @@ class PersonalizedBase(Dataset):
|
||||||
|
|
||||||
print("Preparing dataset...")
|
print("Preparing dataset...")
|
||||||
for path in tqdm.tqdm(self.image_paths):
|
for path in tqdm.tqdm(self.image_paths):
|
||||||
|
alpha_channel = None
|
||||||
if shared.state.interrupted:
|
if shared.state.interrupted:
|
||||||
raise Exception("interrupted")
|
raise Exception("interrupted")
|
||||||
try:
|
try:
|
||||||
image = Image.open(path).convert('RGB')
|
image = Image.open(path)
|
||||||
|
#Currently does not work for single color transparency
|
||||||
|
#We would need to read image.info['transparency'] for that
|
||||||
|
if 'A' in image.getbands():
|
||||||
|
alpha_channel = image.getchannel('A')
|
||||||
|
image = image.convert('RGB')
|
||||||
if not varsize:
|
if not varsize:
|
||||||
image = image.resize((width, height), PIL.Image.BICUBIC)
|
image = image.resize((width, height), PIL.Image.BICUBIC)
|
||||||
except Exception:
|
except Exception:
|
||||||
|
@ -87,17 +94,33 @@ class PersonalizedBase(Dataset):
|
||||||
with devices.autocast():
|
with devices.autocast():
|
||||||
latent_dist = model.encode_first_stage(torchdata.unsqueeze(dim=0))
|
latent_dist = model.encode_first_stage(torchdata.unsqueeze(dim=0))
|
||||||
|
|
||||||
if latent_sampling_method == "once" or (latent_sampling_method == "deterministic" and not isinstance(latent_dist, DiagonalGaussianDistribution)):
|
#Perform latent sampling, even for random sampling.
|
||||||
latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu)
|
#We need the sample dimensions for the weights
|
||||||
latent_sampling_method = "once"
|
if latent_sampling_method == "deterministic":
|
||||||
entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample)
|
if isinstance(latent_dist, DiagonalGaussianDistribution):
|
||||||
elif latent_sampling_method == "deterministic":
|
|
||||||
# Works only for DiagonalGaussianDistribution
|
# Works only for DiagonalGaussianDistribution
|
||||||
latent_dist.std = 0
|
latent_dist.std = 0
|
||||||
|
else:
|
||||||
|
latent_sampling_method = "once"
|
||||||
latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu)
|
latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu)
|
||||||
entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample)
|
|
||||||
elif latent_sampling_method == "random":
|
if alpha_channel is not None:
|
||||||
entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist)
|
channels, *latent_size = latent_sample.shape
|
||||||
|
weight_img = alpha_channel.resize(latent_size)
|
||||||
|
npweight = np.array(weight_img).astype(np.float32)
|
||||||
|
#Repeat for every channel in the latent sample
|
||||||
|
weight = torch.tensor([npweight] * channels).reshape([channels] + latent_size)
|
||||||
|
#Normalize the weight to a minimum of 0 and a mean of 1, that way the loss will be comparable to default.
|
||||||
|
weight -= weight.min()
|
||||||
|
weight /= weight.mean()
|
||||||
|
else:
|
||||||
|
#If an image does not have a alpha channel, add a ones weight map anyway so we can stack it later
|
||||||
|
weight = torch.ones([channels] + latent_size)
|
||||||
|
|
||||||
|
if latent_sampling_method == "random":
|
||||||
|
entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist, weight=weight)
|
||||||
|
else:
|
||||||
|
entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample, weight=weight)
|
||||||
|
|
||||||
if not (self.tag_drop_out != 0 or self.shuffle_tags):
|
if not (self.tag_drop_out != 0 or self.shuffle_tags):
|
||||||
entry.cond_text = self.create_text(filename_text)
|
entry.cond_text = self.create_text(filename_text)
|
||||||
|
@ -110,6 +133,7 @@ class PersonalizedBase(Dataset):
|
||||||
del torchdata
|
del torchdata
|
||||||
del latent_dist
|
del latent_dist
|
||||||
del latent_sample
|
del latent_sample
|
||||||
|
del weight
|
||||||
|
|
||||||
self.length = len(self.dataset)
|
self.length = len(self.dataset)
|
||||||
self.groups = list(groups.values())
|
self.groups = list(groups.values())
|
||||||
|
@ -195,6 +219,7 @@ class BatchLoader:
|
||||||
self.cond_text = [entry.cond_text for entry in data]
|
self.cond_text = [entry.cond_text for entry in data]
|
||||||
self.cond = [entry.cond for entry in data]
|
self.cond = [entry.cond for entry in data]
|
||||||
self.latent_sample = torch.stack([entry.latent_sample for entry in data]).squeeze(1)
|
self.latent_sample = torch.stack([entry.latent_sample for entry in data]).squeeze(1)
|
||||||
|
self.weight = torch.stack([entry.weight for entry in data]).squeeze(1)
|
||||||
#self.emb_index = [entry.emb_index for entry in data]
|
#self.emb_index = [entry.emb_index for entry in data]
|
||||||
#print(self.latent_sample.device)
|
#print(self.latent_sample.device)
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue