Merge branch 'AUTOMATIC1111:master' into master

This commit is contained in:
不会画画的中医不是好程序员 2022-10-13 12:35:39 +08:00 committed by GitHub
commit 0186db178e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 202 additions and 135 deletions

View file

@ -81,6 +81,9 @@ titles = {
"Eta noise seed delta": "If this values is non-zero, it will be added to seed and used to initialize RNG for noises when using samplers with Eta. You can use this to produce even more variation of images, or you can use this to match images of other software if you know what you are doing.", "Eta noise seed delta": "If this values is non-zero, it will be added to seed and used to initialize RNG for noises when using samplers with Eta. You can use this to produce even more variation of images, or you can use this to match images of other software if you know what you are doing.",
"Do not add watermark to images": "If this option is enabled, watermark will not be added to created images. Warning: if you do not add watermark, you may be behaving in an unethical manner.", "Do not add watermark to images": "If this option is enabled, watermark will not be added to created images. Warning: if you do not add watermark, you may be behaving in an unethical manner.",
"Filename word regex": "This regular expression will be used extract words from filename, and they will be joined using the option below into label text used for training. Leave empty to keep filename text as it is.",
"Filename join string": "This string will be used to hoin split words into a single line if the option above is enabled.",
} }

View file

@ -2,33 +2,44 @@ import os.path
from concurrent.futures import ProcessPoolExecutor from concurrent.futures import ProcessPoolExecutor
import multiprocessing import multiprocessing
import time import time
import re
re_special = re.compile(r'([\\()])')
def get_deepbooru_tags(pil_image): def get_deepbooru_tags(pil_image):
""" """
This method is for running only one image at a time for simple use. Used to the img2img interrogate. This method is for running only one image at a time for simple use. Used to the img2img interrogate.
""" """
from modules import shared # prevents circular reference from modules import shared # prevents circular reference
create_deepbooru_process(shared.opts.interrogate_deepbooru_score_threshold, shared.opts.deepbooru_sort_alpha)
shared.deepbooru_process_return["value"] = -1 try:
shared.deepbooru_process_queue.put(pil_image) create_deepbooru_process(shared.opts.interrogate_deepbooru_score_threshold, create_deepbooru_opts())
while shared.deepbooru_process_return["value"] == -1: return get_tags_from_process(pil_image)
time.sleep(0.2) finally:
tags = shared.deepbooru_process_return["value"]
release_process() release_process()
return tags
def deepbooru_process(queue, deepbooru_process_return, threshold, alpha_sort): def create_deepbooru_opts():
from modules import shared
return {
"use_spaces": shared.opts.deepbooru_use_spaces,
"use_escape": shared.opts.deepbooru_escape,
"alpha_sort": shared.opts.deepbooru_sort_alpha,
}
def deepbooru_process(queue, deepbooru_process_return, threshold, deepbooru_opts):
model, tags = get_deepbooru_tags_model() model, tags = get_deepbooru_tags_model()
while True: # while process is running, keep monitoring queue for new image while True: # while process is running, keep monitoring queue for new image
pil_image = queue.get() pil_image = queue.get()
if pil_image == "QUIT": if pil_image == "QUIT":
break break
else: else:
deepbooru_process_return["value"] = get_deepbooru_tags_from_model(model, tags, pil_image, threshold, alpha_sort) deepbooru_process_return["value"] = get_deepbooru_tags_from_model(model, tags, pil_image, threshold, deepbooru_opts)
def create_deepbooru_process(threshold, alpha_sort): def create_deepbooru_process(threshold, deepbooru_opts):
""" """
Creates deepbooru process. A queue is created to send images into the process. This enables multiple images Creates deepbooru process. A queue is created to send images into the process. This enables multiple images
to be processed in a row without reloading the model or creating a new process. To return the data, a shared to be processed in a row without reloading the model or creating a new process. To return the data, a shared
@ -41,10 +52,23 @@ def create_deepbooru_process(threshold, alpha_sort):
shared.deepbooru_process_queue = shared.deepbooru_process_manager.Queue() shared.deepbooru_process_queue = shared.deepbooru_process_manager.Queue()
shared.deepbooru_process_return = shared.deepbooru_process_manager.dict() shared.deepbooru_process_return = shared.deepbooru_process_manager.dict()
shared.deepbooru_process_return["value"] = -1 shared.deepbooru_process_return["value"] = -1
shared.deepbooru_process = multiprocessing.Process(target=deepbooru_process, args=(shared.deepbooru_process_queue, shared.deepbooru_process_return, threshold, alpha_sort)) shared.deepbooru_process = multiprocessing.Process(target=deepbooru_process, args=(shared.deepbooru_process_queue, shared.deepbooru_process_return, threshold, deepbooru_opts))
shared.deepbooru_process.start() shared.deepbooru_process.start()
def get_tags_from_process(image):
from modules import shared
shared.deepbooru_process_return["value"] = -1
shared.deepbooru_process_queue.put(image)
while shared.deepbooru_process_return["value"] == -1:
time.sleep(0.2)
caption = shared.deepbooru_process_return["value"]
shared.deepbooru_process_return["value"] = -1
return caption
def release_process(): def release_process():
""" """
Stops the deepbooru process to return used memory Stops the deepbooru process to return used memory
@ -81,10 +105,15 @@ def get_deepbooru_tags_model():
return model, tags return model, tags
def get_deepbooru_tags_from_model(model, tags, pil_image, threshold, alpha_sort): def get_deepbooru_tags_from_model(model, tags, pil_image, threshold, deepbooru_opts):
import deepdanbooru as dd import deepdanbooru as dd
import tensorflow as tf import tensorflow as tf
import numpy as np import numpy as np
alpha_sort = deepbooru_opts['alpha_sort']
use_spaces = deepbooru_opts['use_spaces']
use_escape = deepbooru_opts['use_escape']
width = model.input_shape[2] width = model.input_shape[2]
height = model.input_shape[1] height = model.input_shape[1]
image = np.array(pil_image) image = np.array(pil_image)
@ -129,4 +158,12 @@ def get_deepbooru_tags_from_model(model, tags, pil_image, threshold, alpha_sort)
print('\n'.join(sorted(result_tags_print, reverse=True))) print('\n'.join(sorted(result_tags_print, reverse=True)))
return ', '.join(result_tags_out).replace('_', ' ').replace(':', ' ') tags_text = ', '.join(result_tags_out)
if use_spaces:
tags_text = tags_text.replace('_', ' ')
if use_escape:
tags_text = re.sub(re_special, r'\\\1', tags_text)
return tags_text.replace(':', ' ')

View file

@ -14,7 +14,7 @@ import torch
from torch import einsum from torch import einsum
from einops import rearrange, repeat from einops import rearrange, repeat
import modules.textual_inversion.dataset import modules.textual_inversion.dataset
from modules.textual_inversion.learn_schedule import LearnSchedule from modules.textual_inversion.learn_schedule import LearnRateScheduler
class HypernetworkModule(torch.nn.Module): class HypernetworkModule(torch.nn.Module):
@ -223,31 +223,23 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
if ititial_step > steps: if ititial_step > steps:
return hypernetwork, filename return hypernetwork, filename
schedules = iter(LearnSchedule(learn_rate, steps, ititial_step)) scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
(learn_rate, end_step) = next(schedules) optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate)
print(f'Training at rate of {learn_rate} until step {end_step}')
optimizer = torch.optim.AdamW(weights, lr=learn_rate)
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step) pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
for i, (x, text, cond) in pbar: for i, entry in pbar:
hypernetwork.step = i + ititial_step hypernetwork.step = i + ititial_step
if hypernetwork.step > end_step: scheduler.apply(optimizer, hypernetwork.step)
try: if scheduler.finished:
(learn_rate, end_step) = next(schedules)
except Exception:
break break
tqdm.tqdm.write(f'Training at rate of {learn_rate} until step {end_step}')
for pg in optimizer.param_groups:
pg['lr'] = learn_rate
if shared.state.interrupted: if shared.state.interrupted:
break break
with torch.autocast("cuda"): with torch.autocast("cuda"):
cond = cond.to(devices.device) cond = entry.cond.to(devices.device)
x = x.to(devices.device) x = entry.latent.to(devices.device)
loss = shared.sd_model(x.unsqueeze(0), cond)[0] loss = shared.sd_model(x.unsqueeze(0), cond)[0]
del x del x
del cond del cond
@ -267,7 +259,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0: if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0:
last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png') last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png')
preview_text = text if preview_image_prompt == "" else preview_image_prompt preview_text = entry.cond_text if preview_image_prompt == "" else preview_image_prompt
optimizer.zero_grad() optimizer.zero_grad()
shared.sd_model.cond_stage_model.to(devices.device) shared.sd_model.cond_stage_model.to(devices.device)
@ -282,15 +274,15 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
) )
processed = processing.process_images(p) processed = processing.process_images(p)
image = processed.images[0] image = processed.images[0] if len(processed.images)>0 else None
if unload: if unload:
shared.sd_model.cond_stage_model.to(devices.cpu) shared.sd_model.cond_stage_model.to(devices.cpu)
shared.sd_model.first_stage_model.to(devices.cpu) shared.sd_model.first_stage_model.to(devices.cpu)
if image is not None:
shared.state.current_image = image shared.state.current_image = image
image.save(last_saved_image) image.save(last_saved_image)
last_saved_image += f", prompt: {preview_text}" last_saved_image += f", prompt: {preview_text}"
shared.state.job_no = hypernetwork.step shared.state.job_no = hypernetwork.step
@ -299,7 +291,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
<p> <p>
Loss: {losses.mean():.7f}<br/> Loss: {losses.mean():.7f}<br/>
Step: {hypernetwork.step}<br/> Step: {hypernetwork.step}<br/>
Last prompt: {html.escape(text)}<br/> Last prompt: {html.escape(entry.cond_text)}<br/>
Last saved embedding: {html.escape(last_saved_file)}<br/> Last saved embedding: {html.escape(last_saved_file)}<br/>
Last saved image: {html.escape(last_saved_image)}<br/> Last saved image: {html.escape(last_saved_image)}<br/>
</p> </p>

View file

@ -231,6 +231,9 @@ options_templates.update(options_section(('system', "System"), {
options_templates.update(options_section(('training', "Training"), { options_templates.update(options_section(('training', "Training"), {
"unload_models_when_training": OptionInfo(False, "Unload VAE and CLIP from VRAM when training"), "unload_models_when_training": OptionInfo(False, "Unload VAE and CLIP from VRAM when training"),
"dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
"dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
"training_image_repeats_per_epoch": OptionInfo(100, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
})) }))
options_templates.update(options_section(('sd', "Stable Diffusion"), { options_templates.update(options_section(('sd', "Stable Diffusion"), {
@ -257,6 +260,8 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"),
"interrogate_clip_max_length": OptionInfo(48, "Interrogate: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}), "interrogate_clip_max_length": OptionInfo(48, "Interrogate: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}),
"interrogate_deepbooru_score_threshold": OptionInfo(0.5, "Interrogate: deepbooru score threshold", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}), "interrogate_deepbooru_score_threshold": OptionInfo(0.5, "Interrogate: deepbooru score threshold", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),
"deepbooru_sort_alpha": OptionInfo(True, "Interrogate: deepbooru sort alphabetically"), "deepbooru_sort_alpha": OptionInfo(True, "Interrogate: deepbooru sort alphabetically"),
"deepbooru_use_spaces": OptionInfo(False, "use spaces for tags in deepbooru"),
"deepbooru_escape": OptionInfo(True, "escape (\\) brackets in deepbooru (so they are used as literal brackets and not for emphasis)"),
})) }))
options_templates.update(options_section(('ui', "User interface"), { options_templates.update(options_section(('ui', "User interface"), {

View file

@ -11,11 +11,21 @@ import tqdm
from modules import devices, shared from modules import devices, shared
import re import re
re_tag = re.compile(r"[a-zA-Z][_\w\d()]+") re_numbers_at_start = re.compile(r"^[-\d]+\s*")
class DatasetEntry:
def __init__(self, filename=None, latent=None, filename_text=None):
self.filename = filename
self.latent = latent
self.filename_text = filename_text
self.cond = None
self.cond_text = None
class PersonalizedBase(Dataset): class PersonalizedBase(Dataset):
def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None, include_cond=False): def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None, include_cond=False):
re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex)>0 else None
self.placeholder_token = placeholder_token self.placeholder_token = placeholder_token
@ -42,9 +52,18 @@ class PersonalizedBase(Dataset):
except Exception: except Exception:
continue continue
text_filename = os.path.splitext(path)[0] + ".txt"
filename = os.path.basename(path) filename = os.path.basename(path)
filename_tokens = os.path.splitext(filename)[0]
filename_tokens = re_tag.findall(filename_tokens) if os.path.exists(text_filename):
with open(text_filename, "r", encoding="utf8") as file:
filename_text = file.read()
else:
filename_text = os.path.splitext(filename)[0]
filename_text = re.sub(re_numbers_at_start, '', filename_text)
if re_word:
tokens = re_word.findall(filename_text)
filename_text = (shared.opts.dataset_filename_join_string or "").join(tokens)
npimage = np.array(image).astype(np.uint8) npimage = np.array(image).astype(np.uint8)
npimage = (npimage / 127.5 - 1.0).astype(np.float32) npimage = (npimage / 127.5 - 1.0).astype(np.float32)
@ -55,13 +74,13 @@ class PersonalizedBase(Dataset):
init_latent = model.get_first_stage_encoding(model.encode_first_stage(torchdata.unsqueeze(dim=0))).squeeze() init_latent = model.get_first_stage_encoding(model.encode_first_stage(torchdata.unsqueeze(dim=0))).squeeze()
init_latent = init_latent.to(devices.cpu) init_latent = init_latent.to(devices.cpu)
if include_cond: entry = DatasetEntry(filename=path, filename_text=filename_text, latent=init_latent)
text = self.create_text(filename_tokens)
cond = cond_model([text]).to(devices.cpu)
else:
cond = None
self.dataset.append((init_latent, filename_tokens, cond)) if include_cond:
entry.cond_text = self.create_text(filename_text)
entry.cond = cond_model([entry.cond_text]).to(devices.cpu)
self.dataset.append(entry)
self.length = len(self.dataset) * repeats self.length = len(self.dataset) * repeats
@ -72,10 +91,10 @@ class PersonalizedBase(Dataset):
def shuffle(self): def shuffle(self):
self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0])] self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0])]
def create_text(self, filename_tokens): def create_text(self, filename_text):
text = random.choice(self.lines) text = random.choice(self.lines)
text = text.replace("[name]", self.placeholder_token) text = text.replace("[name]", self.placeholder_token)
text = text.replace("[filewords]", ' '.join(filename_tokens)) text = text.replace("[filewords]", filename_text)
return text return text
def __len__(self): def __len__(self):
@ -86,7 +105,9 @@ class PersonalizedBase(Dataset):
self.shuffle() self.shuffle()
index = self.indexes[i % len(self.indexes)] index = self.indexes[i % len(self.indexes)]
x, filename_tokens, cond = self.dataset[index] entry = self.dataset[index]
text = self.create_text(filename_tokens) if entry.cond is None:
return x, text, cond entry.cond_text = self.create_text(entry.filename_text)
return entry

View file

@ -1,6 +1,12 @@
import tqdm
class LearnSchedule:
class LearnScheduleIterator:
def __init__(self, learn_rate, max_steps, cur_step=0): def __init__(self, learn_rate, max_steps, cur_step=0):
"""
specify learn_rate as "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, 1e-5:10000 until 10000
"""
pairs = learn_rate.split(',') pairs = learn_rate.split(',')
self.rates = [] self.rates = []
self.it = 0 self.it = 0
@ -32,3 +38,32 @@ class LearnSchedule:
return self.rates[self.it - 1] return self.rates[self.it - 1]
else: else:
raise StopIteration raise StopIteration
class LearnRateScheduler:
def __init__(self, learn_rate, max_steps, cur_step=0, verbose=True):
self.schedules = LearnScheduleIterator(learn_rate, max_steps, cur_step)
(self.learn_rate, self.end_step) = next(self.schedules)
self.verbose = verbose
if self.verbose:
print(f'Training at rate of {self.learn_rate} until step {self.end_step}')
self.finished = False
def apply(self, optimizer, step_number):
if step_number <= self.end_step:
return
try:
(self.learn_rate, self.end_step) = next(self.schedules)
except Exception:
self.finished = True
return
if self.verbose:
tqdm.tqdm.write(f'Training at rate of {self.learn_rate} until step {self.end_step}')
for pg in optimizer.param_groups:
pg['lr'] = self.learn_rate

View file

@ -10,7 +10,28 @@ from modules.shared import opts, cmd_opts
if cmd_opts.deepdanbooru: if cmd_opts.deepdanbooru:
import modules.deepbooru as deepbooru import modules.deepbooru as deepbooru
def preprocess(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False): def preprocess(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False):
try:
if process_caption:
shared.interrogator.load()
if process_caption_deepbooru:
deepbooru.create_deepbooru_process(opts.interrogate_deepbooru_score_threshold, deepbooru.create_deepbooru_opts())
preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru)
finally:
if process_caption:
shared.interrogator.send_blip_to_ram()
if process_caption_deepbooru:
deepbooru.release_process()
def preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False):
width = process_width width = process_width
height = process_height height = process_height
src = os.path.abspath(process_src) src = os.path.abspath(process_src)
@ -25,30 +46,28 @@ def preprocess(process_src, process_dst, process_width, process_height, process_
shared.state.textinfo = "Preprocessing..." shared.state.textinfo = "Preprocessing..."
shared.state.job_count = len(files) shared.state.job_count = len(files)
def save_pic_with_caption(image, index):
caption = ""
if process_caption: if process_caption:
shared.interrogator.load() caption += shared.interrogator.generate_caption(image)
if process_caption_deepbooru: if process_caption_deepbooru:
deepbooru.create_deepbooru_process(opts.interrogate_deepbooru_score_threshold, opts.deepbooru_sort_alpha) if len(caption) > 0:
caption += ", "
caption += deepbooru.get_tags_from_process(image)
def save_pic_with_caption(image, index): filename_part = filename
if process_caption: filename_part = os.path.splitext(filename_part)[0]
caption = "-" + shared.interrogator.generate_caption(image) filename_part = os.path.basename(filename_part)
caption = sanitize_caption(os.path.join(dst, f"{index:05}-{subindex[0]}"), caption, ".png")
elif process_caption_deepbooru: basename = f"{index:05}-{subindex[0]}-{filename_part}"
shared.deepbooru_process_return["value"] = -1 image.save(os.path.join(dst, f"{basename}.png"))
shared.deepbooru_process_queue.put(image)
while shared.deepbooru_process_return["value"] == -1: if len(caption) > 0:
time.sleep(0.2) with open(os.path.join(dst, f"{basename}.txt"), "w", encoding="utf8") as file:
caption = "-" + shared.deepbooru_process_return["value"] file.write(caption)
caption = sanitize_caption(os.path.join(dst, f"{index:05}-{subindex[0]}"), caption, ".png")
shared.deepbooru_process_return["value"] = -1
else:
caption = filename
caption = os.path.splitext(caption)[0]
caption = os.path.basename(caption)
image.save(os.path.join(dst, f"{index:05}-{subindex[0]}{caption}.png"))
subindex[0] += 1 subindex[0] += 1
def save_pic(image, index): def save_pic(image, index):
@ -93,34 +112,3 @@ def preprocess(process_src, process_dst, process_width, process_height, process_
save_pic(img, index) save_pic(img, index)
shared.state.nextjob() shared.state.nextjob()
if process_caption:
shared.interrogator.send_blip_to_ram()
if process_caption_deepbooru:
deepbooru.release_process()
def sanitize_caption(base_path, original_caption, suffix):
operating_system = platform.system().lower()
if (operating_system == "windows"):
invalid_path_characters = "\\/:*?\"<>|"
max_path_length = 259
else:
invalid_path_characters = "/" #linux/macos
max_path_length = 1023
caption = original_caption
for invalid_character in invalid_path_characters:
caption = caption.replace(invalid_character, "")
fixed_path_length = len(base_path) + len(suffix)
if fixed_path_length + len(caption) <= max_path_length:
return caption
caption_tokens = caption.split()
new_caption = ""
for token in caption_tokens:
last_caption = new_caption
new_caption = new_caption + token + " "
if (len(new_caption) + fixed_path_length - 1 > max_path_length):
break
print(f"\nPath will be too long. Truncated caption: {original_caption}\nto: {last_caption}", file=sys.stderr)
return last_caption.strip()

View file

@ -11,7 +11,7 @@ from PIL import Image, PngImagePlugin
from modules import shared, devices, sd_hijack, processing, sd_models from modules import shared, devices, sd_hijack, processing, sd_models
import modules.textual_inversion.dataset import modules.textual_inversion.dataset
from modules.textual_inversion.learn_schedule import LearnSchedule from modules.textual_inversion.learn_schedule import LearnRateScheduler
from modules.textual_inversion.image_embedding import (embedding_to_b64, embedding_from_b64, from modules.textual_inversion.image_embedding import (embedding_to_b64, embedding_from_b64,
insert_image_data_embed, extract_image_data_embed, insert_image_data_embed, extract_image_data_embed,
@ -172,8 +172,7 @@ def create_embedding(name, num_vectors_per_token, init_text='*'):
return fn return fn
def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_image_prompt):
def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, num_repeats, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_image_prompt):
assert embedding_name, 'embedding not selected' assert embedding_name, 'embedding not selected'
shared.state.textinfo = "Initializing textual inversion training..." shared.state.textinfo = "Initializing textual inversion training..."
@ -205,7 +204,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"): with torch.autocast("cuda"):
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=num_repeats, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file) ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file)
hijack = sd_hijack.model_hijack hijack = sd_hijack.model_hijack
@ -221,32 +220,24 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
if ititial_step > steps: if ititial_step > steps:
return embedding, filename return embedding, filename
schedules = iter(LearnSchedule(learn_rate, steps, ititial_step)) scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
(learn_rate, end_step) = next(schedules) optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate)
print(f'Training at rate of {learn_rate} until step {end_step}')
optimizer = torch.optim.AdamW([embedding.vec], lr=learn_rate)
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step) pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
for i, (x, text, _) in pbar: for i, entry in pbar:
embedding.step = i + ititial_step embedding.step = i + ititial_step
if embedding.step > end_step: scheduler.apply(optimizer, embedding.step)
try: if scheduler.finished:
(learn_rate, end_step) = next(schedules)
except:
break break
tqdm.tqdm.write(f'Training at rate of {learn_rate} until step {end_step}')
for pg in optimizer.param_groups:
pg['lr'] = learn_rate
if shared.state.interrupted: if shared.state.interrupted:
break break
with torch.autocast("cuda"): with torch.autocast("cuda"):
c = cond_model([text]) c = cond_model([entry.cond_text])
x = x.to(devices.device) x = entry.latent.to(devices.device)
loss = shared.sd_model(x.unsqueeze(0), c)[0] loss = shared.sd_model(x.unsqueeze(0), c)[0]
del x del x
@ -268,7 +259,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0: if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0:
last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png') last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png')
preview_text = text if preview_image_prompt == "" else preview_image_prompt preview_text = entry.cond_text if preview_image_prompt == "" else preview_image_prompt
p = processing.StableDiffusionProcessingTxt2Img( p = processing.StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model, sd_model=shared.sd_model,
@ -314,7 +305,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
<p> <p>
Loss: {losses.mean():.7f}<br/> Loss: {losses.mean():.7f}<br/>
Step: {embedding.step}<br/> Step: {embedding.step}<br/>
Last prompt: {html.escape(text)}<br/> Last prompt: {html.escape(entry.cond_text)}<br/>
Last saved embedding: {html.escape(last_saved_file)}<br/> Last saved embedding: {html.escape(last_saved_file)}<br/>
Last saved image: {html.escape(last_saved_image)}<br/> Last saved image: {html.escape(last_saved_image)}<br/>
</p> </p>

View file

@ -1082,11 +1082,8 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Row(): with gr.Row():
process_flip = gr.Checkbox(label='Create flipped copies') process_flip = gr.Checkbox(label='Create flipped copies')
process_split = gr.Checkbox(label='Split oversized images into two') process_split = gr.Checkbox(label='Split oversized images into two')
process_caption = gr.Checkbox(label='Use BLIP caption as filename') process_caption = gr.Checkbox(label='Use BLIP for caption')
if cmd_opts.deepdanbooru: process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True if cmd_opts.deepdanbooru else False)
process_caption_deepbooru = gr.Checkbox(label='Use deepbooru caption as filename')
else:
process_caption_deepbooru = gr.Checkbox(label='Use deepbooru caption as filename', visible=False)
with gr.Row(): with gr.Row():
with gr.Column(scale=3): with gr.Column(scale=3):
@ -1106,7 +1103,6 @@ def create_ui(wrap_gradio_gpu_call):
training_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512) training_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
training_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512) training_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
steps = gr.Number(label='Max steps', value=100000, precision=0) steps = gr.Number(label='Max steps', value=100000, precision=0)
num_repeats = gr.Number(label='Number of repeats for a single input image per epoch', value=100, precision=0)
create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0) create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0)
save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0) save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0)
save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True) save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True)
@ -1184,7 +1180,6 @@ def create_ui(wrap_gradio_gpu_call):
training_width, training_width,
training_height, training_height,
steps, steps,
num_repeats,
create_image_every, create_image_every,
save_embedding_every, save_embedding_every,
template_file, template_file,