Merge branch 'AUTOMATIC1111:master' into master
This commit is contained in:
commit
d41ac174e2
26 changed files with 407 additions and 171 deletions
36
.github/workflows/on_pull_request.yaml
vendored
Normal file
36
.github/workflows/on_pull_request.yaml
vendored
Normal file
|
@ -0,0 +1,36 @@
|
||||||
|
# See https://github.com/actions/starter-workflows/blob/1067f16ad8a1eac328834e4b0ae24f7d206f810d/ci/pylint.yml for original reference file
|
||||||
|
name: Run Linting/Formatting on Pull Requests
|
||||||
|
|
||||||
|
on:
|
||||||
|
- push
|
||||||
|
- pull_request
|
||||||
|
# See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#onpull_requestpull_request_targetbranchesbranches-ignore for syntax docs
|
||||||
|
# if you want to filter out branches, delete the `- pull_request` and uncomment these lines :
|
||||||
|
# pull_request:
|
||||||
|
# branches:
|
||||||
|
# - master
|
||||||
|
# branches-ignore:
|
||||||
|
# - development
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
lint:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Checkout Code
|
||||||
|
uses: actions/checkout@v3
|
||||||
|
- name: Set up Python 3.10
|
||||||
|
uses: actions/setup-python@v3
|
||||||
|
with:
|
||||||
|
python-version: 3.10.6
|
||||||
|
- name: Install PyLint
|
||||||
|
run: |
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
pip install pylint
|
||||||
|
# This lets PyLint check to see if it can resolve imports
|
||||||
|
- name: Install dependencies
|
||||||
|
run : |
|
||||||
|
export COMMANDLINE_ARGS="--skip-torch-cuda-test --exit"
|
||||||
|
python launch.py
|
||||||
|
- name: Analysing the code with pylint
|
||||||
|
run: |
|
||||||
|
pylint $(git ls-files '*.py')
|
3
.pylintrc
Normal file
3
.pylintrc
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
# See https://pylint.pycqa.org/en/latest/user_guide/messages/message_control.html
|
||||||
|
[MESSAGES CONTROL]
|
||||||
|
disable=C,R,W,E,I
|
4
javascript/dragdrop.js
vendored
4
javascript/dragdrop.js
vendored
|
@ -43,7 +43,7 @@ function dropReplaceImage( imgWrap, files ) {
|
||||||
window.document.addEventListener('dragover', e => {
|
window.document.addEventListener('dragover', e => {
|
||||||
const target = e.composedPath()[0];
|
const target = e.composedPath()[0];
|
||||||
const imgWrap = target.closest('[data-testid="image"]');
|
const imgWrap = target.closest('[data-testid="image"]');
|
||||||
if ( !imgWrap && target.placeholder != "Prompt") {
|
if ( !imgWrap && target.placeholder.indexOf("Prompt") == -1) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
e.stopPropagation();
|
e.stopPropagation();
|
||||||
|
@ -53,7 +53,7 @@ window.document.addEventListener('dragover', e => {
|
||||||
|
|
||||||
window.document.addEventListener('drop', e => {
|
window.document.addEventListener('drop', e => {
|
||||||
const target = e.composedPath()[0];
|
const target = e.composedPath()[0];
|
||||||
if (target.placeholder === "Prompt") {
|
if (target.placeholder.indexOf("Prompt") == -1) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
const imgWrap = target.closest('[data-testid="image"]');
|
const imgWrap = target.closest('[data-testid="image"]');
|
||||||
|
|
|
@ -2,6 +2,8 @@ addEventListener('keydown', (event) => {
|
||||||
let target = event.originalTarget || event.composedPath()[0];
|
let target = event.originalTarget || event.composedPath()[0];
|
||||||
if (!target.hasAttribute("placeholder")) return;
|
if (!target.hasAttribute("placeholder")) return;
|
||||||
if (!target.placeholder.toLowerCase().includes("prompt")) return;
|
if (!target.placeholder.toLowerCase().includes("prompt")) return;
|
||||||
|
if (! (event.metaKey || event.ctrlKey)) return;
|
||||||
|
|
||||||
|
|
||||||
let plus = "ArrowUp"
|
let plus = "ArrowUp"
|
||||||
let minus = "ArrowDown"
|
let minus = "ArrowDown"
|
||||||
|
|
|
@ -16,6 +16,8 @@ titles = {
|
||||||
"\u{1f3a8}": "Add a random artist to the prompt.",
|
"\u{1f3a8}": "Add a random artist to the prompt.",
|
||||||
"\u2199\ufe0f": "Read generation parameters from prompt or last generation if prompt is empty into user interface.",
|
"\u2199\ufe0f": "Read generation parameters from prompt or last generation if prompt is empty into user interface.",
|
||||||
"\u{1f4c2}": "Open images output directory",
|
"\u{1f4c2}": "Open images output directory",
|
||||||
|
"\u{1f4be}": "Save style",
|
||||||
|
"\u{1f4cb}": "Apply selected styles to current prompt",
|
||||||
|
|
||||||
"Inpaint a part of image": "Draw a mask over an image, and the script will regenerate the masked area with content according to prompt",
|
"Inpaint a part of image": "Draw a mask over an image, and the script will regenerate the masked area with content according to prompt",
|
||||||
"SD upscale": "Upscale image normally, split result into tiles, improve each tile using img2img, merge whole image back",
|
"SD upscale": "Upscale image normally, split result into tiles, improve each tile using img2img, merge whole image back",
|
||||||
|
|
|
@ -2,7 +2,7 @@ window.onload = (function(){
|
||||||
window.addEventListener('drop', e => {
|
window.addEventListener('drop', e => {
|
||||||
const target = e.composedPath()[0];
|
const target = e.composedPath()[0];
|
||||||
const idx = selected_gallery_index();
|
const idx = selected_gallery_index();
|
||||||
if (target.placeholder != "Prompt") return;
|
if (target.placeholder.indexOf("Prompt") == -1) return;
|
||||||
|
|
||||||
let prompt_target = get_tab_index('tabs') == 1 ? "img2img_prompt_image" : "txt2img_prompt_image";
|
let prompt_target = get_tab_index('tabs') == 1 ? "img2img_prompt_image" : "txt2img_prompt_image";
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,7 @@
|
||||||
// code related to showing and updating progressbar shown as the image is being made
|
// code related to showing and updating progressbar shown as the image is being made
|
||||||
global_progressbars = {}
|
global_progressbars = {}
|
||||||
|
galleries = {}
|
||||||
|
galleryObservers = {}
|
||||||
|
|
||||||
function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_skip, id_interrupt, id_preview, id_gallery){
|
function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_skip, id_interrupt, id_preview, id_gallery){
|
||||||
var progressbar = gradioApp().getElementById(id_progressbar)
|
var progressbar = gradioApp().getElementById(id_progressbar)
|
||||||
|
@ -31,21 +33,54 @@ function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_skip
|
||||||
preview.style.width = gallery.clientWidth + "px"
|
preview.style.width = gallery.clientWidth + "px"
|
||||||
preview.style.height = gallery.clientHeight + "px"
|
preview.style.height = gallery.clientHeight + "px"
|
||||||
|
|
||||||
|
//only watch gallery if there is a generation process going on
|
||||||
|
check_gallery(id_gallery);
|
||||||
|
|
||||||
var progressDiv = gradioApp().querySelectorAll('#' + id_progressbar_span).length > 0;
|
var progressDiv = gradioApp().querySelectorAll('#' + id_progressbar_span).length > 0;
|
||||||
if(!progressDiv){
|
if(!progressDiv){
|
||||||
if (skip) {
|
if (skip) {
|
||||||
skip.style.display = "none"
|
skip.style.display = "none"
|
||||||
}
|
}
|
||||||
interrupt.style.display = "none"
|
interrupt.style.display = "none"
|
||||||
|
|
||||||
|
//disconnect observer once generation finished, so user can close selected image if they want
|
||||||
|
if (galleryObservers[id_gallery]) {
|
||||||
|
galleryObservers[id_gallery].disconnect();
|
||||||
|
galleries[id_gallery] = null;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
window.setTimeout(function() { requestMoreProgress(id_part, id_progressbar_span, id_skip, id_interrupt) }, 500)
|
window.setTimeout(function() { requestMoreProgress(id_part, id_progressbar_span, id_skip, id_interrupt) }, 500)
|
||||||
});
|
});
|
||||||
mutationObserver.observe( progressbar, { childList:true, subtree:true })
|
mutationObserver.observe( progressbar, { childList:true, subtree:true })
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function check_gallery(id_gallery){
|
||||||
|
let gallery = gradioApp().getElementById(id_gallery)
|
||||||
|
// if gallery has no change, no need to setting up observer again.
|
||||||
|
if (gallery && galleries[id_gallery] !== gallery){
|
||||||
|
galleries[id_gallery] = gallery;
|
||||||
|
if(galleryObservers[id_gallery]){
|
||||||
|
galleryObservers[id_gallery].disconnect();
|
||||||
|
}
|
||||||
|
let prevSelectedIndex = selected_gallery_index();
|
||||||
|
galleryObservers[id_gallery] = new MutationObserver(function (){
|
||||||
|
let galleryButtons = gradioApp().querySelectorAll('#'+id_gallery+' .gallery-item')
|
||||||
|
let galleryBtnSelected = gradioApp().querySelector('#'+id_gallery+' .gallery-item.\\!ring-2')
|
||||||
|
if (prevSelectedIndex !== -1 && galleryButtons.length>prevSelectedIndex && !galleryBtnSelected) {
|
||||||
|
//automatically re-open previously selected index (if exists)
|
||||||
|
galleryButtons[prevSelectedIndex].click();
|
||||||
|
showGalleryImage();
|
||||||
|
}
|
||||||
|
})
|
||||||
|
galleryObservers[id_gallery].observe( gallery, { childList:true, subtree:false })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
onUiUpdate(function(){
|
onUiUpdate(function(){
|
||||||
check_progressbar('txt2img', 'txt2img_progressbar', 'txt2img_progress_span', 'txt2img_skip', 'txt2img_interrupt', 'txt2img_preview', 'txt2img_gallery')
|
check_progressbar('txt2img', 'txt2img_progressbar', 'txt2img_progress_span', 'txt2img_skip', 'txt2img_interrupt', 'txt2img_preview', 'txt2img_gallery')
|
||||||
check_progressbar('img2img', 'img2img_progressbar', 'img2img_progress_span', 'img2img_skip', 'img2img_interrupt', 'img2img_preview', 'img2img_gallery')
|
check_progressbar('img2img', 'img2img_progressbar', 'img2img_progress_span', 'img2img_skip', 'img2img_interrupt', 'img2img_preview', 'img2img_gallery')
|
||||||
|
|
|
@ -141,7 +141,7 @@ function submit_img2img(){
|
||||||
|
|
||||||
function ask_for_style_name(_, prompt_text, negative_prompt_text) {
|
function ask_for_style_name(_, prompt_text, negative_prompt_text) {
|
||||||
name_ = prompt('Style name:')
|
name_ = prompt('Style name:')
|
||||||
return name_ === null ? [null, null, null]: [name_, prompt_text, negative_prompt_text]
|
return [name_, prompt_text, negative_prompt_text]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -187,12 +187,10 @@ onUiUpdate(function(){
|
||||||
if (!txt2img_textarea) {
|
if (!txt2img_textarea) {
|
||||||
txt2img_textarea = gradioApp().querySelector("#txt2img_prompt > label > textarea");
|
txt2img_textarea = gradioApp().querySelector("#txt2img_prompt > label > textarea");
|
||||||
txt2img_textarea?.addEventListener("input", () => update_token_counter("txt2img_token_button"));
|
txt2img_textarea?.addEventListener("input", () => update_token_counter("txt2img_token_button"));
|
||||||
txt2img_textarea?.addEventListener("keyup", (event) => submit_prompt(event, "txt2img_generate"));
|
|
||||||
}
|
}
|
||||||
if (!img2img_textarea) {
|
if (!img2img_textarea) {
|
||||||
img2img_textarea = gradioApp().querySelector("#img2img_prompt > label > textarea");
|
img2img_textarea = gradioApp().querySelector("#img2img_prompt > label > textarea");
|
||||||
img2img_textarea?.addEventListener("input", () => update_token_counter("img2img_token_button"));
|
img2img_textarea?.addEventListener("input", () => update_token_counter("img2img_token_button"));
|
||||||
img2img_textarea?.addEventListener("keyup", (event) => submit_prompt(event, "img2img_generate"));
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -220,14 +218,6 @@ function update_token_counter(button_id) {
|
||||||
token_timeout = setTimeout(() => gradioApp().getElementById(button_id)?.click(), wait_time);
|
token_timeout = setTimeout(() => gradioApp().getElementById(button_id)?.click(), wait_time);
|
||||||
}
|
}
|
||||||
|
|
||||||
function submit_prompt(event, generate_button_id) {
|
|
||||||
if (event.altKey && event.keyCode === 13) {
|
|
||||||
event.preventDefault();
|
|
||||||
gradioApp().getElementById(generate_button_id).click();
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
function restart_reload(){
|
function restart_reload(){
|
||||||
document.body.innerHTML='<h1 style="font-family:monospace;margin-top:20%;color:lightgray;text-align:center;">Reloading...</h1>';
|
document.body.innerHTML='<h1 style="font-family:monospace;margin-top:20%;color:lightgray;text-align:center;">Reloading...</h1>';
|
||||||
setTimeout(function(){location.reload()},2000)
|
setTimeout(function(){location.reload()},2000)
|
||||||
|
|
|
@ -9,6 +9,7 @@ import platform
|
||||||
dir_repos = "repositories"
|
dir_repos = "repositories"
|
||||||
python = sys.executable
|
python = sys.executable
|
||||||
git = os.environ.get('GIT', "git")
|
git = os.environ.get('GIT', "git")
|
||||||
|
index_url = os.environ.get('INDEX_URL', "")
|
||||||
|
|
||||||
|
|
||||||
def extract_arg(args, name):
|
def extract_arg(args, name):
|
||||||
|
@ -57,7 +58,8 @@ def run_python(code, desc=None, errdesc=None):
|
||||||
|
|
||||||
|
|
||||||
def run_pip(args, desc=None):
|
def run_pip(args, desc=None):
|
||||||
return run(f'"{python}" -m pip {args} --prefer-binary', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}")
|
index_url_line = f' --index-url {index_url}' if index_url != '' else ''
|
||||||
|
return run(f'"{python}" -m pip {args} --prefer-binary{index_url_line}', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}")
|
||||||
|
|
||||||
|
|
||||||
def check_run_python(code):
|
def check_run_python(code):
|
||||||
|
@ -102,6 +104,7 @@ def prepare_enviroment():
|
||||||
args = shlex.split(commandline_args)
|
args = shlex.split(commandline_args)
|
||||||
|
|
||||||
args, skip_torch_cuda_test = extract_arg(args, '--skip-torch-cuda-test')
|
args, skip_torch_cuda_test = extract_arg(args, '--skip-torch-cuda-test')
|
||||||
|
args, reinstall_xformers = extract_arg(args, '--reinstall-xformers')
|
||||||
xformers = '--xformers' in args
|
xformers = '--xformers' in args
|
||||||
deepdanbooru = '--deepdanbooru' in args
|
deepdanbooru = '--deepdanbooru' in args
|
||||||
ngrok = '--ngrok' in args
|
ngrok = '--ngrok' in args
|
||||||
|
@ -126,9 +129,9 @@ def prepare_enviroment():
|
||||||
if not is_installed("clip"):
|
if not is_installed("clip"):
|
||||||
run_pip(f"install {clip_package}", "clip")
|
run_pip(f"install {clip_package}", "clip")
|
||||||
|
|
||||||
if not is_installed("xformers") and xformers and platform.python_version().startswith("3.10"):
|
if (not is_installed("xformers") or reinstall_xformers) and xformers and platform.python_version().startswith("3.10"):
|
||||||
if platform.system() == "Windows":
|
if platform.system() == "Windows":
|
||||||
run_pip("install https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/c/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl", "xformers")
|
run_pip("install -U -I --no-deps https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl", "xformers")
|
||||||
elif platform.system() == "Linux":
|
elif platform.system() == "Linux":
|
||||||
run_pip("install xformers", "xformers")
|
run_pip("install xformers", "xformers")
|
||||||
|
|
||||||
|
|
|
@ -102,7 +102,7 @@ def get_deepbooru_tags_model():
|
||||||
|
|
||||||
tags = dd.project.load_tags_from_project(model_path)
|
tags = dd.project.load_tags_from_project(model_path)
|
||||||
model = dd.project.load_model_from_project(
|
model = dd.project.load_model_from_project(
|
||||||
model_path, compile_model=True
|
model_path, compile_model=False
|
||||||
)
|
)
|
||||||
return model, tags
|
return model, tags
|
||||||
|
|
||||||
|
|
|
@ -182,7 +182,21 @@ def attention_CrossAttention_forward(self, x, context=None, mask=None):
|
||||||
return self.to_out(out)
|
return self.to_out(out)
|
||||||
|
|
||||||
|
|
||||||
def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
def stack_conds(conds):
|
||||||
|
if len(conds) == 1:
|
||||||
|
return torch.stack(conds)
|
||||||
|
|
||||||
|
# same as in reconstruct_multicond_batch
|
||||||
|
token_count = max([x.shape[0] for x in conds])
|
||||||
|
for i in range(len(conds)):
|
||||||
|
if conds[i].shape[0] != token_count:
|
||||||
|
last_vector = conds[i][-1:]
|
||||||
|
last_vector_repeated = last_vector.repeat([token_count - conds[i].shape[0], 1])
|
||||||
|
conds[i] = torch.vstack([conds[i], last_vector_repeated])
|
||||||
|
|
||||||
|
return torch.stack(conds)
|
||||||
|
|
||||||
|
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
||||||
assert hypernetwork_name, 'hypernetwork not selected'
|
assert hypernetwork_name, 'hypernetwork not selected'
|
||||||
|
|
||||||
path = shared.hypernetworks.get(hypernetwork_name, None)
|
path = shared.hypernetworks.get(hypernetwork_name, None)
|
||||||
|
@ -211,7 +225,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
|
||||||
|
|
||||||
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
||||||
with torch.autocast("cuda"):
|
with torch.autocast("cuda"):
|
||||||
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=512, height=512, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True)
|
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=512, height=512, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size)
|
||||||
|
|
||||||
if unload:
|
if unload:
|
||||||
shared.sd_model.cond_stage_model.to(devices.cpu)
|
shared.sd_model.cond_stage_model.to(devices.cpu)
|
||||||
|
@ -235,7 +249,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
|
||||||
optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate)
|
optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate)
|
||||||
|
|
||||||
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
|
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
|
||||||
for i, entry in pbar:
|
for i, entries in pbar:
|
||||||
hypernetwork.step = i + ititial_step
|
hypernetwork.step = i + ititial_step
|
||||||
|
|
||||||
scheduler.apply(optimizer, hypernetwork.step)
|
scheduler.apply(optimizer, hypernetwork.step)
|
||||||
|
@ -246,26 +260,29 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
|
||||||
break
|
break
|
||||||
|
|
||||||
with torch.autocast("cuda"):
|
with torch.autocast("cuda"):
|
||||||
cond = entry.cond.to(devices.device)
|
c = stack_conds([entry.cond for entry in entries]).to(devices.device)
|
||||||
x = entry.latent.to(devices.device)
|
# c = torch.vstack([entry.cond for entry in entries]).to(devices.device)
|
||||||
loss = shared.sd_model(x.unsqueeze(0), cond)[0]
|
x = torch.stack([entry.latent for entry in entries]).to(devices.device)
|
||||||
|
loss = shared.sd_model(x, c)[0]
|
||||||
del x
|
del x
|
||||||
del cond
|
del c
|
||||||
|
|
||||||
losses[hypernetwork.step % losses.shape[0]] = loss.item()
|
losses[hypernetwork.step % losses.shape[0]] = loss.item()
|
||||||
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
mean_loss = losses.mean()
|
||||||
pbar.set_description(f"loss: {losses.mean():.7f}")
|
if torch.isnan(mean_loss):
|
||||||
|
raise RuntimeError("Loss diverged.")
|
||||||
|
pbar.set_description(f"loss: {mean_loss:.7f}")
|
||||||
|
|
||||||
if hypernetwork.step > 0 and hypernetwork_dir is not None and hypernetwork.step % save_hypernetwork_every == 0:
|
if hypernetwork.step > 0 and hypernetwork_dir is not None and hypernetwork.step % save_hypernetwork_every == 0:
|
||||||
last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name}-{hypernetwork.step}.pt')
|
last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name}-{hypernetwork.step}.pt')
|
||||||
hypernetwork.save(last_saved_file)
|
hypernetwork.save(last_saved_file)
|
||||||
|
|
||||||
textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), {
|
textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), {
|
||||||
"loss": f"{losses.mean():.7f}",
|
"loss": f"{mean_loss:.7f}",
|
||||||
"learn_rate": scheduler.learn_rate
|
"learn_rate": scheduler.learn_rate
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -292,7 +309,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
|
||||||
p.width = preview_width
|
p.width = preview_width
|
||||||
p.height = preview_height
|
p.height = preview_height
|
||||||
else:
|
else:
|
||||||
p.prompt = entry.cond_text
|
p.prompt = entries[0].cond_text
|
||||||
p.steps = 20
|
p.steps = 20
|
||||||
|
|
||||||
preview_text = p.prompt
|
preview_text = p.prompt
|
||||||
|
@ -313,9 +330,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
|
||||||
|
|
||||||
shared.state.textinfo = f"""
|
shared.state.textinfo = f"""
|
||||||
<p>
|
<p>
|
||||||
Loss: {losses.mean():.7f}<br/>
|
Loss: {mean_loss:.7f}<br/>
|
||||||
Step: {hypernetwork.step}<br/>
|
Step: {hypernetwork.step}<br/>
|
||||||
Last prompt: {html.escape(entry.cond_text)}<br/>
|
Last prompt: {html.escape(entries[0].cond_text)}<br/>
|
||||||
Last saved embedding: {html.escape(last_saved_file)}<br/>
|
Last saved embedding: {html.escape(last_saved_file)}<br/>
|
||||||
Last saved image: {html.escape(last_saved_image)}<br/>
|
Last saved image: {html.escape(last_saved_image)}<br/>
|
||||||
</p>
|
</p>
|
||||||
|
|
|
@ -97,14 +97,16 @@ def delete_image(delete_num, tabname, dir_name, name, page_index, filenames, ima
|
||||||
|
|
||||||
|
|
||||||
def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict):
|
def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict):
|
||||||
if tabname == "txt2img":
|
if opts.outdir_samples != "":
|
||||||
|
dir_name = opts.outdir_samples
|
||||||
|
elif tabname == "txt2img":
|
||||||
dir_name = opts.outdir_txt2img_samples
|
dir_name = opts.outdir_txt2img_samples
|
||||||
elif tabname == "img2img":
|
elif tabname == "img2img":
|
||||||
dir_name = opts.outdir_img2img_samples
|
dir_name = opts.outdir_img2img_samples
|
||||||
elif tabname == "extras":
|
elif tabname == "extras":
|
||||||
dir_name = opts.outdir_extras_samples
|
dir_name = opts.outdir_extras_samples
|
||||||
d = dir_name.split("/")
|
d = dir_name.split("/")
|
||||||
dir_name = d[0]
|
dir_name = "/" if dir_name.startswith("/") else d[0]
|
||||||
for p in d[1:]:
|
for p in d[1:]:
|
||||||
dir_name = os.path.join(dir_name, p)
|
dir_name = os.path.join(dir_name, p)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
|
|
|
@ -140,7 +140,7 @@ class Processed:
|
||||||
self.sampler_noise_scheduler_override = p.sampler_noise_scheduler_override
|
self.sampler_noise_scheduler_override = p.sampler_noise_scheduler_override
|
||||||
self.prompt = self.prompt if type(self.prompt) != list else self.prompt[0]
|
self.prompt = self.prompt if type(self.prompt) != list else self.prompt[0]
|
||||||
self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0]
|
self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0]
|
||||||
self.seed = int(self.seed if type(self.seed) != list else self.seed[0])
|
self.seed = int(self.seed if type(self.seed) != list else self.seed[0]) if self.seed is not None else -1
|
||||||
self.subseed = int(self.subseed if type(self.subseed) != list else self.subseed[0]) if self.subseed is not None else -1
|
self.subseed = int(self.subseed if type(self.subseed) != list else self.subseed[0]) if self.subseed is not None else -1
|
||||||
|
|
||||||
self.all_prompts = all_prompts or [self.prompt]
|
self.all_prompts = all_prompts or [self.prompt]
|
||||||
|
@ -528,7 +528,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||||
firstphase_height_truncated = int(scale * self.height)
|
firstphase_height_truncated = int(scale * self.height)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
self.extra_generation_params["First pass size"] = f"{self.firstphase_width}x{self.firstphase_height}"
|
|
||||||
|
|
||||||
width_ratio = self.width / self.firstphase_width
|
width_ratio = self.width / self.firstphase_width
|
||||||
height_ratio = self.height / self.firstphase_height
|
height_ratio = self.height / self.firstphase_height
|
||||||
|
@ -540,6 +539,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||||
firstphase_width_truncated = self.firstphase_height * self.width / self.height
|
firstphase_width_truncated = self.firstphase_height * self.width / self.height
|
||||||
firstphase_height_truncated = self.firstphase_height
|
firstphase_height_truncated = self.firstphase_height
|
||||||
|
|
||||||
|
self.extra_generation_params["First pass size"] = f"{self.firstphase_width}x{self.firstphase_height}"
|
||||||
self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f
|
self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f
|
||||||
self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f
|
self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f
|
||||||
|
|
||||||
|
@ -557,11 +557,11 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||||
|
|
||||||
samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2]
|
samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2]
|
||||||
|
|
||||||
decoded_samples = decode_first_stage(self.sd_model, samples)
|
if opts.use_scale_latent_for_hires_fix:
|
||||||
|
samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
|
||||||
|
|
||||||
if opts.upscaler_for_img2img is None or opts.upscaler_for_img2img == "None":
|
|
||||||
decoded_samples = torch.nn.functional.interpolate(decoded_samples, size=(self.height, self.width), mode="bilinear")
|
|
||||||
else:
|
else:
|
||||||
|
decoded_samples = decode_first_stage(self.sd_model, samples)
|
||||||
lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
|
lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
|
|
||||||
batch_images = []
|
batch_images = []
|
||||||
|
|
|
@ -24,7 +24,7 @@ def apply_optimizations():
|
||||||
|
|
||||||
ldm.modules.diffusionmodules.model.nonlinearity = silu
|
ldm.modules.diffusionmodules.model.nonlinearity = silu
|
||||||
|
|
||||||
if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (8, 6)):
|
if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)):
|
||||||
print("Applying xformers cross attention optimization.")
|
print("Applying xformers cross attention optimization.")
|
||||||
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward
|
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward
|
||||||
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward
|
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
import glob
|
import collections
|
||||||
import os.path
|
import os.path
|
||||||
import sys
|
import sys
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
|
@ -15,6 +15,7 @@ model_path = os.path.abspath(os.path.join(models_path, model_dir))
|
||||||
|
|
||||||
CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name', 'config'])
|
CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name', 'config'])
|
||||||
checkpoints_list = {}
|
checkpoints_list = {}
|
||||||
|
checkpoints_loaded = collections.OrderedDict()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
|
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
|
||||||
|
@ -132,15 +133,14 @@ def load_model_weights(model, checkpoint_info):
|
||||||
checkpoint_file = checkpoint_info.filename
|
checkpoint_file = checkpoint_info.filename
|
||||||
sd_model_hash = checkpoint_info.hash
|
sd_model_hash = checkpoint_info.hash
|
||||||
|
|
||||||
|
if checkpoint_info not in checkpoints_loaded:
|
||||||
print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}")
|
print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}")
|
||||||
|
|
||||||
pl_sd = torch.load(checkpoint_file, map_location=shared.weight_load_location)
|
pl_sd = torch.load(checkpoint_file, map_location=shared.weight_load_location)
|
||||||
|
|
||||||
if "global_step" in pl_sd:
|
if "global_step" in pl_sd:
|
||||||
print(f"Global Step: {pl_sd['global_step']}")
|
print(f"Global Step: {pl_sd['global_step']}")
|
||||||
|
|
||||||
sd = get_state_dict_from_checkpoint(pl_sd)
|
sd = get_state_dict_from_checkpoint(pl_sd)
|
||||||
|
|
||||||
model.load_state_dict(sd, strict=False)
|
model.load_state_dict(sd, strict=False)
|
||||||
|
|
||||||
if shared.cmd_opts.opt_channelslast:
|
if shared.cmd_opts.opt_channelslast:
|
||||||
|
@ -159,15 +159,20 @@ def load_model_weights(model, checkpoint_info):
|
||||||
|
|
||||||
if os.path.exists(vae_file):
|
if os.path.exists(vae_file):
|
||||||
print(f"Loading VAE weights from: {vae_file}")
|
print(f"Loading VAE weights from: {vae_file}")
|
||||||
|
|
||||||
vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location)
|
vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location)
|
||||||
|
|
||||||
vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"}
|
vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"}
|
||||||
|
|
||||||
model.first_stage_model.load_state_dict(vae_dict)
|
model.first_stage_model.load_state_dict(vae_dict)
|
||||||
|
|
||||||
model.first_stage_model.to(devices.dtype_vae)
|
model.first_stage_model.to(devices.dtype_vae)
|
||||||
|
|
||||||
|
checkpoints_loaded[checkpoint_info] = model.state_dict().copy()
|
||||||
|
while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
|
||||||
|
checkpoints_loaded.popitem(last=False) # LRU
|
||||||
|
else:
|
||||||
|
print(f"Loading weights [{sd_model_hash}] from cache")
|
||||||
|
checkpoints_loaded.move_to_end(checkpoint_info)
|
||||||
|
model.load_state_dict(checkpoints_loaded[checkpoint_info])
|
||||||
|
|
||||||
model.sd_model_hash = sd_model_hash
|
model.sd_model_hash = sd_model_hash
|
||||||
model.sd_model_checkpoint = checkpoint_file
|
model.sd_model_checkpoint = checkpoint_file
|
||||||
model.sd_checkpoint_info = checkpoint_info
|
model.sd_checkpoint_info = checkpoint_info
|
||||||
|
@ -205,6 +210,7 @@ def reload_model_weights(sd_model, info=None):
|
||||||
return
|
return
|
||||||
|
|
||||||
if sd_model.sd_checkpoint_info.config != checkpoint_info.config:
|
if sd_model.sd_checkpoint_info.config != checkpoint_info.config:
|
||||||
|
checkpoints_loaded.clear()
|
||||||
shared.sd_model = load_model()
|
shared.sd_model = load_model()
|
||||||
return shared.sd_model
|
return shared.sd_model
|
||||||
|
|
||||||
|
|
|
@ -218,6 +218,7 @@ options_templates.update(options_section(('upscaling', "Upscaling"), {
|
||||||
"SWIN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
|
"SWIN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
|
||||||
"ldsr_steps": OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}),
|
"ldsr_steps": OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}),
|
||||||
"upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}),
|
"upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}),
|
||||||
|
"use_scale_latent_for_hires_fix": OptionInfo(False, "Upscale latent space image when doing hires. fix"),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('face-restoration', "Face restoration"), {
|
options_templates.update(options_section(('face-restoration', "Face restoration"), {
|
||||||
|
@ -242,6 +243,7 @@ options_templates.update(options_section(('training', "Training"), {
|
||||||
|
|
||||||
options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
||||||
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, refresh=sd_models.list_models),
|
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, refresh=sd_models.list_models),
|
||||||
|
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
||||||
"sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
|
"sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
|
||||||
"sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}),
|
"sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}),
|
||||||
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
|
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
|
||||||
|
@ -255,7 +257,6 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
||||||
"filter_nsfw": OptionInfo(False, "Filter NSFW content"),
|
"filter_nsfw": OptionInfo(False, "Filter NSFW content"),
|
||||||
'CLIP_stop_at_last_layers': OptionInfo(1, "Stop At last layers of CLIP model", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}),
|
'CLIP_stop_at_last_layers': OptionInfo(1, "Stop At last layers of CLIP model", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}),
|
||||||
"random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}),
|
"random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}),
|
||||||
'quicksettings': OptionInfo("sd_model_checkpoint", "Quicksettings list"),
|
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('interrogate', "Interrogate Options"), {
|
options_templates.update(options_section(('interrogate', "Interrogate Options"), {
|
||||||
|
@ -283,6 +284,7 @@ options_templates.update(options_section(('ui', "User interface"), {
|
||||||
"js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"),
|
"js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"),
|
||||||
"js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),
|
"js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),
|
||||||
"show_progress_in_title": OptionInfo(True, "Show generation progress in window title."),
|
"show_progress_in_title": OptionInfo(True, "Show generation progress in window title."),
|
||||||
|
'quicksettings': OptionInfo("sd_model_checkpoint", "Quicksettings list"),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('sampler-params', "Sampler parameters"), {
|
options_templates.update(options_section(('sampler-params', "Sampler parameters"), {
|
||||||
|
|
|
@ -24,11 +24,12 @@ class DatasetEntry:
|
||||||
|
|
||||||
|
|
||||||
class PersonalizedBase(Dataset):
|
class PersonalizedBase(Dataset):
|
||||||
def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None, include_cond=False):
|
def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None, include_cond=False, batch_size=1):
|
||||||
re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None
|
re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None
|
||||||
|
|
||||||
self.placeholder_token = placeholder_token
|
self.placeholder_token = placeholder_token
|
||||||
|
|
||||||
|
self.batch_size = batch_size
|
||||||
self.width = width
|
self.width = width
|
||||||
self.height = height
|
self.height = height
|
||||||
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
|
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
|
||||||
|
@ -78,13 +79,14 @@ class PersonalizedBase(Dataset):
|
||||||
|
|
||||||
if include_cond:
|
if include_cond:
|
||||||
entry.cond_text = self.create_text(filename_text)
|
entry.cond_text = self.create_text(filename_text)
|
||||||
entry.cond = cond_model([entry.cond_text]).to(devices.cpu)
|
entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0)
|
||||||
|
|
||||||
self.dataset.append(entry)
|
self.dataset.append(entry)
|
||||||
|
|
||||||
self.length = len(self.dataset) * repeats
|
assert len(self.dataset) > 1, "No images have been found in the dataset."
|
||||||
|
self.length = len(self.dataset) * repeats // batch_size
|
||||||
|
|
||||||
self.initial_indexes = np.arange(self.length) % len(self.dataset)
|
self.initial_indexes = np.arange(len(self.dataset))
|
||||||
self.indexes = None
|
self.indexes = None
|
||||||
self.shuffle()
|
self.shuffle()
|
||||||
|
|
||||||
|
@ -101,13 +103,19 @@ class PersonalizedBase(Dataset):
|
||||||
return self.length
|
return self.length
|
||||||
|
|
||||||
def __getitem__(self, i):
|
def __getitem__(self, i):
|
||||||
if i % len(self.dataset) == 0:
|
res = []
|
||||||
|
|
||||||
|
for j in range(self.batch_size):
|
||||||
|
position = i * self.batch_size + j
|
||||||
|
if position % len(self.indexes) == 0:
|
||||||
self.shuffle()
|
self.shuffle()
|
||||||
|
|
||||||
index = self.indexes[i % len(self.indexes)]
|
index = self.indexes[position % len(self.indexes)]
|
||||||
entry = self.dataset[index]
|
entry = self.dataset[index]
|
||||||
|
|
||||||
if entry.cond is None:
|
if entry.cond is None:
|
||||||
entry.cond_text = self.create_text(entry.filename_text)
|
entry.cond_text = self.create_text(entry.filename_text)
|
||||||
|
|
||||||
return entry
|
res.append(entry)
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
|
@ -88,9 +88,9 @@ class EmbeddingDatabase:
|
||||||
|
|
||||||
data = []
|
data = []
|
||||||
|
|
||||||
if filename.upper().endswith('.PNG'):
|
if os.path.splitext(filename.upper())[-1] in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
|
||||||
embed_image = Image.open(path)
|
embed_image = Image.open(path)
|
||||||
if 'sd-ti-embedding' in embed_image.text:
|
if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text:
|
||||||
data = embedding_from_b64(embed_image.text['sd-ti-embedding'])
|
data = embedding_from_b64(embed_image.text['sd-ti-embedding'])
|
||||||
name = data.get('name', name)
|
name = data.get('name', name)
|
||||||
else:
|
else:
|
||||||
|
@ -199,7 +199,7 @@ def write_loss(log_directory, filename, step, epoch_len, values):
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
||||||
assert embedding_name, 'embedding not selected'
|
assert embedding_name, 'embedding not selected'
|
||||||
|
|
||||||
shared.state.textinfo = "Initializing textual inversion training..."
|
shared.state.textinfo = "Initializing textual inversion training..."
|
||||||
|
@ -231,7 +231,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
|
||||||
|
|
||||||
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
||||||
with torch.autocast("cuda"):
|
with torch.autocast("cuda"):
|
||||||
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file)
|
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file, batch_size=batch_size)
|
||||||
|
|
||||||
hijack = sd_hijack.model_hijack
|
hijack = sd_hijack.model_hijack
|
||||||
|
|
||||||
|
@ -242,6 +242,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
|
||||||
|
|
||||||
last_saved_file = "<none>"
|
last_saved_file = "<none>"
|
||||||
last_saved_image = "<none>"
|
last_saved_image = "<none>"
|
||||||
|
embedding_yet_to_be_embedded = False
|
||||||
|
|
||||||
ititial_step = embedding.step or 0
|
ititial_step = embedding.step or 0
|
||||||
if ititial_step > steps:
|
if ititial_step > steps:
|
||||||
|
@ -251,7 +252,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
|
||||||
optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate)
|
optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate)
|
||||||
|
|
||||||
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
|
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
|
||||||
for i, entry in pbar:
|
for i, entries in pbar:
|
||||||
embedding.step = i + ititial_step
|
embedding.step = i + ititial_step
|
||||||
|
|
||||||
scheduler.apply(optimizer, embedding.step)
|
scheduler.apply(optimizer, embedding.step)
|
||||||
|
@ -262,10 +263,9 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
|
||||||
break
|
break
|
||||||
|
|
||||||
with torch.autocast("cuda"):
|
with torch.autocast("cuda"):
|
||||||
c = cond_model([entry.cond_text])
|
c = cond_model([entry.cond_text for entry in entries])
|
||||||
|
x = torch.stack([entry.latent for entry in entries]).to(devices.device)
|
||||||
x = entry.latent.to(devices.device)
|
loss = shared.sd_model(x, c)[0]
|
||||||
loss = shared.sd_model(x.unsqueeze(0), c)[0]
|
|
||||||
del x
|
del x
|
||||||
|
|
||||||
losses[embedding.step % losses.shape[0]] = loss.item()
|
losses[embedding.step % losses.shape[0]] = loss.item()
|
||||||
|
@ -282,6 +282,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
|
||||||
if embedding.step > 0 and embedding_dir is not None and embedding.step % save_embedding_every == 0:
|
if embedding.step > 0 and embedding_dir is not None and embedding.step % save_embedding_every == 0:
|
||||||
last_saved_file = os.path.join(embedding_dir, f'{embedding_name}-{embedding.step}.pt')
|
last_saved_file = os.path.join(embedding_dir, f'{embedding_name}-{embedding.step}.pt')
|
||||||
embedding.save(last_saved_file)
|
embedding.save(last_saved_file)
|
||||||
|
embedding_yet_to_be_embedded = True
|
||||||
|
|
||||||
write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, len(ds), {
|
write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, len(ds), {
|
||||||
"loss": f"{losses.mean():.7f}",
|
"loss": f"{losses.mean():.7f}",
|
||||||
|
@ -307,7 +308,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
|
||||||
p.width = preview_width
|
p.width = preview_width
|
||||||
p.height = preview_height
|
p.height = preview_height
|
||||||
else:
|
else:
|
||||||
p.prompt = entry.cond_text
|
p.prompt = entries[0].cond_text
|
||||||
p.steps = 20
|
p.steps = 20
|
||||||
p.width = training_width
|
p.width = training_width
|
||||||
p.height = training_height
|
p.height = training_height
|
||||||
|
@ -319,7 +320,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
|
||||||
|
|
||||||
shared.state.current_image = image
|
shared.state.current_image = image
|
||||||
|
|
||||||
if save_image_with_stored_embedding and os.path.exists(last_saved_file):
|
if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
|
||||||
|
|
||||||
last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{embedding.step}.png')
|
last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{embedding.step}.png')
|
||||||
|
|
||||||
|
@ -328,15 +329,22 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
|
||||||
info.add_text("sd-ti-embedding", embedding_to_b64(data))
|
info.add_text("sd-ti-embedding", embedding_to_b64(data))
|
||||||
|
|
||||||
title = "<{}>".format(data.get('name', '???'))
|
title = "<{}>".format(data.get('name', '???'))
|
||||||
|
|
||||||
|
try:
|
||||||
|
vectorSize = list(data['string_to_param'].values())[0].shape[0]
|
||||||
|
except Exception as e:
|
||||||
|
vectorSize = '?'
|
||||||
|
|
||||||
checkpoint = sd_models.select_checkpoint()
|
checkpoint = sd_models.select_checkpoint()
|
||||||
footer_left = checkpoint.model_name
|
footer_left = checkpoint.model_name
|
||||||
footer_mid = '[{}]'.format(checkpoint.hash)
|
footer_mid = '[{}]'.format(checkpoint.hash)
|
||||||
footer_right = '{}'.format(embedding.step)
|
footer_right = '{}v {}s'.format(vectorSize, embedding.step)
|
||||||
|
|
||||||
captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right)
|
captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right)
|
||||||
captioned_image = insert_image_data_embed(captioned_image, data)
|
captioned_image = insert_image_data_embed(captioned_image, data)
|
||||||
|
|
||||||
captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info)
|
captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info)
|
||||||
|
embedding_yet_to_be_embedded = False
|
||||||
|
|
||||||
image.save(last_saved_image)
|
image.save(last_saved_image)
|
||||||
|
|
||||||
|
@ -348,7 +356,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
|
||||||
<p>
|
<p>
|
||||||
Loss: {losses.mean():.7f}<br/>
|
Loss: {losses.mean():.7f}<br/>
|
||||||
Step: {embedding.step}<br/>
|
Step: {embedding.step}<br/>
|
||||||
Last prompt: {html.escape(entry.cond_text)}<br/>
|
Last prompt: {html.escape(entries[0].cond_text)}<br/>
|
||||||
Last saved embedding: {html.escape(last_saved_file)}<br/>
|
Last saved embedding: {html.escape(last_saved_file)}<br/>
|
||||||
Last saved image: {html.escape(last_saved_image)}<br/>
|
Last saved image: {html.escape(last_saved_image)}<br/>
|
||||||
</p>
|
</p>
|
||||||
|
|
101
modules/ui.py
101
modules/ui.py
|
@ -7,6 +7,7 @@ import mimetypes
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import sys
|
import sys
|
||||||
|
import tempfile
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
import platform
|
import platform
|
||||||
|
@ -80,6 +81,8 @@ art_symbol = '\U0001f3a8' # 🎨
|
||||||
paste_symbol = '\u2199\ufe0f' # ↙
|
paste_symbol = '\u2199\ufe0f' # ↙
|
||||||
folder_symbol = '\U0001f4c2' # 📂
|
folder_symbol = '\U0001f4c2' # 📂
|
||||||
refresh_symbol = '\U0001f504' # 🔄
|
refresh_symbol = '\U0001f504' # 🔄
|
||||||
|
save_style_symbol = '\U0001f4be' # 💾
|
||||||
|
apply_style_symbol = '\U0001f4cb' # 📋
|
||||||
|
|
||||||
|
|
||||||
def plaintext_to_html(text):
|
def plaintext_to_html(text):
|
||||||
|
@ -88,6 +91,14 @@ def plaintext_to_html(text):
|
||||||
|
|
||||||
|
|
||||||
def image_from_url_text(filedata):
|
def image_from_url_text(filedata):
|
||||||
|
if type(filedata) == dict and filedata["is_file"]:
|
||||||
|
filename = filedata["name"]
|
||||||
|
tempdir = os.path.normpath(tempfile.gettempdir())
|
||||||
|
normfn = os.path.normpath(filename)
|
||||||
|
assert normfn.startswith(tempdir), 'trying to open image file not in temporary directory'
|
||||||
|
|
||||||
|
return Image.open(filename)
|
||||||
|
|
||||||
if type(filedata) == list:
|
if type(filedata) == list:
|
||||||
if len(filedata) == 0:
|
if len(filedata) == 0:
|
||||||
return None
|
return None
|
||||||
|
@ -143,10 +154,7 @@ def save_files(js_data, images, do_make_zip, index):
|
||||||
writer.writerow(["prompt", "seed", "width", "height", "sampler", "cfgs", "steps", "filename", "negative_prompt"])
|
writer.writerow(["prompt", "seed", "width", "height", "sampler", "cfgs", "steps", "filename", "negative_prompt"])
|
||||||
|
|
||||||
for image_index, filedata in enumerate(images, start_index):
|
for image_index, filedata in enumerate(images, start_index):
|
||||||
if filedata.startswith("data:image/png;base64,"):
|
image = image_from_url_text(filedata)
|
||||||
filedata = filedata[len("data:image/png;base64,"):]
|
|
||||||
|
|
||||||
image = Image.open(io.BytesIO(base64.decodebytes(filedata.encode('utf-8'))))
|
|
||||||
|
|
||||||
is_grid = image_index < p.index_of_first_image
|
is_grid = image_index < p.index_of_first_image
|
||||||
i = 0 if is_grid else (image_index - p.index_of_first_image)
|
i = 0 if is_grid else (image_index - p.index_of_first_image)
|
||||||
|
@ -176,6 +184,23 @@ def save_files(js_data, images, do_make_zip, index):
|
||||||
return gr.File.update(value=fullfns, visible=True), '', '', plaintext_to_html(f"Saved: {filenames[0]}")
|
return gr.File.update(value=fullfns, visible=True), '', '', plaintext_to_html(f"Saved: {filenames[0]}")
|
||||||
|
|
||||||
|
|
||||||
|
def save_pil_to_file(pil_image, dir=None):
|
||||||
|
use_metadata = False
|
||||||
|
metadata = PngImagePlugin.PngInfo()
|
||||||
|
for key, value in pil_image.info.items():
|
||||||
|
if isinstance(key, str) and isinstance(value, str):
|
||||||
|
metadata.add_text(key, value)
|
||||||
|
use_metadata = True
|
||||||
|
|
||||||
|
file_obj = tempfile.NamedTemporaryFile(delete=False, suffix=".png", dir=dir)
|
||||||
|
pil_image.save(file_obj, pnginfo=(metadata if use_metadata else None))
|
||||||
|
return file_obj
|
||||||
|
|
||||||
|
|
||||||
|
# override save to file function so that it also writes PNG info
|
||||||
|
gr.processing_utils.save_pil_to_file = save_pil_to_file
|
||||||
|
|
||||||
|
|
||||||
def wrap_gradio_call(func, extra_outputs=None):
|
def wrap_gradio_call(func, extra_outputs=None):
|
||||||
def f(*args, extra_outputs_array=extra_outputs, **kwargs):
|
def f(*args, extra_outputs_array=extra_outputs, **kwargs):
|
||||||
run_memmon = opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled
|
run_memmon = opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled
|
||||||
|
@ -304,7 +329,7 @@ def visit(x, func, path=""):
|
||||||
|
|
||||||
def add_style(name: str, prompt: str, negative_prompt: str):
|
def add_style(name: str, prompt: str, negative_prompt: str):
|
||||||
if name is None:
|
if name is None:
|
||||||
return [gr_show(), gr_show()]
|
return [gr_show() for x in range(4)]
|
||||||
|
|
||||||
style = modules.styles.PromptStyle(name, prompt, negative_prompt)
|
style = modules.styles.PromptStyle(name, prompt, negative_prompt)
|
||||||
shared.prompt_styles.styles[style.name] = style
|
shared.prompt_styles.styles[style.name] = style
|
||||||
|
@ -429,29 +454,38 @@ def create_toprow(is_img2img):
|
||||||
id_part = "img2img" if is_img2img else "txt2img"
|
id_part = "img2img" if is_img2img else "txt2img"
|
||||||
|
|
||||||
with gr.Row(elem_id="toprow"):
|
with gr.Row(elem_id="toprow"):
|
||||||
with gr.Column(scale=4):
|
with gr.Column(scale=6):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column(scale=80):
|
with gr.Column(scale=80):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, placeholder="Prompt", lines=2)
|
prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=2,
|
||||||
|
placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)"
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column(scale=80):
|
||||||
|
with gr.Row():
|
||||||
|
negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=2,
|
||||||
|
placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)"
|
||||||
|
)
|
||||||
|
|
||||||
with gr.Column(scale=1, elem_id="roll_col"):
|
with gr.Column(scale=1, elem_id="roll_col"):
|
||||||
roll = gr.Button(value=art_symbol, elem_id="roll", visible=len(shared.artist_db.artists) > 0)
|
roll = gr.Button(value=art_symbol, elem_id="roll", visible=len(shared.artist_db.artists) > 0)
|
||||||
paste = gr.Button(value=paste_symbol, elem_id="paste")
|
paste = gr.Button(value=paste_symbol, elem_id="paste")
|
||||||
|
save_style = gr.Button(value=save_style_symbol, elem_id="style_create")
|
||||||
|
prompt_style_apply = gr.Button(value=apply_style_symbol, elem_id="style_apply")
|
||||||
|
|
||||||
token_counter = gr.HTML(value="<span></span>", elem_id=f"{id_part}_token_counter")
|
token_counter = gr.HTML(value="<span></span>", elem_id=f"{id_part}_token_counter")
|
||||||
token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
|
token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
|
||||||
|
|
||||||
with gr.Column(scale=10, elem_id="style_pos_col"):
|
button_interrogate = None
|
||||||
prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())), visible=len(shared.prompt_styles.styles) > 1)
|
button_deepbooru = None
|
||||||
|
if is_img2img:
|
||||||
|
with gr.Column(scale=1, elem_id="interrogate_col"):
|
||||||
|
button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate")
|
||||||
|
|
||||||
with gr.Row():
|
if cmd_opts.deepdanbooru:
|
||||||
with gr.Column(scale=8):
|
button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru")
|
||||||
with gr.Row():
|
|
||||||
negative_prompt = gr.Textbox(label="Negative prompt", elem_id="negative_prompt", show_label=False, placeholder="Negative prompt", lines=2)
|
|
||||||
with gr.Column(scale=1, elem_id="roll_col"):
|
|
||||||
sh = gr.Button(elem_id="sh", visible=True)
|
|
||||||
|
|
||||||
with gr.Column(scale=1, elem_id="style_neg_col"):
|
|
||||||
prompt_style2 = gr.Dropdown(label="Style 2", elem_id=f"{id_part}_style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())), visible=len(shared.prompt_styles.styles) > 1)
|
|
||||||
|
|
||||||
with gr.Column(scale=1):
|
with gr.Column(scale=1):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
|
@ -471,20 +505,14 @@ def create_toprow(is_img2img):
|
||||||
outputs=[],
|
outputs=[],
|
||||||
)
|
)
|
||||||
|
|
||||||
with gr.Row(scale=1):
|
with gr.Row():
|
||||||
if is_img2img:
|
with gr.Column(scale=1, elem_id="style_pos_col"):
|
||||||
interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate")
|
prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())))
|
||||||
if cmd_opts.deepdanbooru:
|
|
||||||
deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru")
|
|
||||||
else:
|
|
||||||
deepbooru = None
|
|
||||||
else:
|
|
||||||
interrogate = None
|
|
||||||
deepbooru = None
|
|
||||||
prompt_style_apply = gr.Button('Apply style', elem_id="style_apply")
|
|
||||||
save_style = gr.Button('Create style', elem_id="style_create")
|
|
||||||
|
|
||||||
return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, interrogate, deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button
|
with gr.Column(scale=1, elem_id="style_neg_col"):
|
||||||
|
prompt_style2 = gr.Dropdown(label="Style 2", elem_id=f"{id_part}_style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())))
|
||||||
|
|
||||||
|
return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button
|
||||||
|
|
||||||
|
|
||||||
def setup_progressbar(progressbar, preview, id_part, textinfo=None):
|
def setup_progressbar(progressbar, preview, id_part, textinfo=None):
|
||||||
|
@ -588,7 +616,7 @@ def create_ui(wrap_gradio_gpu_call):
|
||||||
txt2img_preview = gr.Image(elem_id='txt2img_preview', visible=False)
|
txt2img_preview = gr.Image(elem_id='txt2img_preview', visible=False)
|
||||||
txt2img_gallery = gr.Gallery(label='Output', show_label=False, elem_id='txt2img_gallery').style(grid=4)
|
txt2img_gallery = gr.Gallery(label='Output', show_label=False, elem_id='txt2img_gallery').style(grid=4)
|
||||||
|
|
||||||
with gr.Group():
|
with gr.Column():
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
save = gr.Button('Save')
|
save = gr.Button('Save')
|
||||||
send_to_img2img = gr.Button('Send to img2img')
|
send_to_img2img = gr.Button('Send to img2img')
|
||||||
|
@ -744,10 +772,10 @@ def create_ui(wrap_gradio_gpu_call):
|
||||||
|
|
||||||
with gr.Tabs(elem_id="mode_img2img") as tabs_img2img_mode:
|
with gr.Tabs(elem_id="mode_img2img") as tabs_img2img_mode:
|
||||||
with gr.TabItem('img2img', id='img2img'):
|
with gr.TabItem('img2img', id='img2img'):
|
||||||
init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_img2img_tool)
|
init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_img2img_tool).style(height=480)
|
||||||
|
|
||||||
with gr.TabItem('Inpaint', id='inpaint'):
|
with gr.TabItem('Inpaint', id='inpaint'):
|
||||||
init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA")
|
init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA").style(height=480)
|
||||||
|
|
||||||
init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_base")
|
init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_base")
|
||||||
init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_mask")
|
init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_mask")
|
||||||
|
@ -803,7 +831,7 @@ def create_ui(wrap_gradio_gpu_call):
|
||||||
img2img_preview = gr.Image(elem_id='img2img_preview', visible=False)
|
img2img_preview = gr.Image(elem_id='img2img_preview', visible=False)
|
||||||
img2img_gallery = gr.Gallery(label='Output', show_label=False, elem_id='img2img_gallery').style(grid=4)
|
img2img_gallery = gr.Gallery(label='Output', show_label=False, elem_id='img2img_gallery').style(grid=4)
|
||||||
|
|
||||||
with gr.Group():
|
with gr.Column():
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
save = gr.Button('Save')
|
save = gr.Button('Save')
|
||||||
img2img_send_to_img2img = gr.Button('Send to img2img')
|
img2img_send_to_img2img = gr.Button('Send to img2img')
|
||||||
|
@ -1166,6 +1194,7 @@ def create_ui(wrap_gradio_gpu_call):
|
||||||
train_embedding_name = gr.Dropdown(label='Embedding', choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
|
train_embedding_name = gr.Dropdown(label='Embedding', choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
|
||||||
train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', choices=[x for x in shared.hypernetworks.keys()])
|
train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', choices=[x for x in shared.hypernetworks.keys()])
|
||||||
learn_rate = gr.Textbox(label='Learning rate', placeholder="Learning rate", value="0.005")
|
learn_rate = gr.Textbox(label='Learning rate', placeholder="Learning rate", value="0.005")
|
||||||
|
batch_size = gr.Number(label='Batch size', value=1, precision=0)
|
||||||
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
|
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
|
||||||
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
|
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
|
||||||
template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"))
|
template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"))
|
||||||
|
@ -1244,6 +1273,7 @@ def create_ui(wrap_gradio_gpu_call):
|
||||||
inputs=[
|
inputs=[
|
||||||
train_embedding_name,
|
train_embedding_name,
|
||||||
learn_rate,
|
learn_rate,
|
||||||
|
batch_size,
|
||||||
dataset_directory,
|
dataset_directory,
|
||||||
log_directory,
|
log_directory,
|
||||||
training_width,
|
training_width,
|
||||||
|
@ -1268,6 +1298,7 @@ def create_ui(wrap_gradio_gpu_call):
|
||||||
inputs=[
|
inputs=[
|
||||||
train_hypernetwork_name,
|
train_hypernetwork_name,
|
||||||
learn_rate,
|
learn_rate,
|
||||||
|
batch_size,
|
||||||
dataset_directory,
|
dataset_directory,
|
||||||
log_directory,
|
log_directory,
|
||||||
steps,
|
steps,
|
||||||
|
|
|
@ -4,7 +4,7 @@ fairscale==0.4.4
|
||||||
fonts
|
fonts
|
||||||
font-roboto
|
font-roboto
|
||||||
gfpgan
|
gfpgan
|
||||||
gradio==3.4.1
|
gradio==3.5
|
||||||
invisible-watermark
|
invisible-watermark
|
||||||
numpy
|
numpy
|
||||||
omegaconf
|
omegaconf
|
||||||
|
|
|
@ -2,7 +2,7 @@ transformers==4.19.2
|
||||||
diffusers==0.3.0
|
diffusers==0.3.0
|
||||||
basicsr==1.4.2
|
basicsr==1.4.2
|
||||||
gfpgan==1.3.8
|
gfpgan==1.3.8
|
||||||
gradio==3.4.1
|
gradio==3.5
|
||||||
numpy==1.23.3
|
numpy==1.23.3
|
||||||
Pillow==9.2.0
|
Pillow==9.2.0
|
||||||
realesrgan==0.3.0
|
realesrgan==0.3.0
|
||||||
|
|
|
@ -50,9 +50,9 @@ document.addEventListener("DOMContentLoaded", function() {
|
||||||
document.addEventListener('keydown', function(e) {
|
document.addEventListener('keydown', function(e) {
|
||||||
var handled = false;
|
var handled = false;
|
||||||
if (e.key !== undefined) {
|
if (e.key !== undefined) {
|
||||||
if((e.key == "Enter" && (e.metaKey || e.ctrlKey))) handled = true;
|
if((e.key == "Enter" && (e.metaKey || e.ctrlKey || e.altKey))) handled = true;
|
||||||
} else if (e.keyCode !== undefined) {
|
} else if (e.keyCode !== undefined) {
|
||||||
if((e.keyCode == 13 && (e.metaKey || e.ctrlKey))) handled = true;
|
if((e.keyCode == 13 && (e.metaKey || e.ctrlKey || e.altKey))) handled = true;
|
||||||
}
|
}
|
||||||
if (handled) {
|
if (handled) {
|
||||||
button = get_uiCurrentTabContent().querySelector('button[id$=_generate]');
|
button = get_uiCurrentTabContent().querySelector('button[id$=_generate]');
|
||||||
|
|
|
@ -1,7 +1,9 @@
|
||||||
|
import copy
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
|
import shlex
|
||||||
|
|
||||||
import modules.scripts as scripts
|
import modules.scripts as scripts
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
@ -10,6 +12,75 @@ from modules.processing import Processed, process_images
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from modules.shared import opts, cmd_opts, state
|
from modules.shared import opts, cmd_opts, state
|
||||||
|
|
||||||
|
|
||||||
|
def process_string_tag(tag):
|
||||||
|
return tag
|
||||||
|
|
||||||
|
|
||||||
|
def process_int_tag(tag):
|
||||||
|
return int(tag)
|
||||||
|
|
||||||
|
|
||||||
|
def process_float_tag(tag):
|
||||||
|
return float(tag)
|
||||||
|
|
||||||
|
|
||||||
|
def process_boolean_tag(tag):
|
||||||
|
return True if (tag == "true") else False
|
||||||
|
|
||||||
|
|
||||||
|
prompt_tags = {
|
||||||
|
"sd_model": None,
|
||||||
|
"outpath_samples": process_string_tag,
|
||||||
|
"outpath_grids": process_string_tag,
|
||||||
|
"prompt_for_display": process_string_tag,
|
||||||
|
"prompt": process_string_tag,
|
||||||
|
"negative_prompt": process_string_tag,
|
||||||
|
"styles": process_string_tag,
|
||||||
|
"seed": process_int_tag,
|
||||||
|
"subseed_strength": process_float_tag,
|
||||||
|
"subseed": process_int_tag,
|
||||||
|
"seed_resize_from_h": process_int_tag,
|
||||||
|
"seed_resize_from_w": process_int_tag,
|
||||||
|
"sampler_index": process_int_tag,
|
||||||
|
"batch_size": process_int_tag,
|
||||||
|
"n_iter": process_int_tag,
|
||||||
|
"steps": process_int_tag,
|
||||||
|
"cfg_scale": process_float_tag,
|
||||||
|
"width": process_int_tag,
|
||||||
|
"height": process_int_tag,
|
||||||
|
"restore_faces": process_boolean_tag,
|
||||||
|
"tiling": process_boolean_tag,
|
||||||
|
"do_not_save_samples": process_boolean_tag,
|
||||||
|
"do_not_save_grid": process_boolean_tag
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def cmdargs(line):
|
||||||
|
args = shlex.split(line)
|
||||||
|
pos = 0
|
||||||
|
res = {}
|
||||||
|
|
||||||
|
while pos < len(args):
|
||||||
|
arg = args[pos]
|
||||||
|
|
||||||
|
assert arg.startswith("--"), f'must start with "--": {arg}'
|
||||||
|
tag = arg[2:]
|
||||||
|
|
||||||
|
func = prompt_tags.get(tag, None)
|
||||||
|
assert func, f'unknown commandline option: {arg}'
|
||||||
|
|
||||||
|
assert pos+1 < len(args), f'missing argument for command line option {arg}'
|
||||||
|
|
||||||
|
val = args[pos+1]
|
||||||
|
|
||||||
|
res[tag] = func(val)
|
||||||
|
|
||||||
|
pos += 2
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
class Script(scripts.Script):
|
class Script(scripts.Script):
|
||||||
def title(self):
|
def title(self):
|
||||||
return "Prompts from file or textbox"
|
return "Prompts from file or textbox"
|
||||||
|
@ -32,26 +103,48 @@ class Script(scripts.Script):
|
||||||
return [ gr.Checkbox.update(visible = True), gr.File.update(visible = not checkbox_txt), gr.TextArea.update(visible = checkbox_txt) ]
|
return [ gr.Checkbox.update(visible = True), gr.File.update(visible = not checkbox_txt), gr.TextArea.update(visible = checkbox_txt) ]
|
||||||
|
|
||||||
def run(self, p, checkbox_txt, data: bytes, prompt_txt: str):
|
def run(self, p, checkbox_txt, data: bytes, prompt_txt: str):
|
||||||
if (checkbox_txt):
|
if checkbox_txt:
|
||||||
lines = [x.strip() for x in prompt_txt.splitlines()]
|
lines = [x.strip() for x in prompt_txt.splitlines()]
|
||||||
else:
|
else:
|
||||||
lines = [x.strip() for x in data.decode('utf8', errors='ignore').split("\n")]
|
lines = [x.strip() for x in data.decode('utf8', errors='ignore').split("\n")]
|
||||||
lines = [x for x in lines if len(x) > 0]
|
lines = [x for x in lines if len(x) > 0]
|
||||||
|
|
||||||
img_count = len(lines) * p.n_iter
|
|
||||||
batch_count = math.ceil(img_count / p.batch_size)
|
|
||||||
loop_count = math.ceil(batch_count / p.n_iter)
|
|
||||||
print(f"Will process {img_count} images in {batch_count} batches.")
|
|
||||||
|
|
||||||
p.do_not_save_grid = True
|
p.do_not_save_grid = True
|
||||||
|
|
||||||
state.job_count = batch_count
|
job_count = 0
|
||||||
|
jobs = []
|
||||||
|
|
||||||
|
for line in lines:
|
||||||
|
if "--" in line:
|
||||||
|
try:
|
||||||
|
args = cmdargs(line)
|
||||||
|
except Exception:
|
||||||
|
print(f"Error parsing line [line] as commandline:", file=sys.stderr)
|
||||||
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
args = {"prompt": line}
|
||||||
|
else:
|
||||||
|
args = {"prompt": line}
|
||||||
|
|
||||||
|
n_iter = args.get("n_iter", 1)
|
||||||
|
if n_iter != 1:
|
||||||
|
job_count += n_iter
|
||||||
|
else:
|
||||||
|
job_count += 1
|
||||||
|
|
||||||
|
jobs.append(args)
|
||||||
|
|
||||||
|
print(f"Will process {len(lines)} lines in {job_count} jobs.")
|
||||||
|
state.job_count = job_count
|
||||||
|
|
||||||
images = []
|
images = []
|
||||||
for loop_no in range(loop_count):
|
for n, args in enumerate(jobs):
|
||||||
state.job = f"{loop_no + 1} out of {loop_count}"
|
state.job = f"{state.job_no + 1} out of {state.job_count}"
|
||||||
p.prompt = lines[loop_no*p.batch_size:(loop_no+1)*p.batch_size] * p.n_iter
|
|
||||||
proc = process_images(p)
|
copy_p = copy.copy(p)
|
||||||
|
for k, v in args.items():
|
||||||
|
setattr(copy_p, k, v)
|
||||||
|
|
||||||
|
proc = process_images(copy_p)
|
||||||
images += proc.images
|
images += proc.images
|
||||||
|
|
||||||
return Processed(p, images, p.seed, "")
|
return Processed(p, images, p.seed, "")
|
||||||
|
|
|
@ -12,7 +12,7 @@ import gradio as gr
|
||||||
|
|
||||||
from modules import images
|
from modules import images
|
||||||
from modules.hypernetworks import hypernetwork
|
from modules.hypernetworks import hypernetwork
|
||||||
from modules.processing import process_images, Processed, get_correct_sampler
|
from modules.processing import process_images, Processed, get_correct_sampler, StableDiffusionProcessingTxt2Img
|
||||||
from modules.shared import opts, cmd_opts, state
|
from modules.shared import opts, cmd_opts, state
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
import modules.sd_samplers
|
import modules.sd_samplers
|
||||||
|
@ -354,6 +354,9 @@ class Script(scripts.Script):
|
||||||
else:
|
else:
|
||||||
total_steps = p.steps * len(xs) * len(ys)
|
total_steps = p.steps * len(xs) * len(ys)
|
||||||
|
|
||||||
|
if isinstance(p, StableDiffusionProcessingTxt2Img) and p.enable_hr:
|
||||||
|
total_steps *= 2
|
||||||
|
|
||||||
print(f"X/Y plot will create {len(xs) * len(ys) * p.n_iter} images on a {len(xs)}x{len(ys)} grid. (Total steps to process: {total_steps * p.n_iter})")
|
print(f"X/Y plot will create {len(xs) * len(ys) * p.n_iter} images on a {len(xs)}x{len(ys)} grid. (Total steps to process: {total_steps * p.n_iter})")
|
||||||
shared.total_tqdm.updateTotal(total_steps * p.n_iter)
|
shared.total_tqdm.updateTotal(total_steps * p.n_iter)
|
||||||
|
|
||||||
|
|
41
style.css
41
style.css
|
@ -115,7 +115,7 @@
|
||||||
padding: 0.4em 0;
|
padding: 0.4em 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
#roll, #paste{
|
#roll, #paste, #style_create, #style_apply{
|
||||||
min-width: 2em;
|
min-width: 2em;
|
||||||
min-height: 2em;
|
min-height: 2em;
|
||||||
max-width: 2em;
|
max-width: 2em;
|
||||||
|
@ -126,14 +126,14 @@
|
||||||
margin: 0.1em 0;
|
margin: 0.1em 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
#style_apply, #style_create, #interrogate{
|
#interrogate_col{
|
||||||
margin: 0.75em 0.25em 0.25em 0.25em;
|
min-width: 0 !important;
|
||||||
min-width: 5em;
|
max-width: 8em !important;
|
||||||
}
|
}
|
||||||
|
#interrogate, #deepbooru{
|
||||||
#style_apply, #style_create, #deepbooru{
|
margin: 0em 0.25em 0.9em 0.25em;
|
||||||
margin: 0.75em 0.25em 0.25em 0.25em;
|
min-width: 8em;
|
||||||
min-width: 5em;
|
max-width: 8em;
|
||||||
}
|
}
|
||||||
|
|
||||||
#style_pos_col, #style_neg_col{
|
#style_pos_col, #style_neg_col{
|
||||||
|
@ -167,10 +167,6 @@ button{
|
||||||
align-self: stretch !important;
|
align-self: stretch !important;
|
||||||
}
|
}
|
||||||
|
|
||||||
#img2maskimg .h-60{
|
|
||||||
height: 30rem;
|
|
||||||
}
|
|
||||||
|
|
||||||
.overflow-hidden, .gr-panel{
|
.overflow-hidden, .gr-panel{
|
||||||
overflow: visible !important;
|
overflow: visible !important;
|
||||||
}
|
}
|
||||||
|
@ -241,13 +237,6 @@ fieldset span.text-gray-500, .gr-block.gr-box span.text-gray-500, label.block s
|
||||||
margin: 0;
|
margin: 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
.gr-panel div.flex-col div.justify-between div{
|
|
||||||
position: absolute;
|
|
||||||
top: -0.1em;
|
|
||||||
right: 1em;
|
|
||||||
padding: 0 0.5em;
|
|
||||||
}
|
|
||||||
|
|
||||||
#settings .gr-panel div.flex-col div.justify-between div{
|
#settings .gr-panel div.flex-col div.justify-between div{
|
||||||
position: relative;
|
position: relative;
|
||||||
z-index: 200;
|
z-index: 200;
|
||||||
|
@ -320,6 +309,8 @@ input[type="range"]{
|
||||||
height: 100%;
|
height: 100%;
|
||||||
overflow: auto;
|
overflow: auto;
|
||||||
background-color: rgba(20, 20, 20, 0.95);
|
background-color: rgba(20, 20, 20, 0.95);
|
||||||
|
user-select: none;
|
||||||
|
-webkit-user-select: none;
|
||||||
}
|
}
|
||||||
|
|
||||||
.modalControls {
|
.modalControls {
|
||||||
|
@ -443,10 +434,6 @@ input[type="range"]{
|
||||||
--tw-bg-opacity: 0 !important;
|
--tw-bg-opacity: 0 !important;
|
||||||
}
|
}
|
||||||
|
|
||||||
#img2img_image div.h-60{
|
|
||||||
height: 480px;
|
|
||||||
}
|
|
||||||
|
|
||||||
#context-menu{
|
#context-menu{
|
||||||
z-index:9999;
|
z-index:9999;
|
||||||
position:absolute;
|
position:absolute;
|
||||||
|
@ -521,3 +508,11 @@ canvas[key="mask"] {
|
||||||
.row.gr-compact{
|
.row.gr-compact{
|
||||||
overflow: visible;
|
overflow: visible;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#img2img_image, #img2img_image > .h-60, #img2img_image > .h-60 > div, #img2img_image > .h-60 > div > img,
|
||||||
|
img2maskimg, #img2maskimg > .h-60, #img2maskimg > .h-60 > div, #img2maskimg > .h-60 > div > img
|
||||||
|
{
|
||||||
|
height: 480px !important;
|
||||||
|
max-height: 480px !important;
|
||||||
|
min-height: 480px !important;
|
||||||
|
}
|
||||||
|
|
4
webui.sh
4
webui.sh
|
@ -82,8 +82,8 @@ then
|
||||||
clone_dir="${PWD##*/}"
|
clone_dir="${PWD##*/}"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Check prequisites
|
# Check prerequisites
|
||||||
for preq in git python3
|
for preq in "${GIT}" "${python_cmd}"
|
||||||
do
|
do
|
||||||
if ! hash "${preq}" &>/dev/null
|
if ! hash "${preq}" &>/dev/null
|
||||||
then
|
then
|
||||||
|
|
Loading…
Reference in a new issue