diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index b0cfbe71..32c67ccc 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -403,7 +403,7 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, shared.reload_hypernetworks() -def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): +def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): # images allows training previews to have infotext. Importing it at the top causes a circular import problem. from modules import images @@ -456,7 +456,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, pin_memory = shared.opts.pin_memory - ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method) + ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method, varsize=varsize) if shared.opts.save_training_settings_to_txt: saved_params = dict( diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index 88d68c76..bcad6848 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -17,7 +17,7 @@ re_numbers_at_start = re.compile(r"^[-\d]+\s*") class DatasetEntry: - def __init__(self, filename=None, filename_text=None, latent_dist=None, latent_sample=None, cond=None, cond_text=None, pixel_values=None): + def __init__(self, filename=None, filename_text=None, latent_dist=None, latent_sample=None, cond=None, cond_text=None, pixel_values=None, img_shape=None): self.filename = filename self.filename_text = filename_text self.latent_dist = latent_dist @@ -25,16 +25,15 @@ class DatasetEntry: self.cond = cond self.cond_text = cond_text self.pixel_values = pixel_values + self.img_shape = img_shape class PersonalizedBase(Dataset): - def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once'): + def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once', varsize=False): re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None self.placeholder_token = placeholder_token - self.width = width - self.height = height self.flip = transforms.RandomHorizontalFlip(p=flip_p) self.dataset = [] @@ -47,6 +46,8 @@ class PersonalizedBase(Dataset): assert data_root, 'dataset directory not specified' assert os.path.isdir(data_root), "Dataset directory doesn't exist" assert os.listdir(data_root), "Dataset directory is empty" + if varsize: + assert batch_size == 1, 'variable img size must have batch size 1' self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)] @@ -59,7 +60,9 @@ class PersonalizedBase(Dataset): if shared.state.interrupted: raise Exception("interrupted") try: - image = Image.open(path).convert('RGB').resize((self.width, self.height), PIL.Image.BICUBIC) + image = Image.open(path).convert('RGB') + if not varsize: + image = image.resize((width, height), PIL.Image.BICUBIC) except Exception: continue @@ -88,14 +91,14 @@ class PersonalizedBase(Dataset): if latent_sampling_method == "once" or (latent_sampling_method == "deterministic" and not isinstance(latent_dist, DiagonalGaussianDistribution)): latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu) latent_sampling_method = "once" - entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample) + entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample, img_shape=image.size) elif latent_sampling_method == "deterministic": # Works only for DiagonalGaussianDistribution latent_dist.std = 0 latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu) - entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample) + entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample, img_shape=image.size) elif latent_sampling_method == "random": - entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist) + entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist, img_shape=image.size) if not (self.tag_drop_out != 0 or self.shuffle_tags): entry.cond_text = self.create_text(filename_text) @@ -151,6 +154,7 @@ class BatchLoader: self.cond_text = [entry.cond_text for entry in data] self.cond = [entry.cond for entry in data] self.latent_sample = torch.stack([entry.latent_sample for entry in data]).squeeze(1) + self.img_shape = [entry.img_shape for entry in data] #self.emb_index = [entry.emb_index for entry in data] #print(self.latent_sample.device) diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 217fe9eb..ad76297e 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -296,8 +296,7 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat if save_model_every or create_image_every: assert log_directory, "Log directory is empty" - -def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): +def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): save_embedding_every = save_embedding_every or 0 create_image_every = create_image_every or 0 validate_train_inputs(embedding_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_embedding_every, create_image_every, log_directory, name="embedding") @@ -351,7 +350,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ pin_memory = shared.opts.pin_memory - ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method) + ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method, varsize=varsize) if shared.opts.save_training_settings_to_txt: save_settings_to_file(log_directory, {**dict(model_name=checkpoint.model_name, model_hash=checkpoint.hash, num_of_dataset_images=len(ds), num_vectors_per_token=len(embedding.vec)), **locals()}) @@ -493,8 +492,8 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ else: p.prompt = batch.cond_text[0] p.steps = 20 - p.width = training_width - p.height = training_height + p.width = batch.img_shape[0][0] + p.height = batch.img_shape[0][1] preview_text = p.prompt diff --git a/modules/ui.py b/modules/ui.py index 719c26b3..9d6b141e 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1348,6 +1348,7 @@ def create_ui(): template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"), elem_id="train_template_file") training_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_training_width") training_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_training_height") + varsize = gr.Checkbox(label="Ignore dimension settings and do not resize images", value=False, elem_id="train_varsize") steps = gr.Number(label='Max steps', value=100000, precision=0, elem_id="train_steps") with FormRow(): @@ -1454,6 +1455,7 @@ def create_ui(): log_directory, training_width, training_height, + varsize, steps, clip_grad_mode, clip_grad_value, @@ -1485,6 +1487,7 @@ def create_ui(): log_directory, training_width, training_height, + varsize, steps, clip_grad_mode, clip_grad_value,