diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index df278dc2..a95c7835 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -98,12 +98,12 @@ class PersonalizedBase(Dataset): def create_text(self, filename_text): text = random.choice(self.lines) text = text.replace("[name]", self.placeholder_token) + tags = filename_text.split(',') + if shared.opt.tag_drop_out != 0: + tags = [t for t in tags if random.random() > shared.opt.tag_drop_out] if shared.opts.shuffle_tags: - tags = filename_text.split(',') random.shuffle(tags) - text = text.replace("[filewords]", ','.join(tags)) - else: - text = text.replace("[filewords]", filename_text) + text = text.replace("[filewords]", ','.join(tags)) return text def __len__(self):