Merge pull request #4117 from TinkTheBoush/master

Adding optional tag shuffling for training
This commit is contained in:
AUTOMATIC1111 2022-11-11 15:46:20 +03:00 committed by GitHub
commit 73776907ec
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 8 additions and 1 deletions

View file

@ -319,6 +319,8 @@ options_templates.update(options_section(('system', "System"), {
options_templates.update(options_section(('training', "Training"), { options_templates.update(options_section(('training', "Training"), {
"unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."), "unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."),
"shuffle_tags": OptionInfo(False, "Shuffleing tags by ',' when create texts."),
"tag_drop_out": OptionInfo(0, "Dropout tags when create texts", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.1}),
"save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training can be resumed with HN itself and matching optim file."), "save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training can be resumed with HN itself and matching optim file."),
"dataset_filename_word_regex": OptionInfo("", "Filename word regex"), "dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
"dataset_filename_join_string": OptionInfo(" ", "Filename join string"), "dataset_filename_join_string": OptionInfo(" ", "Filename join string"),

View file

@ -98,7 +98,12 @@ class PersonalizedBase(Dataset):
def create_text(self, filename_text): def create_text(self, filename_text):
text = random.choice(self.lines) text = random.choice(self.lines)
text = text.replace("[name]", self.placeholder_token) text = text.replace("[name]", self.placeholder_token)
text = text.replace("[filewords]", filename_text) tags = filename_text.split(',')
if shared.opts.tag_drop_out != 0:
tags = [t for t in tags if random.random() > shared.opts.tag_drop_out]
if shared.opts.shuffle_tags:
random.shuffle(tags)
text = text.replace("[filewords]", ','.join(tags))
return text return text
def __len__(self): def __len__(self):