Merge branch 'master' into master
This commit is contained in:
commit
ea8aa1701a
11 changed files with 216 additions and 136 deletions
|
@ -163,10 +163,15 @@ function images_history_init(){
|
||||||
for (var i in images_history_tab_list){
|
for (var i in images_history_tab_list){
|
||||||
var tabname = images_history_tab_list[i]
|
var tabname = images_history_tab_list[i]
|
||||||
tab_btns[i].setAttribute("tabname", tabname);
|
tab_btns[i].setAttribute("tabname", tabname);
|
||||||
tab_btns[i].addEventListener('click', images_history_click_tab);
|
|
||||||
|
// this refreshes history upon tab switch
|
||||||
|
// until the history is known to work well, which is not the case now, we do not do this at startup
|
||||||
|
//tab_btns[i].addEventListener('click', images_history_click_tab);
|
||||||
}
|
}
|
||||||
tabs_box.classList.add(images_history_tab_list[0]);
|
tabs_box.classList.add(images_history_tab_list[0]);
|
||||||
load_txt2img_button.click();
|
|
||||||
|
// same as above, at page load
|
||||||
|
//load_txt2img_button.click();
|
||||||
} else {
|
} else {
|
||||||
setTimeout(images_history_init, 500);
|
setTimeout(images_history_init, 500);
|
||||||
}
|
}
|
||||||
|
@ -182,12 +187,15 @@ document.addEventListener("DOMContentLoaded", function() {
|
||||||
buttons.forEach(function(bnt){
|
buttons.forEach(function(bnt){
|
||||||
bnt.addEventListener('click', images_history_click_image, true);
|
bnt.addEventListener('click', images_history_click_image, true);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// same as load_txt2img_button.click() above
|
||||||
|
/*
|
||||||
var cls_btn = gradioApp().getElementById(tabname + '_images_history_gallery').querySelector("svg");
|
var cls_btn = gradioApp().getElementById(tabname + '_images_history_gallery').querySelector("svg");
|
||||||
if (cls_btn){
|
if (cls_btn){
|
||||||
cls_btn.addEventListener('click', function(){
|
cls_btn.addEventListener('click', function(){
|
||||||
gradioApp().getElementById(tabname + '_images_history_renew_page').click();
|
gradioApp().getElementById(tabname + '_images_history_renew_page').click();
|
||||||
}, false);
|
}, false);
|
||||||
}
|
}*/
|
||||||
|
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
|
@ -1,5 +1,7 @@
|
||||||
// code related to showing and updating progressbar shown as the image is being made
|
// code related to showing and updating progressbar shown as the image is being made
|
||||||
global_progressbars = {}
|
global_progressbars = {}
|
||||||
|
galleries = {}
|
||||||
|
galleryObservers = {}
|
||||||
|
|
||||||
function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_skip, id_interrupt, id_preview, id_gallery){
|
function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_skip, id_interrupt, id_preview, id_gallery){
|
||||||
var progressbar = gradioApp().getElementById(id_progressbar)
|
var progressbar = gradioApp().getElementById(id_progressbar)
|
||||||
|
@ -31,21 +33,54 @@ function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_skip
|
||||||
preview.style.width = gallery.clientWidth + "px"
|
preview.style.width = gallery.clientWidth + "px"
|
||||||
preview.style.height = gallery.clientHeight + "px"
|
preview.style.height = gallery.clientHeight + "px"
|
||||||
|
|
||||||
|
//only watch gallery if there is a generation process going on
|
||||||
|
check_gallery(id_gallery);
|
||||||
|
|
||||||
var progressDiv = gradioApp().querySelectorAll('#' + id_progressbar_span).length > 0;
|
var progressDiv = gradioApp().querySelectorAll('#' + id_progressbar_span).length > 0;
|
||||||
if(!progressDiv){
|
if(!progressDiv){
|
||||||
if (skip) {
|
if (skip) {
|
||||||
skip.style.display = "none"
|
skip.style.display = "none"
|
||||||
}
|
}
|
||||||
interrupt.style.display = "none"
|
interrupt.style.display = "none"
|
||||||
|
|
||||||
|
//disconnect observer once generation finished, so user can close selected image if they want
|
||||||
|
if (galleryObservers[id_gallery]) {
|
||||||
|
galleryObservers[id_gallery].disconnect();
|
||||||
|
galleries[id_gallery] = null;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
window.setTimeout(function() { requestMoreProgress(id_part, id_progressbar_span, id_skip, id_interrupt) }, 500)
|
window.setTimeout(function() { requestMoreProgress(id_part, id_progressbar_span, id_skip, id_interrupt) }, 500)
|
||||||
});
|
});
|
||||||
mutationObserver.observe( progressbar, { childList:true, subtree:true })
|
mutationObserver.observe( progressbar, { childList:true, subtree:true })
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function check_gallery(id_gallery){
|
||||||
|
let gallery = gradioApp().getElementById(id_gallery)
|
||||||
|
// if gallery has no change, no need to setting up observer again.
|
||||||
|
if (gallery && galleries[id_gallery] !== gallery){
|
||||||
|
galleries[id_gallery] = gallery;
|
||||||
|
if(galleryObservers[id_gallery]){
|
||||||
|
galleryObservers[id_gallery].disconnect();
|
||||||
|
}
|
||||||
|
let prevSelectedIndex = selected_gallery_index();
|
||||||
|
galleryObservers[id_gallery] = new MutationObserver(function (){
|
||||||
|
let galleryButtons = gradioApp().querySelectorAll('#'+id_gallery+' .gallery-item')
|
||||||
|
let galleryBtnSelected = gradioApp().querySelector('#'+id_gallery+' .gallery-item.\\!ring-2')
|
||||||
|
if (prevSelectedIndex !== -1 && galleryButtons.length>prevSelectedIndex && !galleryBtnSelected) {
|
||||||
|
//automatically re-open previously selected index (if exists)
|
||||||
|
galleryButtons[prevSelectedIndex].click();
|
||||||
|
showGalleryImage();
|
||||||
|
}
|
||||||
|
})
|
||||||
|
galleryObservers[id_gallery].observe( gallery, { childList:true, subtree:false })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
onUiUpdate(function(){
|
onUiUpdate(function(){
|
||||||
check_progressbar('txt2img', 'txt2img_progressbar', 'txt2img_progress_span', 'txt2img_skip', 'txt2img_interrupt', 'txt2img_preview', 'txt2img_gallery')
|
check_progressbar('txt2img', 'txt2img_progressbar', 'txt2img_progress_span', 'txt2img_skip', 'txt2img_interrupt', 'txt2img_preview', 'txt2img_gallery')
|
||||||
check_progressbar('img2img', 'img2img_progressbar', 'img2img_progress_span', 'img2img_skip', 'img2img_interrupt', 'img2img_preview', 'img2img_gallery')
|
check_progressbar('img2img', 'img2img_progressbar', 'img2img_progress_span', 'img2img_skip', 'img2img_interrupt', 'img2img_preview', 'img2img_gallery')
|
||||||
|
|
|
@ -187,12 +187,10 @@ onUiUpdate(function(){
|
||||||
if (!txt2img_textarea) {
|
if (!txt2img_textarea) {
|
||||||
txt2img_textarea = gradioApp().querySelector("#txt2img_prompt > label > textarea");
|
txt2img_textarea = gradioApp().querySelector("#txt2img_prompt > label > textarea");
|
||||||
txt2img_textarea?.addEventListener("input", () => update_token_counter("txt2img_token_button"));
|
txt2img_textarea?.addEventListener("input", () => update_token_counter("txt2img_token_button"));
|
||||||
txt2img_textarea?.addEventListener("keyup", (event) => submit_prompt(event, "txt2img_generate"));
|
|
||||||
}
|
}
|
||||||
if (!img2img_textarea) {
|
if (!img2img_textarea) {
|
||||||
img2img_textarea = gradioApp().querySelector("#img2img_prompt > label > textarea");
|
img2img_textarea = gradioApp().querySelector("#img2img_prompt > label > textarea");
|
||||||
img2img_textarea?.addEventListener("input", () => update_token_counter("img2img_token_button"));
|
img2img_textarea?.addEventListener("input", () => update_token_counter("img2img_token_button"));
|
||||||
img2img_textarea?.addEventListener("keyup", (event) => submit_prompt(event, "img2img_generate"));
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -220,14 +218,6 @@ function update_token_counter(button_id) {
|
||||||
token_timeout = setTimeout(() => gradioApp().getElementById(button_id)?.click(), wait_time);
|
token_timeout = setTimeout(() => gradioApp().getElementById(button_id)?.click(), wait_time);
|
||||||
}
|
}
|
||||||
|
|
||||||
function submit_prompt(event, generate_button_id) {
|
|
||||||
if (event.altKey && event.keyCode === 13) {
|
|
||||||
event.preventDefault();
|
|
||||||
gradioApp().getElementById(generate_button_id).click();
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
function restart_reload(){
|
function restart_reload(){
|
||||||
document.body.innerHTML='<h1 style="font-family:monospace;margin-top:20%;color:lightgray;text-align:center;">Reloading...</h1>';
|
document.body.innerHTML='<h1 style="font-family:monospace;margin-top:20%;color:lightgray;text-align:center;">Reloading...</h1>';
|
||||||
setTimeout(function(){location.reload()},2000)
|
setTimeout(function(){location.reload()},2000)
|
||||||
|
|
|
@ -9,6 +9,7 @@ import platform
|
||||||
dir_repos = "repositories"
|
dir_repos = "repositories"
|
||||||
python = sys.executable
|
python = sys.executable
|
||||||
git = os.environ.get('GIT', "git")
|
git = os.environ.get('GIT', "git")
|
||||||
|
index_url = os.environ.get('INDEX_URL',"")
|
||||||
|
|
||||||
|
|
||||||
def extract_arg(args, name):
|
def extract_arg(args, name):
|
||||||
|
@ -57,7 +58,7 @@ def run_python(code, desc=None, errdesc=None):
|
||||||
|
|
||||||
|
|
||||||
def run_pip(args, desc=None):
|
def run_pip(args, desc=None):
|
||||||
return run(f'"{python}" -m pip {args} --prefer-binary', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}")
|
return run(f'"{python}" -m pip {args} --prefer-binary{f' --index-url {index_url}' if index_url!='' else ''}', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}")
|
||||||
|
|
||||||
|
|
||||||
def check_run_python(code):
|
def check_run_python(code):
|
||||||
|
|
|
@ -182,7 +182,21 @@ def attention_CrossAttention_forward(self, x, context=None, mask=None):
|
||||||
return self.to_out(out)
|
return self.to_out(out)
|
||||||
|
|
||||||
|
|
||||||
def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
def stack_conds(conds):
|
||||||
|
if len(conds) == 1:
|
||||||
|
return torch.stack(conds)
|
||||||
|
|
||||||
|
# same as in reconstruct_multicond_batch
|
||||||
|
token_count = max([x.shape[0] for x in conds])
|
||||||
|
for i in range(len(conds)):
|
||||||
|
if conds[i].shape[0] != token_count:
|
||||||
|
last_vector = conds[i][-1:]
|
||||||
|
last_vector_repeated = last_vector.repeat([token_count - conds[i].shape[0], 1])
|
||||||
|
conds[i] = torch.vstack([conds[i], last_vector_repeated])
|
||||||
|
|
||||||
|
return torch.stack(conds)
|
||||||
|
|
||||||
|
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
||||||
assert hypernetwork_name, 'hypernetwork not selected'
|
assert hypernetwork_name, 'hypernetwork not selected'
|
||||||
|
|
||||||
path = shared.hypernetworks.get(hypernetwork_name, None)
|
path = shared.hypernetworks.get(hypernetwork_name, None)
|
||||||
|
@ -211,7 +225,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
|
||||||
|
|
||||||
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
||||||
with torch.autocast("cuda"):
|
with torch.autocast("cuda"):
|
||||||
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=512, height=512, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True)
|
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=512, height=512, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size)
|
||||||
|
|
||||||
if unload:
|
if unload:
|
||||||
shared.sd_model.cond_stage_model.to(devices.cpu)
|
shared.sd_model.cond_stage_model.to(devices.cpu)
|
||||||
|
@ -235,7 +249,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
|
||||||
optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate)
|
optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate)
|
||||||
|
|
||||||
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
|
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
|
||||||
for i, entry in pbar:
|
for i, entries in pbar:
|
||||||
hypernetwork.step = i + ititial_step
|
hypernetwork.step = i + ititial_step
|
||||||
|
|
||||||
scheduler.apply(optimizer, hypernetwork.step)
|
scheduler.apply(optimizer, hypernetwork.step)
|
||||||
|
@ -246,11 +260,12 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
|
||||||
break
|
break
|
||||||
|
|
||||||
with torch.autocast("cuda"):
|
with torch.autocast("cuda"):
|
||||||
cond = entry.cond.to(devices.device)
|
c = stack_conds([entry.cond for entry in entries]).to(devices.device)
|
||||||
x = entry.latent.to(devices.device)
|
# c = torch.vstack([entry.cond for entry in entries]).to(devices.device)
|
||||||
loss = shared.sd_model(x.unsqueeze(0), cond)[0]
|
x = torch.stack([entry.latent for entry in entries]).to(devices.device)
|
||||||
|
loss = shared.sd_model(x, c)[0]
|
||||||
del x
|
del x
|
||||||
del cond
|
del c
|
||||||
|
|
||||||
losses[hypernetwork.step % losses.shape[0]] = loss.item()
|
losses[hypernetwork.step % losses.shape[0]] = loss.item()
|
||||||
|
|
||||||
|
@ -292,7 +307,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
|
||||||
p.width = preview_width
|
p.width = preview_width
|
||||||
p.height = preview_height
|
p.height = preview_height
|
||||||
else:
|
else:
|
||||||
p.prompt = entry.cond_text
|
p.prompt = entries[0].cond_text
|
||||||
p.steps = 20
|
p.steps = 20
|
||||||
|
|
||||||
preview_text = p.prompt
|
preview_text = p.prompt
|
||||||
|
@ -315,7 +330,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
|
||||||
<p>
|
<p>
|
||||||
Loss: {losses.mean():.7f}<br/>
|
Loss: {losses.mean():.7f}<br/>
|
||||||
Step: {hypernetwork.step}<br/>
|
Step: {hypernetwork.step}<br/>
|
||||||
Last prompt: {html.escape(entry.cond_text)}<br/>
|
Last prompt: {html.escape(entries[0].cond_text)}<br/>
|
||||||
Last saved embedding: {html.escape(last_saved_file)}<br/>
|
Last saved embedding: {html.escape(last_saved_file)}<br/>
|
||||||
Last saved image: {html.escape(last_saved_image)}<br/>
|
Last saved image: {html.escape(last_saved_image)}<br/>
|
||||||
</p>
|
</p>
|
||||||
|
|
|
@ -1,5 +1,7 @@
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
|
|
||||||
def traverse_all_files(output_dir, image_list, curr_dir=None):
|
def traverse_all_files(output_dir, image_list, curr_dir=None):
|
||||||
curr_path = output_dir if curr_dir is None else os.path.join(output_dir, curr_dir)
|
curr_path = output_dir if curr_dir is None else os.path.join(output_dir, curr_dir)
|
||||||
try:
|
try:
|
||||||
|
@ -42,21 +44,33 @@ def get_recent_images(dir_name, page_index, step, image_index, tabname):
|
||||||
hidden = os.path.join(dir_name, current_file)
|
hidden = os.path.join(dir_name, current_file)
|
||||||
return [os.path.join(dir_name, file) for file in image_list], page_index, image_list, current_file, hidden, ""
|
return [os.path.join(dir_name, file) for file in image_list], page_index, image_list, current_file, hidden, ""
|
||||||
|
|
||||||
|
|
||||||
def first_page_click(dir_name, page_index, image_index, tabname):
|
def first_page_click(dir_name, page_index, image_index, tabname):
|
||||||
return get_recent_images(dir_name, 1, 0, image_index, tabname)
|
return get_recent_images(dir_name, 1, 0, image_index, tabname)
|
||||||
|
|
||||||
|
|
||||||
def end_page_click(dir_name, page_index, image_index, tabname):
|
def end_page_click(dir_name, page_index, image_index, tabname):
|
||||||
return get_recent_images(dir_name, -1, 0, image_index, tabname)
|
return get_recent_images(dir_name, -1, 0, image_index, tabname)
|
||||||
|
|
||||||
|
|
||||||
def prev_page_click(dir_name, page_index, image_index, tabname):
|
def prev_page_click(dir_name, page_index, image_index, tabname):
|
||||||
return get_recent_images(dir_name, page_index, -1, image_index, tabname)
|
return get_recent_images(dir_name, page_index, -1, image_index, tabname)
|
||||||
|
|
||||||
|
|
||||||
def next_page_click(dir_name, page_index, image_index, tabname):
|
def next_page_click(dir_name, page_index, image_index, tabname):
|
||||||
return get_recent_images(dir_name, page_index, 1, image_index, tabname)
|
return get_recent_images(dir_name, page_index, 1, image_index, tabname)
|
||||||
|
|
||||||
|
|
||||||
def page_index_change(dir_name, page_index, image_index, tabname):
|
def page_index_change(dir_name, page_index, image_index, tabname):
|
||||||
return get_recent_images(dir_name, page_index, 0, image_index, tabname)
|
return get_recent_images(dir_name, page_index, 0, image_index, tabname)
|
||||||
|
|
||||||
|
|
||||||
def show_image_info(num, image_path, filenames):
|
def show_image_info(num, image_path, filenames):
|
||||||
# print(f"select image {num}")
|
# print(f"select image {num}")
|
||||||
file = filenames[int(num)]
|
file = filenames[int(num)]
|
||||||
return file, num, os.path.join(image_path, file)
|
return file, num, os.path.join(image_path, file)
|
||||||
|
|
||||||
|
|
||||||
def delete_image(delete_num, tabname, dir_name, name, page_index, filenames, image_index):
|
def delete_image(delete_num, tabname, dir_name, name, page_index, filenames, image_index):
|
||||||
if name == "":
|
if name == "":
|
||||||
return filenames, delete_num
|
return filenames, delete_num
|
||||||
|
@ -81,15 +95,18 @@ def delete_image(delete_num, tabname, dir_name, name, page_index, filenames, ima
|
||||||
i += 1
|
i += 1
|
||||||
return new_file_list, 1
|
return new_file_list, 1
|
||||||
|
|
||||||
|
|
||||||
def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict):
|
def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict):
|
||||||
if tabname == "txt2img":
|
if opts.outdir_samples != "":
|
||||||
|
dir_name = opts.outdir_samples
|
||||||
|
elif tabname == "txt2img":
|
||||||
dir_name = opts.outdir_txt2img_samples
|
dir_name = opts.outdir_txt2img_samples
|
||||||
elif tabname == "img2img":
|
elif tabname == "img2img":
|
||||||
dir_name = opts.outdir_img2img_samples
|
dir_name = opts.outdir_img2img_samples
|
||||||
elif tabname == "extras":
|
elif tabname == "extras":
|
||||||
dir_name = opts.outdir_extras_samples
|
dir_name = opts.outdir_extras_samples
|
||||||
d = dir_name.split("/")
|
d = dir_name.split("/")
|
||||||
dir_name = d[0]
|
dir_name = "/" if dir_name.startswith("/") else d[0]
|
||||||
for p in d[1:]:
|
for p in d[1:]:
|
||||||
dir_name = os.path.join(dir_name, p)
|
dir_name = os.path.join(dir_name, p)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
|
@ -126,7 +143,6 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict):
|
||||||
info1 = gr.Textbox(visible=False)
|
info1 = gr.Textbox(visible=False)
|
||||||
info2 = gr.Textbox(visible=False)
|
info2 = gr.Textbox(visible=False)
|
||||||
|
|
||||||
|
|
||||||
# turn pages
|
# turn pages
|
||||||
gallery_inputs = [img_path, page_index, image_index, tabname_box]
|
gallery_inputs = [img_path, page_index, image_index, tabname_box]
|
||||||
gallery_outputs = [history_gallery, page_index, filenames, img_file_name, hidden, img_file_name]
|
gallery_outputs = [history_gallery, page_index, filenames, img_file_name, hidden, img_file_name]
|
||||||
|
|
|
@ -140,7 +140,7 @@ class Processed:
|
||||||
self.sampler_noise_scheduler_override = p.sampler_noise_scheduler_override
|
self.sampler_noise_scheduler_override = p.sampler_noise_scheduler_override
|
||||||
self.prompt = self.prompt if type(self.prompt) != list else self.prompt[0]
|
self.prompt = self.prompt if type(self.prompt) != list else self.prompt[0]
|
||||||
self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0]
|
self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0]
|
||||||
self.seed = int(self.seed if type(self.seed) != list else self.seed[0])
|
self.seed = int(self.seed if type(self.seed) != list else self.seed[0]) if self.seed is not None else -1
|
||||||
self.subseed = int(self.subseed if type(self.subseed) != list else self.subseed[0]) if self.subseed is not None else -1
|
self.subseed = int(self.subseed if type(self.subseed) != list else self.subseed[0]) if self.subseed is not None else -1
|
||||||
|
|
||||||
self.all_prompts = all_prompts or [self.prompt]
|
self.all_prompts = all_prompts or [self.prompt]
|
||||||
|
|
|
@ -24,11 +24,12 @@ class DatasetEntry:
|
||||||
|
|
||||||
|
|
||||||
class PersonalizedBase(Dataset):
|
class PersonalizedBase(Dataset):
|
||||||
def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None, include_cond=False):
|
def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None, include_cond=False, batch_size=1):
|
||||||
re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None
|
re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None
|
||||||
|
|
||||||
self.placeholder_token = placeholder_token
|
self.placeholder_token = placeholder_token
|
||||||
|
|
||||||
|
self.batch_size = batch_size
|
||||||
self.width = width
|
self.width = width
|
||||||
self.height = height
|
self.height = height
|
||||||
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
|
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
|
||||||
|
@ -78,14 +79,14 @@ class PersonalizedBase(Dataset):
|
||||||
|
|
||||||
if include_cond:
|
if include_cond:
|
||||||
entry.cond_text = self.create_text(filename_text)
|
entry.cond_text = self.create_text(filename_text)
|
||||||
entry.cond = cond_model([entry.cond_text]).to(devices.cpu)
|
entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0)
|
||||||
|
|
||||||
self.dataset.append(entry)
|
self.dataset.append(entry)
|
||||||
|
|
||||||
assert len(self.dataset) > 1, "No images have been found in the dataset."
|
assert len(self.dataset) > 1, "No images have been found in the dataset."
|
||||||
self.length = len(self.dataset) * repeats
|
self.length = len(self.dataset) * repeats // batch_size
|
||||||
|
|
||||||
self.initial_indexes = np.arange(self.length) % len(self.dataset)
|
self.initial_indexes = np.arange(len(self.dataset))
|
||||||
self.indexes = None
|
self.indexes = None
|
||||||
self.shuffle()
|
self.shuffle()
|
||||||
|
|
||||||
|
@ -102,13 +103,19 @@ class PersonalizedBase(Dataset):
|
||||||
return self.length
|
return self.length
|
||||||
|
|
||||||
def __getitem__(self, i):
|
def __getitem__(self, i):
|
||||||
if i % len(self.dataset) == 0:
|
res = []
|
||||||
|
|
||||||
|
for j in range(self.batch_size):
|
||||||
|
position = i * self.batch_size + j
|
||||||
|
if position % len(self.indexes) == 0:
|
||||||
self.shuffle()
|
self.shuffle()
|
||||||
|
|
||||||
index = self.indexes[i % len(self.indexes)]
|
index = self.indexes[position % len(self.indexes)]
|
||||||
entry = self.dataset[index]
|
entry = self.dataset[index]
|
||||||
|
|
||||||
if entry.cond is None:
|
if entry.cond is None:
|
||||||
entry.cond_text = self.create_text(entry.filename_text)
|
entry.cond_text = self.create_text(entry.filename_text)
|
||||||
|
|
||||||
return entry
|
res.append(entry)
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
|
@ -199,7 +199,7 @@ def write_loss(log_directory, filename, step, epoch_len, values):
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
||||||
assert embedding_name, 'embedding not selected'
|
assert embedding_name, 'embedding not selected'
|
||||||
|
|
||||||
shared.state.textinfo = "Initializing textual inversion training..."
|
shared.state.textinfo = "Initializing textual inversion training..."
|
||||||
|
@ -231,7 +231,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
|
||||||
|
|
||||||
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
||||||
with torch.autocast("cuda"):
|
with torch.autocast("cuda"):
|
||||||
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file)
|
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file, batch_size=batch_size)
|
||||||
|
|
||||||
hijack = sd_hijack.model_hijack
|
hijack = sd_hijack.model_hijack
|
||||||
|
|
||||||
|
@ -251,7 +251,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
|
||||||
optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate)
|
optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate)
|
||||||
|
|
||||||
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
|
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
|
||||||
for i, entry in pbar:
|
for i, entries in pbar:
|
||||||
embedding.step = i + ititial_step
|
embedding.step = i + ititial_step
|
||||||
|
|
||||||
scheduler.apply(optimizer, embedding.step)
|
scheduler.apply(optimizer, embedding.step)
|
||||||
|
@ -262,10 +262,9 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
|
||||||
break
|
break
|
||||||
|
|
||||||
with torch.autocast("cuda"):
|
with torch.autocast("cuda"):
|
||||||
c = cond_model([entry.cond_text])
|
c = cond_model([entry.cond_text for entry in entries])
|
||||||
|
x = torch.stack([entry.latent for entry in entries]).to(devices.device)
|
||||||
x = entry.latent.to(devices.device)
|
loss = shared.sd_model(x, c)[0]
|
||||||
loss = shared.sd_model(x.unsqueeze(0), c)[0]
|
|
||||||
del x
|
del x
|
||||||
|
|
||||||
losses[embedding.step % losses.shape[0]] = loss.item()
|
losses[embedding.step % losses.shape[0]] = loss.item()
|
||||||
|
@ -307,7 +306,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
|
||||||
p.width = preview_width
|
p.width = preview_width
|
||||||
p.height = preview_height
|
p.height = preview_height
|
||||||
else:
|
else:
|
||||||
p.prompt = entry.cond_text
|
p.prompt = entries[0].cond_text
|
||||||
p.steps = 20
|
p.steps = 20
|
||||||
p.width = training_width
|
p.width = training_width
|
||||||
p.height = training_height
|
p.height = training_height
|
||||||
|
@ -348,7 +347,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
|
||||||
<p>
|
<p>
|
||||||
Loss: {losses.mean():.7f}<br/>
|
Loss: {losses.mean():.7f}<br/>
|
||||||
Step: {embedding.step}<br/>
|
Step: {embedding.step}<br/>
|
||||||
Last prompt: {html.escape(entry.cond_text)}<br/>
|
Last prompt: {html.escape(entries[0].cond_text)}<br/>
|
||||||
Last saved embedding: {html.escape(last_saved_file)}<br/>
|
Last saved embedding: {html.escape(last_saved_file)}<br/>
|
||||||
Last saved image: {html.escape(last_saved_image)}<br/>
|
Last saved image: {html.escape(last_saved_image)}<br/>
|
||||||
</p>
|
</p>
|
||||||
|
|
|
@ -433,7 +433,10 @@ def create_toprow(is_img2img):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column(scale=80):
|
with gr.Column(scale=80):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, placeholder="Prompt", lines=2)
|
prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=2,
|
||||||
|
placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)"
|
||||||
|
)
|
||||||
|
|
||||||
with gr.Column(scale=1, elem_id="roll_col"):
|
with gr.Column(scale=1, elem_id="roll_col"):
|
||||||
roll = gr.Button(value=art_symbol, elem_id="roll", visible=len(shared.artist_db.artists) > 0)
|
roll = gr.Button(value=art_symbol, elem_id="roll", visible=len(shared.artist_db.artists) > 0)
|
||||||
paste = gr.Button(value=paste_symbol, elem_id="paste")
|
paste = gr.Button(value=paste_symbol, elem_id="paste")
|
||||||
|
@ -446,7 +449,10 @@ def create_toprow(is_img2img):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column(scale=8):
|
with gr.Column(scale=8):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
negative_prompt = gr.Textbox(label="Negative prompt", elem_id="negative_prompt", show_label=False, placeholder="Negative prompt", lines=2)
|
negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=2,
|
||||||
|
placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)"
|
||||||
|
)
|
||||||
|
|
||||||
with gr.Column(scale=1, elem_id="roll_col"):
|
with gr.Column(scale=1, elem_id="roll_col"):
|
||||||
sh = gr.Button(elem_id="sh", visible=True)
|
sh = gr.Button(elem_id="sh", visible=True)
|
||||||
|
|
||||||
|
@ -1090,7 +1096,7 @@ def create_ui(wrap_gradio_gpu_call):
|
||||||
"i2i":img2img_paste_fields
|
"i2i":img2img_paste_fields
|
||||||
}
|
}
|
||||||
|
|
||||||
#images_history = img_his.create_history_tabs(gr, opts, wrap_gradio_call(modules.extras.run_pnginfo), images_history_switch_dict)
|
images_history = img_his.create_history_tabs(gr, opts, wrap_gradio_call(modules.extras.run_pnginfo), images_history_switch_dict)
|
||||||
|
|
||||||
with gr.Blocks() as modelmerger_interface:
|
with gr.Blocks() as modelmerger_interface:
|
||||||
with gr.Row().style(equal_height=False):
|
with gr.Row().style(equal_height=False):
|
||||||
|
@ -1166,6 +1172,7 @@ def create_ui(wrap_gradio_gpu_call):
|
||||||
train_embedding_name = gr.Dropdown(label='Embedding', choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
|
train_embedding_name = gr.Dropdown(label='Embedding', choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
|
||||||
train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', choices=[x for x in shared.hypernetworks.keys()])
|
train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', choices=[x for x in shared.hypernetworks.keys()])
|
||||||
learn_rate = gr.Textbox(label='Learning rate', placeholder="Learning rate", value="0.005")
|
learn_rate = gr.Textbox(label='Learning rate', placeholder="Learning rate", value="0.005")
|
||||||
|
batch_size = gr.Number(label='Batch size', value=1, precision=0)
|
||||||
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
|
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
|
||||||
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
|
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
|
||||||
template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"))
|
template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"))
|
||||||
|
@ -1244,6 +1251,7 @@ def create_ui(wrap_gradio_gpu_call):
|
||||||
inputs=[
|
inputs=[
|
||||||
train_embedding_name,
|
train_embedding_name,
|
||||||
learn_rate,
|
learn_rate,
|
||||||
|
batch_size,
|
||||||
dataset_directory,
|
dataset_directory,
|
||||||
log_directory,
|
log_directory,
|
||||||
training_width,
|
training_width,
|
||||||
|
@ -1268,6 +1276,7 @@ def create_ui(wrap_gradio_gpu_call):
|
||||||
inputs=[
|
inputs=[
|
||||||
train_hypernetwork_name,
|
train_hypernetwork_name,
|
||||||
learn_rate,
|
learn_rate,
|
||||||
|
batch_size,
|
||||||
dataset_directory,
|
dataset_directory,
|
||||||
log_directory,
|
log_directory,
|
||||||
steps,
|
steps,
|
||||||
|
@ -1487,7 +1496,7 @@ Requested path was: {f}
|
||||||
(img2img_interface, "img2img", "img2img"),
|
(img2img_interface, "img2img", "img2img"),
|
||||||
(extras_interface, "Extras", "extras"),
|
(extras_interface, "Extras", "extras"),
|
||||||
(pnginfo_interface, "PNG Info", "pnginfo"),
|
(pnginfo_interface, "PNG Info", "pnginfo"),
|
||||||
#(images_history, "History", "images_history"),
|
(images_history, "History", "images_history"),
|
||||||
(modelmerger_interface, "Checkpoint Merger", "modelmerger"),
|
(modelmerger_interface, "Checkpoint Merger", "modelmerger"),
|
||||||
(train_interface, "Train", "ti"),
|
(train_interface, "Train", "ti"),
|
||||||
(settings_interface, "Settings", "settings"),
|
(settings_interface, "Settings", "settings"),
|
||||||
|
|
|
@ -50,9 +50,9 @@ document.addEventListener("DOMContentLoaded", function() {
|
||||||
document.addEventListener('keydown', function(e) {
|
document.addEventListener('keydown', function(e) {
|
||||||
var handled = false;
|
var handled = false;
|
||||||
if (e.key !== undefined) {
|
if (e.key !== undefined) {
|
||||||
if((e.key == "Enter" && (e.metaKey || e.ctrlKey))) handled = true;
|
if((e.key == "Enter" && (e.metaKey || e.ctrlKey || e.altKey))) handled = true;
|
||||||
} else if (e.keyCode !== undefined) {
|
} else if (e.keyCode !== undefined) {
|
||||||
if((e.keyCode == 13 && (e.metaKey || e.ctrlKey))) handled = true;
|
if((e.keyCode == 13 && (e.metaKey || e.ctrlKey || e.altKey))) handled = true;
|
||||||
}
|
}
|
||||||
if (handled) {
|
if (handled) {
|
||||||
button = get_uiCurrentTabContent().querySelector('button[id$=_generate]');
|
button = get_uiCurrentTabContent().querySelector('button[id$=_generate]');
|
||||||
|
|
Loading…
Reference in a new issue