Add PNG alpha channel as weight maps to data entries

This commit is contained in:
Shondoit 2023-01-12 15:29:19 +01:00
parent c4bfd20f31
commit 21642000b3

View file

@ -19,9 +19,10 @@ re_numbers_at_start = re.compile(r"^[-\d]+\s*")
class DatasetEntry: class DatasetEntry:
def __init__(self, filename=None, filename_text=None, latent_dist=None, latent_sample=None, cond=None, cond_text=None, pixel_values=None): def __init__(self, filename=None, filename_text=None, latent_dist=None, latent_sample=None, cond=None, cond_text=None, pixel_values=None, weight=None):
self.filename = filename self.filename = filename
self.filename_text = filename_text self.filename_text = filename_text
self.weight = weight
self.latent_dist = latent_dist self.latent_dist = latent_dist
self.latent_sample = latent_sample self.latent_sample = latent_sample
self.cond = cond self.cond = cond
@ -56,10 +57,16 @@ class PersonalizedBase(Dataset):
print("Preparing dataset...") print("Preparing dataset...")
for path in tqdm.tqdm(self.image_paths): for path in tqdm.tqdm(self.image_paths):
alpha_channel = None
if shared.state.interrupted: if shared.state.interrupted:
raise Exception("interrupted") raise Exception("interrupted")
try: try:
image = Image.open(path).convert('RGB') image = Image.open(path)
#Currently does not work for single color transparency
#We would need to read image.info['transparency'] for that
if 'A' in image.getbands():
alpha_channel = image.getchannel('A')
image = image.convert('RGB')
if not varsize: if not varsize:
image = image.resize((width, height), PIL.Image.BICUBIC) image = image.resize((width, height), PIL.Image.BICUBIC)
except Exception: except Exception:
@ -87,17 +94,33 @@ class PersonalizedBase(Dataset):
with devices.autocast(): with devices.autocast():
latent_dist = model.encode_first_stage(torchdata.unsqueeze(dim=0)) latent_dist = model.encode_first_stage(torchdata.unsqueeze(dim=0))
if latent_sampling_method == "once" or (latent_sampling_method == "deterministic" and not isinstance(latent_dist, DiagonalGaussianDistribution)): #Perform latent sampling, even for random sampling.
latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu) #We need the sample dimensions for the weights
latent_sampling_method = "once" if latent_sampling_method == "deterministic":
entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample) if isinstance(latent_dist, DiagonalGaussianDistribution):
elif latent_sampling_method == "deterministic":
# Works only for DiagonalGaussianDistribution # Works only for DiagonalGaussianDistribution
latent_dist.std = 0 latent_dist.std = 0
else:
latent_sampling_method = "once"
latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu) latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu)
entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample)
elif latent_sampling_method == "random": if alpha_channel is not None:
entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist) channels, *latent_size = latent_sample.shape
weight_img = alpha_channel.resize(latent_size)
npweight = np.array(weight_img).astype(np.float32)
#Repeat for every channel in the latent sample
weight = torch.tensor([npweight] * channels).reshape([channels] + latent_size)
#Normalize the weight to a minimum of 0 and a mean of 1, that way the loss will be comparable to default.
weight -= weight.min()
weight /= weight.mean()
else:
#If an image does not have a alpha channel, add a ones weight map anyway so we can stack it later
weight = torch.ones([channels] + latent_size)
if latent_sampling_method == "random":
entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist, weight=weight)
else:
entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample, weight=weight)
if not (self.tag_drop_out != 0 or self.shuffle_tags): if not (self.tag_drop_out != 0 or self.shuffle_tags):
entry.cond_text = self.create_text(filename_text) entry.cond_text = self.create_text(filename_text)
@ -110,6 +133,7 @@ class PersonalizedBase(Dataset):
del torchdata del torchdata
del latent_dist del latent_dist
del latent_sample del latent_sample
del weight
self.length = len(self.dataset) self.length = len(self.dataset)
self.groups = list(groups.values()) self.groups = list(groups.values())
@ -195,6 +219,7 @@ class BatchLoader:
self.cond_text = [entry.cond_text for entry in data] self.cond_text = [entry.cond_text for entry in data]
self.cond = [entry.cond for entry in data] self.cond = [entry.cond for entry in data]
self.latent_sample = torch.stack([entry.latent_sample for entry in data]).squeeze(1) self.latent_sample = torch.stack([entry.latent_sample for entry in data]).squeeze(1)
self.weight = torch.stack([entry.weight for entry in data]).squeeze(1)
#self.emb_index = [entry.emb_index for entry in data] #self.emb_index = [entry.emb_index for entry in data]
#print(self.latent_sample.device) #print(self.latent_sample.device)