Allow variable img size

This commit is contained in:
dan 2023-01-07 21:07:27 +08:00
parent 151233399c
commit 448b9cedab
2 changed files with 13 additions and 9 deletions

View file

@ -17,7 +17,7 @@ re_numbers_at_start = re.compile(r"^[-\d]+\s*")
class DatasetEntry:
def __init__(self, filename=None, filename_text=None, latent_dist=None, latent_sample=None, cond=None, cond_text=None, pixel_values=None):
def __init__(self, filename=None, filename_text=None, latent_dist=None, latent_sample=None, cond=None, cond_text=None, pixel_values=None, img_shape=None):
self.filename = filename
self.filename_text = filename_text
self.latent_dist = latent_dist
@ -25,6 +25,7 @@ class DatasetEntry:
self.cond = cond
self.cond_text = cond_text
self.pixel_values = pixel_values
self.img_shape = img_shape
class PersonalizedBase(Dataset):
@ -33,8 +34,6 @@ class PersonalizedBase(Dataset):
self.placeholder_token = placeholder_token
self.width = width
self.height = height
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
self.dataset = []
@ -59,7 +58,11 @@ class PersonalizedBase(Dataset):
if shared.state.interrupted:
raise Exception("interrupted")
try:
image = Image.open(path).convert('RGB').resize((self.width, self.height), PIL.Image.BICUBIC)
image = Image.open(path).convert('RGB')
if width < 2000:
image = image.resize((width, height), PIL.Image.BICUBIC)
else:
assert batch_size == 1, 'variable img size must have batch size 1'
except Exception:
continue
@ -88,14 +91,14 @@ class PersonalizedBase(Dataset):
if latent_sampling_method == "once" or (latent_sampling_method == "deterministic" and not isinstance(latent_dist, DiagonalGaussianDistribution)):
latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu)
latent_sampling_method = "once"
entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample)
entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample, img_shape=image.size)
elif latent_sampling_method == "deterministic":
# Works only for DiagonalGaussianDistribution
latent_dist.std = 0
latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu)
entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample)
entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample, img_shape=image.size)
elif latent_sampling_method == "random":
entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist)
entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist, img_shape=image.size)
if not (self.tag_drop_out != 0 or self.shuffle_tags):
entry.cond_text = self.create_text(filename_text)
@ -151,6 +154,7 @@ class BatchLoader:
self.cond_text = [entry.cond_text for entry in data]
self.cond = [entry.cond for entry in data]
self.latent_sample = torch.stack([entry.latent_sample for entry in data]).squeeze(1)
self.img_shape = [entry.img_shape for entry in data]
#self.emb_index = [entry.emb_index for entry in data]
#print(self.latent_sample.device)

View file

@ -451,8 +451,8 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
else:
p.prompt = batch.cond_text[0]
p.steps = 20
p.width = training_width
p.height = training_height
p.width = batch.img_shape[0][0]
p.height = batch.img_shape[0][1]
preview_text = p.prompt