Merge branch 'master' into hot-reload-javascript

This commit is contained in:
AUTOMATIC1111 2022-10-19 09:43:49 +03:00 committed by GitHub
commit 05315d8a23
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
36 changed files with 858 additions and 232 deletions

View file

@ -22,6 +22,12 @@ jobs:
uses: actions/setup-python@v3 uses: actions/setup-python@v3
with: with:
python-version: 3.10.6 python-version: 3.10.6
- uses: actions/cache@v2
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }}
restore-keys: |
${{ runner.os }}-pip-
- name: Install PyLint - name: Install PyLint
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip

View file

@ -523,7 +523,6 @@ Affandi,0.7170285,nudity
Diane Arbus,0.655138,digipa-high-impact Diane Arbus,0.655138,digipa-high-impact
Joseph Ducreux,0.65247905,digipa-high-impact Joseph Ducreux,0.65247905,digipa-high-impact
Berthe Morisot,0.7165984,fineart Berthe Morisot,0.7165984,fineart
Hilma AF Klint,0.71643853,scribbles
Hilma af Klint,0.71643853,scribbles Hilma af Klint,0.71643853,scribbles
Filippino Lippi,0.7163017,fineart Filippino Lippi,0.7163017,fineart
Leonid Afremov,0.7163005,fineart Leonid Afremov,0.7163005,fineart
@ -738,14 +737,12 @@ Abraham Mignon,0.60605425,fineart
Albert Bloch,0.69573116,nudity Albert Bloch,0.69573116,nudity
Charles Dana Gibson,0.67155975,fineart Charles Dana Gibson,0.67155975,fineart
Alexandre-Évariste Fragonard,0.6507174,fineart Alexandre-Évariste Fragonard,0.6507174,fineart
Alexandre-Évariste Fragonard,0.6507174,fineart
Ernst Fuchs,0.6953538,nudity Ernst Fuchs,0.6953538,nudity
Alfredo Jaar,0.6952965,digipa-high-impact Alfredo Jaar,0.6952965,digipa-high-impact
Judy Chicago,0.6952246,weird Judy Chicago,0.6952246,weird
Frans van Mieris the Younger,0.6951849,fineart Frans van Mieris the Younger,0.6951849,fineart
Aertgen van Leyden,0.6951305,fineart Aertgen van Leyden,0.6951305,fineart
Emily Carr,0.69512105,fineart Emily Carr,0.69512105,fineart
Frances Macdonald,0.6950408,scribbles
Frances MacDonald,0.6950408,scribbles Frances MacDonald,0.6950408,scribbles
Hannah Höch,0.69495845,scribbles Hannah Höch,0.69495845,scribbles
Gillis Rombouts,0.58770025,fineart Gillis Rombouts,0.58770025,fineart
@ -895,7 +892,6 @@ Richard McGuire,0.6820089,scribbles
Anni Albers,0.65708244,digipa-high-impact Anni Albers,0.65708244,digipa-high-impact
Aleksey Savrasov,0.65207493,fineart Aleksey Savrasov,0.65207493,fineart
Wayne Barlowe,0.6537874,fineart Wayne Barlowe,0.6537874,fineart
Giorgio De Chirico,0.6815907,fineart
Giorgio de Chirico,0.6815907,fineart Giorgio de Chirico,0.6815907,fineart
Ernest Procter,0.6815795,fineart Ernest Procter,0.6815795,fineart
Adriaen Brouwer,0.6815058,fineart Adriaen Brouwer,0.6815058,fineart
@ -1241,7 +1237,6 @@ Betty Churcher,0.65387225,fineart
Claes Corneliszoon Moeyaert,0.65386075,fineart Claes Corneliszoon Moeyaert,0.65386075,fineart
David Bomberg,0.6537477,fineart David Bomberg,0.6537477,fineart
Abraham Bosschaert,0.6535562,fineart Abraham Bosschaert,0.6535562,fineart
Giuseppe De Nittis,0.65354455,fineart
Giuseppe de Nittis,0.65354455,fineart Giuseppe de Nittis,0.65354455,fineart
John La Farge,0.65342575,fineart John La Farge,0.65342575,fineart
Frits Thaulow,0.65341854,fineart Frits Thaulow,0.65341854,fineart
@ -1522,7 +1517,6 @@ Gertrude Harvey,0.5903887,fineart
Grant Wood,0.6266253,fineart Grant Wood,0.6266253,fineart
Fyodor Vasilyev,0.5234919,digipa-med-impact Fyodor Vasilyev,0.5234919,digipa-med-impact
Cagnaccio di San Pietro,0.6261671,fineart Cagnaccio di San Pietro,0.6261671,fineart
Cagnaccio Di San Pietro,0.6261671,fineart
Doris Boulton-Maude,0.62593174,fineart Doris Boulton-Maude,0.62593174,fineart
Adolf Hirémy-Hirschl,0.5946784,fineart Adolf Hirémy-Hirschl,0.5946784,fineart
Harold von Schmidt,0.6256755,fineart Harold von Schmidt,0.6256755,fineart
@ -2411,7 +2405,6 @@ Hermann Feierabend,0.5346168,digipa-high-impact
Antonio Donghi,0.4610982,digipa-low-impact Antonio Donghi,0.4610982,digipa-low-impact
Adonna Khare,0.4858036,digipa-med-impact Adonna Khare,0.4858036,digipa-med-impact
James Stokoe,0.5015107,digipa-med-impact James Stokoe,0.5015107,digipa-med-impact
Art & Language,0.5341332,digipa-high-impact
Agustín Fernández,0.53403986,fineart Agustín Fernández,0.53403986,fineart
Germán Londoño,0.5338712,fineart Germán Londoño,0.5338712,fineart
Emmanuelle Moureaux,0.5335641,digipa-high-impact Emmanuelle Moureaux,0.5335641,digipa-high-impact

1 artist score category
523 Diane Arbus 0.655138 digipa-high-impact
524 Joseph Ducreux 0.65247905 digipa-high-impact
525 Berthe Morisot 0.7165984 fineart
Hilma AF Klint 0.71643853 scribbles
526 Hilma af Klint 0.71643853 scribbles
527 Filippino Lippi 0.7163017 fineart
528 Leonid Afremov 0.7163005 fineart
737 Albert Bloch 0.69573116 nudity
738 Charles Dana Gibson 0.67155975 fineart
739 Alexandre-Évariste Fragonard 0.6507174 fineart
Alexandre-Évariste Fragonard 0.6507174 fineart
740 Ernst Fuchs 0.6953538 nudity
741 Alfredo Jaar 0.6952965 digipa-high-impact
742 Judy Chicago 0.6952246 weird
743 Frans van Mieris the Younger 0.6951849 fineart
744 Aertgen van Leyden 0.6951305 fineart
745 Emily Carr 0.69512105 fineart
Frances Macdonald 0.6950408 scribbles
746 Frances MacDonald 0.6950408 scribbles
747 Hannah Höch 0.69495845 scribbles
748 Gillis Rombouts 0.58770025 fineart
892 Anni Albers 0.65708244 digipa-high-impact
893 Aleksey Savrasov 0.65207493 fineart
894 Wayne Barlowe 0.6537874 fineart
Giorgio De Chirico 0.6815907 fineart
895 Giorgio de Chirico 0.6815907 fineart
896 Ernest Procter 0.6815795 fineart
897 Adriaen Brouwer 0.6815058 fineart
1237 Claes Corneliszoon Moeyaert 0.65386075 fineart
1238 David Bomberg 0.6537477 fineart
1239 Abraham Bosschaert 0.6535562 fineart
Giuseppe De Nittis 0.65354455 fineart
1240 Giuseppe de Nittis 0.65354455 fineart
1241 John La Farge 0.65342575 fineart
1242 Frits Thaulow 0.65341854 fineart
1517 Grant Wood 0.6266253 fineart
1518 Fyodor Vasilyev 0.5234919 digipa-med-impact
1519 Cagnaccio di San Pietro 0.6261671 fineart
Cagnaccio Di San Pietro 0.6261671 fineart
1520 Doris Boulton-Maude 0.62593174 fineart
1521 Adolf Hirémy-Hirschl 0.5946784 fineart
1522 Harold von Schmidt 0.6256755 fineart
2405 Antonio Donghi 0.4610982 digipa-low-impact
2406 Adonna Khare 0.4858036 digipa-med-impact
2407 James Stokoe 0.5015107 digipa-med-impact
Art & Language 0.5341332 digipa-high-impact
2408 Agustín Fernández 0.53403986 fineart
2409 Germán Londoño 0.5338712 fineart
2410 Emmanuelle Moureaux 0.5335641 digipa-high-impact

View file

@ -9,9 +9,38 @@ addEventListener('keydown', (event) => {
let minus = "ArrowDown" let minus = "ArrowDown"
if (event.key != plus && event.key != minus) return; if (event.key != plus && event.key != minus) return;
selectionStart = target.selectionStart; let selectionStart = target.selectionStart;
selectionEnd = target.selectionEnd; let selectionEnd = target.selectionEnd;
if(selectionStart == selectionEnd) return; // If the user hasn't selected anything, let's select their current parenthesis block
if (selectionStart === selectionEnd) {
// Find opening parenthesis around current cursor
const before = target.value.substring(0, selectionStart);
let beforeParen = before.lastIndexOf("(");
if (beforeParen == -1) return;
let beforeParenClose = before.lastIndexOf(")");
while (beforeParenClose !== -1 && beforeParenClose > beforeParen) {
beforeParen = before.lastIndexOf("(", beforeParen - 1);
beforeParenClose = before.lastIndexOf(")", beforeParenClose - 1);
}
// Find closing parenthesis around current cursor
const after = target.value.substring(selectionStart);
let afterParen = after.indexOf(")");
if (afterParen == -1) return;
let afterParenOpen = after.indexOf("(");
while (afterParenOpen !== -1 && afterParen > afterParenOpen) {
afterParen = after.indexOf(")", afterParen + 1);
afterParenOpen = after.indexOf("(", afterParenOpen + 1);
}
if (beforeParen === -1 || afterParen === -1) return;
// Set the selection to the text between the parenthesis
const parenContent = target.value.substring(beforeParen + 1, selectionStart + afterParen);
const lastColon = parenContent.lastIndexOf(":");
selectionStart = beforeParen + 1;
selectionEnd = selectionStart + lastColon;
target.setSelectionRange(selectionStart, selectionEnd);
}
event.preventDefault(); event.preventDefault();

View file

@ -31,8 +31,8 @@ function imageMaskResize() {
wrapper.style.width = `${wW}px`; wrapper.style.width = `${wW}px`;
wrapper.style.height = `${wH}px`; wrapper.style.height = `${wH}px`;
wrapper.style.left = `${(w-wW)/2}px`; wrapper.style.left = `0px`;
wrapper.style.top = `${(h-wH)/2}px`; wrapper.style.top = `0px`;
canvases.forEach( c => { canvases.forEach( c => {
c.style.width = c.style.height = ''; c.style.width = c.style.height = '';

View file

@ -31,7 +31,7 @@ function updateOnBackgroundChange() {
} }
}) })
if (modalImage.src != currentButton.children[0].src) { if (currentButton?.children?.length > 0 && modalImage.src != currentButton.children[0].src) {
modalImage.src = currentButton.children[0].src; modalImage.src = currentButton.children[0].src;
if (modalImage.style.display === 'none') { if (modalImage.style.display === 'none') {
modal.style.setProperty('background-image', `url(${modalImage.src})`) modal.style.setProperty('background-image', `url(${modalImage.src})`)
@ -116,6 +116,7 @@ function showGalleryImage() {
e.dataset.modded = true; e.dataset.modded = true;
if(e && e.parentElement.tagName == 'DIV'){ if(e && e.parentElement.tagName == 'DIV'){
e.style.cursor='pointer' e.style.cursor='pointer'
e.style.userSelect='none'
e.addEventListener('click', function (evt) { e.addEventListener('click', function (evt) {
if(!opts.js_modal_lightbox) return; if(!opts.js_modal_lightbox) return;
modalZoomSet(gradioApp().getElementById('modalImage'), opts.js_modal_lightbox_initially_zoomed) modalZoomSet(gradioApp().getElementById('modalImage'), opts.js_modal_lightbox_initially_zoomed)

146
javascript/localization.js Normal file
View file

@ -0,0 +1,146 @@
// localization = {} -- the dict with translations is created by the backend
ignore_ids_for_localization={
setting_sd_hypernetwork: 'OPTION',
setting_sd_model_checkpoint: 'OPTION',
setting_realesrgan_enabled_models: 'OPTION',
modelmerger_primary_model_name: 'OPTION',
modelmerger_secondary_model_name: 'OPTION',
modelmerger_tertiary_model_name: 'OPTION',
train_embedding: 'OPTION',
train_hypernetwork: 'OPTION',
txt2img_style_index: 'OPTION',
txt2img_style2_index: 'OPTION',
img2img_style_index: 'OPTION',
img2img_style2_index: 'OPTION',
setting_random_artist_categories: 'SPAN',
setting_face_restoration_model: 'SPAN',
setting_realesrgan_enabled_models: 'SPAN',
extras_upscaler_1: 'SPAN',
extras_upscaler_2: 'SPAN',
}
re_num = /^[\.\d]+$/
re_emoji = /[\p{Extended_Pictographic}\u{1F3FB}-\u{1F3FF}\u{1F9B0}-\u{1F9B3}]/u
original_lines = {}
translated_lines = {}
function textNodesUnder(el){
var n, a=[], walk=document.createTreeWalker(el,NodeFilter.SHOW_TEXT,null,false);
while(n=walk.nextNode()) a.push(n);
return a;
}
function canBeTranslated(node, text){
if(! text) return false;
if(! node.parentElement) return false;
parentType = node.parentElement.nodeName
if(parentType=='SCRIPT' || parentType=='STYLE' || parentType=='TEXTAREA') return false;
if (parentType=='OPTION' || parentType=='SPAN'){
pnode = node
for(var level=0; level<4; level++){
pnode = pnode.parentElement
if(! pnode) break;
if(ignore_ids_for_localization[pnode.id] == parentType) return false;
}
}
if(re_num.test(text)) return false;
if(re_emoji.test(text)) return false;
return true
}
function getTranslation(text){
if(! text) return undefined
if(translated_lines[text] === undefined){
original_lines[text] = 1
}
tl = localization[text]
if(tl !== undefined){
translated_lines[tl] = 1
}
return tl
}
function processTextNode(node){
text = node.textContent.trim()
if(! canBeTranslated(node, text)) return
tl = getTranslation(text)
if(tl !== undefined){
node.textContent = tl
}
}
function processNode(node){
if(node.nodeType == 3){
processTextNode(node)
return
}
if(node.title){
tl = getTranslation(node.title)
if(tl !== undefined){
node.title = tl
}
}
if(node.placeholder){
tl = getTranslation(node.placeholder)
if(tl !== undefined){
node.placeholder = tl
}
}
textNodesUnder(node).forEach(function(node){
processTextNode(node)
})
}
function dumpTranslations(){
dumped = {}
Object.keys(original_lines).forEach(function(text){
if(dumped[text] !== undefined) return
dumped[text] = localization[text] || text
})
return dumped
}
onUiUpdate(function(m){
m.forEach(function(mutation){
mutation.addedNodes.forEach(function(node){
processNode(node)
})
});
})
document.addEventListener("DOMContentLoaded", function() {
processNode(gradioApp())
})
function download_localization() {
text = JSON.stringify(dumpTranslations(), null, 4)
var element = document.createElement('a');
element.setAttribute('href', 'data:text/plain;charset=utf-8,' + encodeURIComponent(text));
element.setAttribute('download', "localization.json");
element.style.display = 'none';
document.body.appendChild(element);
element.click();
document.body.removeChild(element);
}

View file

@ -72,9 +72,17 @@ function check_gallery(id_gallery){
let galleryButtons = gradioApp().querySelectorAll('#'+id_gallery+' .gallery-item') let galleryButtons = gradioApp().querySelectorAll('#'+id_gallery+' .gallery-item')
let galleryBtnSelected = gradioApp().querySelector('#'+id_gallery+' .gallery-item.\\!ring-2') let galleryBtnSelected = gradioApp().querySelector('#'+id_gallery+' .gallery-item.\\!ring-2')
if (prevSelectedIndex !== -1 && galleryButtons.length>prevSelectedIndex && !galleryBtnSelected) { if (prevSelectedIndex !== -1 && galleryButtons.length>prevSelectedIndex && !galleryBtnSelected) {
//automatically re-open previously selected index (if exists) // automatically re-open previously selected index (if exists)
activeElement = gradioApp().activeElement;
galleryButtons[prevSelectedIndex].click(); galleryButtons[prevSelectedIndex].click();
showGalleryImage(); showGalleryImage();
if(activeElement){
// i fought this for about an hour; i don't know why the focus is lost or why this helps recover it
// if somenoe has a better solution please by all means
setTimeout(function() { activeElement.focus() }, 1);
}
} }
}) })
galleryObservers[id_gallery].observe( gallery, { childList:true, subtree:false }) galleryObservers[id_gallery].observe( gallery, { childList:true, subtree:false })

View file

@ -1,5 +1,12 @@
// various functions for interation with ui.py not large enough to warrant putting them in separate files // various functions for interation with ui.py not large enough to warrant putting them in separate files
function set_theme(theme){
gradioURL = window.location.href
if (!gradioURL.includes('?__theme=')) {
window.location.replace(gradioURL + '?__theme=' + theme);
}
}
function selected_gallery_index(){ function selected_gallery_index(){
var buttons = gradioApp().querySelectorAll('[style="display: block;"].tabitem .gallery-item') var buttons = gradioApp().querySelectorAll('[style="display: block;"].tabitem .gallery-item')
var button = gradioApp().querySelector('[style="display: block;"].tabitem .gallery-item.\\!ring-2') var button = gradioApp().querySelector('[style="display: block;"].tabitem .gallery-item.\\!ring-2')

View file

@ -87,6 +87,23 @@ def git_clone(url, dir, name, commithash=None):
run(f'"{git}" -C {dir} checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}") run(f'"{git}" -C {dir} checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}")
def version_check(commit):
try:
import requests
commits = requests.get('https://api.github.com/repos/AUTOMATIC1111/stable-diffusion-webui/branches/master').json()
if commit != "<none>" and commits['commit']['sha'] != commit:
print("--------------------------------------------------------")
print("| You are not up to date with the most recent release. |")
print("| Consider running `git pull` to update. |")
print("--------------------------------------------------------")
elif commits['commit']['sha'] == commit:
print("You are up to date with the most recent release.")
else:
print("Not a git clone, can't perform version check.")
except Exception as e:
print("versipm check failed",e)
def prepare_enviroment(): def prepare_enviroment():
torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113") torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113")
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt") requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
@ -94,6 +111,15 @@ def prepare_enviroment():
gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379") gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379")
clip_package = os.environ.get('CLIP_PACKAGE', "git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1") clip_package = os.environ.get('CLIP_PACKAGE', "git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1")
deepdanbooru_package = os.environ.get('DEEPDANBOORU_PACKAGE', "git+https://github.com/KichangKim/DeepDanbooru.git@edf73df4cdaeea2cf00e9ac08bd8a9026b7a7b26")
xformers_windows_package = os.environ.get('XFORMERS_WINDOWS_PACKAGE', 'https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl')
stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/CompVis/stable-diffusion.git")
taming_transformers_repo = os.environ.get('TAMING_REANSFORMERS_REPO', "https://github.com/CompVis/taming-transformers.git")
k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git')
codeformer_repo = os.environ.get('CODEFORMET_REPO', 'https://github.com/sczhou/CodeFormer.git')
blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git')
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "69ae4b35e0a0f6ee1af8bb9a5d0016ccb27e36dc") stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "69ae4b35e0a0f6ee1af8bb9a5d0016ccb27e36dc")
taming_transformers_commit_hash = os.environ.get('TAMING_TRANSFORMERS_COMMIT_HASH', "24268930bf1dce879235a7fddd0b2355b84d7ea6") taming_transformers_commit_hash = os.environ.get('TAMING_TRANSFORMERS_COMMIT_HASH', "24268930bf1dce879235a7fddd0b2355b84d7ea6")
@ -101,13 +127,14 @@ def prepare_enviroment():
codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af") codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9") blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
args = shlex.split(commandline_args) sys.argv += shlex.split(commandline_args)
args, skip_torch_cuda_test = extract_arg(args, '--skip-torch-cuda-test') sys.argv, skip_torch_cuda_test = extract_arg(sys.argv, '--skip-torch-cuda-test')
args, reinstall_xformers = extract_arg(args, '--reinstall-xformers') sys.argv, reinstall_xformers = extract_arg(sys.argv, '--reinstall-xformers')
xformers = '--xformers' in args sys.argv, update_check = extract_arg(sys.argv, '--update-check')
deepdanbooru = '--deepdanbooru' in args xformers = '--xformers' in sys.argv
ngrok = '--ngrok' in args deepdanbooru = '--deepdanbooru' in sys.argv
ngrok = '--ngrok' in sys.argv
try: try:
commit = run(f"{git} rev-parse HEAD").strip() commit = run(f"{git} rev-parse HEAD").strip()
@ -131,32 +158,33 @@ def prepare_enviroment():
if (not is_installed("xformers") or reinstall_xformers) and xformers and platform.python_version().startswith("3.10"): if (not is_installed("xformers") or reinstall_xformers) and xformers and platform.python_version().startswith("3.10"):
if platform.system() == "Windows": if platform.system() == "Windows":
run_pip("install -U -I --no-deps https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl", "xformers") run_pip(f"install -U -I --no-deps {xformers_windows_package}", "xformers")
elif platform.system() == "Linux": elif platform.system() == "Linux":
run_pip("install xformers", "xformers") run_pip("install xformers", "xformers")
if not is_installed("deepdanbooru") and deepdanbooru: if not is_installed("deepdanbooru") and deepdanbooru:
run_pip("install git+https://github.com/KichangKim/DeepDanbooru.git@edf73df4cdaeea2cf00e9ac08bd8a9026b7a7b26#egg=deepdanbooru[tensorflow] tensorflow==2.10.0 tensorflow-io==0.27.0", "deepdanbooru") run_pip(f"install {deepdanbooru_package}#egg=deepdanbooru[tensorflow] tensorflow==2.10.0 tensorflow-io==0.27.0", "deepdanbooru")
if not is_installed("pyngrok") and ngrok: if not is_installed("pyngrok") and ngrok:
run_pip("install pyngrok", "ngrok") run_pip("install pyngrok", "ngrok")
os.makedirs(dir_repos, exist_ok=True) os.makedirs(dir_repos, exist_ok=True)
git_clone("https://github.com/CompVis/stable-diffusion.git", repo_dir('stable-diffusion'), "Stable Diffusion", stable_diffusion_commit_hash) git_clone(stable_diffusion_repo, repo_dir('stable-diffusion'), "Stable Diffusion", stable_diffusion_commit_hash)
git_clone("https://github.com/CompVis/taming-transformers.git", repo_dir('taming-transformers'), "Taming Transformers", taming_transformers_commit_hash) git_clone(taming_transformers_repo, repo_dir('taming-transformers'), "Taming Transformers", taming_transformers_commit_hash)
git_clone("https://github.com/crowsonkb/k-diffusion.git", repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash) git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash)
git_clone("https://github.com/sczhou/CodeFormer.git", repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash) git_clone(codeformer_repo, repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash)
git_clone("https://github.com/salesforce/BLIP.git", repo_dir('BLIP'), "BLIP", blip_commit_hash) git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash)
if not is_installed("lpips"): if not is_installed("lpips"):
run_pip(f"install -r {os.path.join(repo_dir('CodeFormer'), 'requirements.txt')}", "requirements for CodeFormer") run_pip(f"install -r {os.path.join(repo_dir('CodeFormer'), 'requirements.txt')}", "requirements for CodeFormer")
run_pip(f"install -r {requirements_file}", "requirements for Web UI") run_pip(f"install -r {requirements_file}", "requirements for Web UI")
sys.argv += args if update_check:
version_check(commit)
if "--exit" in args: if "--exit" in sys.argv:
print("Exiting because of --exit argument") print("Exiting because of --exit argument")
exit(0) exit(0)

68
modules/api/api.py Normal file
View file

@ -0,0 +1,68 @@
from modules.api.processing import StableDiffusionProcessingAPI
from modules.processing import StableDiffusionProcessingTxt2Img, process_images
from modules.sd_samplers import all_samplers
from modules.extras import run_pnginfo
import modules.shared as shared
import uvicorn
from fastapi import Body, APIRouter, HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field, Json
import json
import io
import base64
sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None)
class TextToImageResponse(BaseModel):
images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
parameters: Json
info: Json
class Api:
def __init__(self, app, queue_lock):
self.router = APIRouter()
self.app = app
self.queue_lock = queue_lock
self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"])
def text2imgapi(self, txt2imgreq: StableDiffusionProcessingAPI ):
sampler_index = sampler_to_index(txt2imgreq.sampler_index)
if sampler_index is None:
raise HTTPException(status_code=404, detail="Sampler not found")
populate = txt2imgreq.copy(update={ # Override __init__ params
"sd_model": shared.sd_model,
"sampler_index": sampler_index[0],
"do_not_save_samples": True,
"do_not_save_grid": True
}
)
p = StableDiffusionProcessingTxt2Img(**vars(populate))
# Override object param
with self.queue_lock:
processed = process_images(p)
b64images = []
for i in processed.images:
buffer = io.BytesIO()
i.save(buffer, format="png")
b64images.append(base64.b64encode(buffer.getvalue()))
return TextToImageResponse(images=b64images, parameters=json.dumps(vars(txt2imgreq)), info=json.dumps(processed.info))
def img2imgapi(self):
raise NotImplementedError
def extrasapi(self):
raise NotImplementedError
def pnginfoapi(self):
raise NotImplementedError
def launch(self, server_name, port):
self.app.include_router(self.router)
uvicorn.run(self.app, host=server_name, port=port)

99
modules/api/processing.py Normal file
View file

@ -0,0 +1,99 @@
from inflection import underscore
from typing import Any, Dict, Optional
from pydantic import BaseModel, Field, create_model
from modules.processing import StableDiffusionProcessingTxt2Img
import inspect
API_NOT_ALLOWED = [
"self",
"kwargs",
"sd_model",
"outpath_samples",
"outpath_grids",
"sampler_index",
"do_not_save_samples",
"do_not_save_grid",
"extra_generation_params",
"overlay_images",
"do_not_reload_embeddings",
"seed_enable_extras",
"prompt_for_display",
"sampler_noise_scheduler_override",
"ddim_discretize"
]
class ModelDef(BaseModel):
"""Assistance Class for Pydantic Dynamic Model Generation"""
field: str
field_alias: str
field_type: Any
field_value: Any
class PydanticModelGenerator:
"""
Takes in created classes and stubs them out in a way FastAPI/Pydantic is happy about:
source_data is a snapshot of the default values produced by the class
params are the names of the actual keys required by __init__
"""
def __init__(
self,
model_name: str = None,
class_instance = None,
additional_fields = None,
):
def field_type_generator(k, v):
# field_type = str if not overrides.get(k) else overrides[k]["type"]
# print(k, v.annotation, v.default)
field_type = v.annotation
return Optional[field_type]
def merge_class_params(class_):
all_classes = list(filter(lambda x: x is not object, inspect.getmro(class_)))
parameters = {}
for classes in all_classes:
parameters = {**parameters, **inspect.signature(classes.__init__).parameters}
return parameters
self._model_name = model_name
self._class_data = merge_class_params(class_instance)
self._model_def = [
ModelDef(
field=underscore(k),
field_alias=k,
field_type=field_type_generator(k, v),
field_value=v.default
)
for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED
]
for fields in additional_fields:
self._model_def.append(ModelDef(
field=underscore(fields["key"]),
field_alias=fields["key"],
field_type=fields["type"],
field_value=fields["default"]))
def generate_model(self):
"""
Creates a pydantic BaseModel
from the json and overrides provided at initialization
"""
fields = {
d.field: (d.field_type, Field(default=d.field_value, alias=d.field_alias)) for d in self._model_def
}
DynamicModel = create_model(self._model_name, **fields)
DynamicModel.__config__.allow_population_by_field_name = True
DynamicModel.__config__.allow_mutation = True
return DynamicModel
StableDiffusionProcessingAPI = PydanticModelGenerator(
"StableDiffusionProcessingTxt2Img",
StableDiffusionProcessingTxt2Img,
[{"key": "sampler_index", "type": str, "default": "Euler"}]
).generate_model()

View file

@ -157,8 +157,7 @@ def get_deepbooru_tags_from_model(model, tags, pil_image, threshold, deepbooru_o
# sort by reverse by likelihood and normal for alpha, and format tag text as requested # sort by reverse by likelihood and normal for alpha, and format tag text as requested
unsorted_tags_in_theshold.sort(key=lambda y: y[sort_ndx], reverse=(not alpha_sort)) unsorted_tags_in_theshold.sort(key=lambda y: y[sort_ndx], reverse=(not alpha_sort))
for weight, tag in unsorted_tags_in_theshold: for weight, tag in unsorted_tags_in_theshold:
# note: tag_outformat will still have a colon if include_ranks is True tag_outformat = tag
tag_outformat = tag.replace(':', ' ')
if use_spaces: if use_spaces:
tag_outformat = tag_outformat.replace('_', ' ') tag_outformat = tag_outformat.replace('_', ' ')
if use_escape: if use_escape:

View file

@ -20,12 +20,13 @@ import gradio as gr
cached_images = {} cached_images = {}
def run_extras(extras_mode, resize_mode, image, image_folder, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility): def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility):
devices.torch_gc() devices.torch_gc()
imageArr = [] imageArr = []
# Also keep track of original file names # Also keep track of original file names
imageNameArr = [] imageNameArr = []
outputs = []
if extras_mode == 1: if extras_mode == 1:
#convert file to pillow image #convert file to pillow image
@ -33,13 +34,26 @@ def run_extras(extras_mode, resize_mode, image, image_folder, gfpgan_visibility,
image = Image.open(img) image = Image.open(img)
imageArr.append(image) imageArr.append(image)
imageNameArr.append(os.path.splitext(img.orig_name)[0]) imageNameArr.append(os.path.splitext(img.orig_name)[0])
elif extras_mode == 2:
assert not shared.cmd_opts.hide_ui_dir_config, '--hide-ui-dir-config option must be disabled'
if input_dir == '':
return outputs, "Please select an input directory.", ''
image_list = [file for file in [os.path.join(input_dir, x) for x in os.listdir(input_dir)] if os.path.isfile(file)]
for img in image_list:
image = Image.open(img)
imageArr.append(image)
imageNameArr.append(img)
else: else:
imageArr.append(image) imageArr.append(image)
imageNameArr.append(None) imageNameArr.append(None)
if extras_mode == 2 and output_dir != '':
outpath = output_dir
else:
outpath = opts.outdir_samples or opts.outdir_extras_samples outpath = opts.outdir_samples or opts.outdir_extras_samples
outputs = []
for image, image_name in zip(imageArr, imageNameArr): for image, image_name in zip(imageArr, imageNameArr):
if image is None: if image is None:
return outputs, "Please select an input image.", '' return outputs, "Please select an input image.", ''
@ -77,7 +91,8 @@ def run_extras(extras_mode, resize_mode, image, image_folder, gfpgan_visibility,
def upscale(image, scaler_index, resize, mode, resize_w, resize_h, crop): def upscale(image, scaler_index, resize, mode, resize_w, resize_h, crop):
small = image.crop((image.width // 2, image.height // 2, image.width // 2 + 10, image.height // 2 + 10)) small = image.crop((image.width // 2, image.height // 2, image.width // 2 + 10, image.height // 2 + 10))
pixels = tuple(np.array(small).flatten().tolist()) pixels = tuple(np.array(small).flatten().tolist())
key = (resize, scaler_index, image.width, image.height, gfpgan_visibility, codeformer_visibility, codeformer_weight) + pixels key = (resize, scaler_index, image.width, image.height, gfpgan_visibility, codeformer_visibility, codeformer_weight,
resize_mode, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop) + pixels
c = cached_images.get(key) c = cached_images.get(key)
if c is None: if c is None:
@ -112,6 +127,7 @@ def run_extras(extras_mode, resize_mode, image, image_folder, gfpgan_visibility,
image.info = existing_pnginfo image.info = existing_pnginfo
image.info["extras"] = info image.info["extras"] = info
if extras_mode != 2 or show_extras_results :
outputs.append(image) outputs.append(image)
devices.torch_gc() devices.torch_gc()
@ -160,11 +176,14 @@ def run_pnginfo(image):
def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_name, interp_method, multiplier, save_as_half, custom_name): def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_name, interp_method, multiplier, save_as_half, custom_name):
def weighted_sum(theta0, theta1, theta2, alpha): def weighted_sum(theta0, theta1, alpha):
return ((1 - alpha) * theta0) + (alpha * theta1) return ((1 - alpha) * theta0) + (alpha * theta1)
def add_difference(theta0, theta1, theta2, alpha): def get_difference(theta1, theta2):
return theta0 + (theta1 - theta2) * alpha return theta1 - theta2
def add_difference(theta0, theta1_2_diff, alpha):
return theta0 + (alpha * theta1_2_diff)
primary_model_info = sd_models.checkpoints_list[primary_model_name] primary_model_info = sd_models.checkpoints_list[primary_model_name]
secondary_model_info = sd_models.checkpoints_list[secondary_model_name] secondary_model_info = sd_models.checkpoints_list[secondary_model_name]
@ -183,23 +202,31 @@ def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_nam
teritary_model = torch.load(teritary_model_info.filename, map_location='cpu') teritary_model = torch.load(teritary_model_info.filename, map_location='cpu')
theta_2 = sd_models.get_state_dict_from_checkpoint(teritary_model) theta_2 = sd_models.get_state_dict_from_checkpoint(teritary_model)
else: else:
teritary_model = None
theta_2 = None theta_2 = None
theta_funcs = { theta_funcs = {
"Weighted sum": weighted_sum, "Weighted sum": (None, weighted_sum),
"Add difference": add_difference, "Add difference": (get_difference, add_difference),
} }
theta_func = theta_funcs[interp_method] theta_func1, theta_func2 = theta_funcs[interp_method]
print(f"Merging...") print(f"Merging...")
if theta_func1:
for key in tqdm.tqdm(theta_1.keys()):
if 'model' in key:
if key in theta_2:
t2 = theta_2.get(key, torch.zeros_like(theta_1[key]))
theta_1[key] = theta_func1(theta_1[key], t2)
else:
theta_1[key] = torch.zeros_like(theta_1[key])
del theta_2, teritary_model
for key in tqdm.tqdm(theta_0.keys()): for key in tqdm.tqdm(theta_0.keys()):
if 'model' in key and key in theta_1: if 'model' in key and key in theta_1:
t2 = (theta_2 or {}).get(key)
if t2 is None:
t2 = torch.zeros_like(theta_0[key])
theta_0[key] = theta_func(theta_0[key], theta_1[key], t2, multiplier) theta_0[key] = theta_func2(theta_0[key], theta_1[key], multiplier)
if save_as_half: if save_as_half:
theta_0[key] = theta_0[key].half() theta_0[key] = theta_0[key].half()

View file

@ -196,7 +196,7 @@ def stack_conds(conds):
return torch.stack(conds) return torch.stack(conds)
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
assert hypernetwork_name, 'hypernetwork not selected' assert hypernetwork_name, 'hypernetwork not selected'
path = shared.hypernetworks.get(hypernetwork_name, None) path = shared.hypernetworks.get(hypernetwork_name, None)
@ -225,7 +225,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"): with torch.autocast("cuda"):
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=512, height=512, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size) ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size)
if unload: if unload:
shared.sd_model.cond_stage_model.to(devices.cpu) shared.sd_model.cond_stage_model.to(devices.cpu)

View file

@ -1,6 +1,6 @@
import os import os
import shutil import shutil
import sys
def traverse_all_files(output_dir, image_list, curr_dir=None): def traverse_all_files(output_dir, image_list, curr_dir=None):
curr_path = output_dir if curr_dir is None else os.path.join(output_dir, curr_dir) curr_path = output_dir if curr_dir is None else os.path.join(output_dir, curr_dir)
@ -24,10 +24,14 @@ def traverse_all_files(output_dir, image_list, curr_dir=None):
def get_recent_images(dir_name, page_index, step, image_index, tabname): def get_recent_images(dir_name, page_index, step, image_index, tabname):
page_index = int(page_index) page_index = int(page_index)
f_list = os.listdir(dir_name)
image_list = [] image_list = []
if not os.path.exists(dir_name):
pass
elif os.path.isdir(dir_name):
image_list = traverse_all_files(dir_name, image_list) image_list = traverse_all_files(dir_name, image_list)
image_list = sorted(image_list, key=lambda file: -os.path.getctime(os.path.join(dir_name, file))) image_list = sorted(image_list, key=lambda file: -os.path.getctime(os.path.join(dir_name, file)))
else:
print(f'ERROR: "{dir_name}" is not a directory. Check the path in the settings.', file=sys.stderr)
num = 48 if tabname != "extras" else 12 num = 48 if tabname != "extras" else 12
max_page_index = len(image_list) // num + 1 max_page_index = len(image_list) // num + 1
page_index = max_page_index if page_index == -1 else page_index + step page_index = max_page_index if page_index == -1 else page_index + step
@ -105,10 +109,8 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict):
dir_name = opts.outdir_img2img_samples dir_name = opts.outdir_img2img_samples
elif tabname == "extras": elif tabname == "extras":
dir_name = opts.outdir_extras_samples dir_name = opts.outdir_extras_samples
d = dir_name.split("/") else:
dir_name = "/" if dir_name.startswith("/") else d[0] return
for p in d[1:]:
dir_name = os.path.join(dir_name, p)
with gr.Row(): with gr.Row():
renew_page = gr.Button('Renew Page', elem_id=tabname + "_images_history_renew_page") renew_page = gr.Button('Renew Page', elem_id=tabname + "_images_history_renew_page")
first_page = gr.Button('First Page') first_page = gr.Button('First Page')

View file

@ -123,7 +123,7 @@ class InterrogateModels:
return caption[0] return caption[0]
def interrogate(self, pil_image, include_ranks=False): def interrogate(self, pil_image):
res = None res = None
try: try:
@ -156,10 +156,10 @@ class InterrogateModels:
for name, topn, items in self.categories: for name, topn, items in self.categories:
matches = self.rank(image_features, items, top_count=topn) matches = self.rank(image_features, items, top_count=topn)
for match, score in matches: for match, score in matches:
if include_ranks: if shared.opts.interrogate_return_ranks:
res += ", " + match res += f", ({match}:{score/100:.3f})"
else: else:
res += f", ({match}:{score})" res += ", " + match
except Exception: except Exception:
print(f"Error interrogating", file=sys.stderr) print(f"Error interrogating", file=sys.stderr)

31
modules/localization.py Normal file
View file

@ -0,0 +1,31 @@
import json
import os
import sys
import traceback
localizations = {}
def list_localizations(dirname):
localizations.clear()
for file in os.listdir(dirname):
fn, ext = os.path.splitext(file)
if ext.lower() != ".json":
continue
localizations[fn] = os.path.join(dirname, file)
def localization_js(current_localization_name):
fn = localizations.get(current_localization_name, None)
data = {}
if fn is not None:
try:
with open(fn, "r", encoding="utf8") as file:
data = json.load(file)
except Exception:
print(f"Error loading localization from {fn}:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
return f"var localization = {json.dumps(data)}\n"

View file

@ -1,12 +1,14 @@
from pyngrok import ngrok, conf, exception from pyngrok import ngrok, conf, exception
def connect(token, port): def connect(token, port, region):
if token == None: if token == None:
token = 'None' token = 'None'
conf.get_default().auth_token = token config = conf.PyngrokConfig(
auth_token=token, region=region
)
try: try:
public_url = ngrok.connect(port).public_url public_url = ngrok.connect(port, pyngrok_config=config).public_url
except exception.PyngrokNgrokError: except exception.PyngrokNgrokError:
print(f'Invalid ngrok authtoken, ngrok connection aborted.\n' print(f'Invalid ngrok authtoken, ngrok connection aborted.\n'
f'Your token: {token}, get the right one on https://dashboard.ngrok.com/get-started/your-authtoken') f'Your token: {token}, get the right one on https://dashboard.ngrok.com/get-started/your-authtoken')

View file

@ -9,6 +9,7 @@ from PIL import Image, ImageFilter, ImageOps
import random import random
import cv2 import cv2
from skimage import exposure from skimage import exposure
from typing import Any, Dict, List, Optional
import modules.sd_hijack import modules.sd_hijack
from modules import devices, prompt_parser, masking, sd_samplers, lowvram from modules import devices, prompt_parser, masking, sd_samplers, lowvram
@ -51,9 +52,15 @@ def get_correct_sampler(p):
return sd_samplers.samplers return sd_samplers.samplers
elif isinstance(p, modules.processing.StableDiffusionProcessingImg2Img): elif isinstance(p, modules.processing.StableDiffusionProcessingImg2Img):
return sd_samplers.samplers_for_img2img return sd_samplers.samplers_for_img2img
elif isinstance(p, modules.api.processing.StableDiffusionProcessingAPI):
return sd_samplers.samplers
class StableDiffusionProcessing: class StableDiffusionProcessing():
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", styles=None, seed=-1, subseed=-1, subseed_strength=0, seed_resize_from_h=-1, seed_resize_from_w=-1, seed_enable_extras=True, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None, eta=None): """
The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
"""
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str="", styles: List[str]=None, seed: int=-1, subseed: int=-1, subseed_strength: float=0, seed_resize_from_h: int=-1, seed_resize_from_w: int=-1, seed_enable_extras: bool=True, sampler_index: int=0, batch_size: int=1, n_iter: int=1, steps:int =50, cfg_scale:float=7.0, width:int=512, height:int=512, restore_faces:bool=False, tiling:bool=False, do_not_save_samples:bool=False, do_not_save_grid:bool=False, extra_generation_params: Dict[Any,Any]=None, overlay_images: Any=None, negative_prompt: str=None, eta: float =None, do_not_reload_embeddings: bool=False, denoising_strength: float = 0, ddim_discretize: str = "uniform", s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0):
self.sd_model = sd_model self.sd_model = sd_model
self.outpath_samples: str = outpath_samples self.outpath_samples: str = outpath_samples
self.outpath_grids: str = outpath_grids self.outpath_grids: str = outpath_grids
@ -80,15 +87,16 @@ class StableDiffusionProcessing:
self.extra_generation_params: dict = extra_generation_params or {} self.extra_generation_params: dict = extra_generation_params or {}
self.overlay_images = overlay_images self.overlay_images = overlay_images
self.eta = eta self.eta = eta
self.do_not_reload_embeddings = do_not_reload_embeddings
self.paste_to = None self.paste_to = None
self.color_corrections = None self.color_corrections = None
self.denoising_strength: float = 0 self.denoising_strength: float = 0
self.sampler_noise_scheduler_override = None self.sampler_noise_scheduler_override = None
self.ddim_discretize = opts.ddim_discretize self.ddim_discretize = opts.ddim_discretize
self.s_churn = opts.s_churn self.s_churn = s_churn or opts.s_churn
self.s_tmin = opts.s_tmin self.s_tmin = s_tmin or opts.s_tmin
self.s_tmax = float('inf') # not representable as a standard ui option self.s_tmax = s_tmax or float('inf') # not representable as a standard ui option
self.s_noise = opts.s_noise self.s_noise = s_noise or opts.s_noise
if not seed_enable_extras: if not seed_enable_extras:
self.subseed = -1 self.subseed = -1
@ -96,6 +104,7 @@ class StableDiffusionProcessing:
self.seed_resize_from_h = 0 self.seed_resize_from_h = 0
self.seed_resize_from_w = 0 self.seed_resize_from_w = 0
def init(self, all_prompts, all_seeds, all_subseeds): def init(self, all_prompts, all_seeds, all_subseeds):
pass pass
@ -333,12 +342,6 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
seed = get_fixed_seed(p.seed) seed = get_fixed_seed(p.seed)
subseed = get_fixed_seed(p.subseed) subseed = get_fixed_seed(p.subseed)
if p.outpath_samples is not None:
os.makedirs(p.outpath_samples, exist_ok=True)
if p.outpath_grids is not None:
os.makedirs(p.outpath_grids, exist_ok=True)
modules.sd_hijack.model_hijack.apply_circular(p.tiling) modules.sd_hijack.model_hijack.apply_circular(p.tiling)
modules.sd_hijack.model_hijack.clear_comments() modules.sd_hijack.model_hijack.clear_comments()
@ -364,7 +367,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
def infotext(iteration=0, position_in_batch=0): def infotext(iteration=0, position_in_batch=0):
return create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration, position_in_batch) return create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration, position_in_batch)
if os.path.exists(cmd_opts.embeddings_dir): if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings:
model_hijack.embedding_db.load_textual_inversion_embeddings() model_hijack.embedding_db.load_textual_inversion_embeddings()
infotexts = [] infotexts = []
@ -407,12 +410,6 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
with devices.autocast(): with devices.autocast():
samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength) samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength)
if state.interrupted or state.skipped:
# if we are interrupted, sample returns just noise
# use the image collected previously in sampler loop
samples_ddim = shared.state.current_latent
samples_ddim = samples_ddim.to(devices.dtype_vae) samples_ddim = samples_ddim.to(devices.dtype_vae)
x_samples_ddim = decode_first_stage(p.sd_model, samples_ddim) x_samples_ddim = decode_first_stage(p.sd_model, samples_ddim)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
@ -502,7 +499,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
sampler = None sampler = None
def __init__(self, enable_hr=False, denoising_strength=0.75, firstphase_width=0, firstphase_height=0, **kwargs): def __init__(self, enable_hr: bool=False, denoising_strength: float=0.75, firstphase_width: int=0, firstphase_height: int=0, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.enable_hr = enable_hr self.enable_hr = enable_hr
self.denoising_strength = denoising_strength self.denoising_strength = denoising_strength

View file

@ -58,6 +58,9 @@ def load_scripts(basedir):
for filename in sorted(os.listdir(basedir)): for filename in sorted(os.listdir(basedir)):
path = os.path.join(basedir, filename) path = os.path.join(basedir, filename)
if os.path.splitext(path)[1].lower() != '.py':
continue
if not os.path.isfile(path): if not os.path.isfile(path):
continue continue
@ -93,6 +96,7 @@ def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
class ScriptRunner: class ScriptRunner:
def __init__(self): def __init__(self):
self.scripts = [] self.scripts = []
self.titles = []
def setup_ui(self, is_img2img): def setup_ui(self, is_img2img):
for script_class, path in scripts_data: for script_class, path in scripts_data:
@ -104,9 +108,10 @@ class ScriptRunner:
self.scripts.append(script) self.scripts.append(script)
titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.scripts] self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.scripts]
dropdown = gr.Dropdown(label="Script", choices=["None"] + titles, value="None", type="index") dropdown = gr.Dropdown(label="Script", choices=["None"] + self.titles, value="None", type="index")
dropdown.save_to_config = True
inputs = [dropdown] inputs = [dropdown]
for script in self.scripts: for script in self.scripts:
@ -136,6 +141,15 @@ class ScriptRunner:
return [ui.gr_show(True if i == 0 else args_from <= i < args_to) for i in range(len(inputs))] return [ui.gr_show(True if i == 0 else args_from <= i < args_to) for i in range(len(inputs))]
def init_field(title):
if title == 'None':
return
script_index = self.titles.index(title)
script = self.scripts[script_index]
for i in range(script.args_from, script.args_to):
inputs[i].visible = True
dropdown.init_field = init_field
dropdown.change( dropdown.change(
fn=select_script, fn=select_script,
inputs=[dropdown], inputs=[dropdown],

View file

@ -181,7 +181,7 @@ def einsum_op_cuda(q, k, v):
mem_free_torch = mem_reserved - mem_active mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_cuda + mem_free_torch mem_free_total = mem_free_cuda + mem_free_torch
# Divide factor of safety as there's copying and fragmentation # Divide factor of safety as there's copying and fragmentation
return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20)) return einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
def einsum_op(q, k, v): def einsum_op(q, k, v):
if q.device.type == 'cuda': if q.device.type == 'cuda':
@ -296,10 +296,16 @@ def xformers_attnblock_forward(self, x):
try: try:
h_ = x h_ = x
h_ = self.norm(h_) h_ = self.norm(h_)
q1 = self.q(h_).contiguous() q = self.q(h_)
k1 = self.k(h_).contiguous() k = self.k(h_)
v = self.v(h_).contiguous() v = self.v(h_)
out = xformers.ops.memory_efficient_attention(q1, k1, v) b, c, h, w = q.shape
q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v))
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
out = xformers.ops.memory_efficient_attention(q, k, v)
out = rearrange(out, 'b (h w) c -> b c h w', h=h)
out = self.proj_out(out) out = self.proj_out(out)
return x + out return x + out
except NotImplementedError: except NotImplementedError:

View file

@ -122,11 +122,33 @@ def select_checkpoint():
return checkpoint_info return checkpoint_info
chckpoint_dict_replacements = {
'cond_stage_model.transformer.embeddings.': 'cond_stage_model.transformer.text_model.embeddings.',
'cond_stage_model.transformer.encoder.': 'cond_stage_model.transformer.text_model.encoder.',
'cond_stage_model.transformer.final_layer_norm.': 'cond_stage_model.transformer.text_model.final_layer_norm.',
}
def transform_checkpoint_dict_key(k):
for text, replacement in chckpoint_dict_replacements.items():
if k.startswith(text):
k = replacement + k[len(text):]
return k
def get_state_dict_from_checkpoint(pl_sd): def get_state_dict_from_checkpoint(pl_sd):
if "state_dict" in pl_sd: if "state_dict" in pl_sd:
return pl_sd["state_dict"] pl_sd = pl_sd["state_dict"]
return pl_sd sd = {}
for k, v in pl_sd.items():
new_key = transform_checkpoint_dict_key(k)
if new_key is not None:
sd[new_key] = v
return sd
def load_model_weights(model, checkpoint_info): def load_model_weights(model, checkpoint_info):
@ -141,7 +163,7 @@ def load_model_weights(model, checkpoint_info):
print(f"Global Step: {pl_sd['global_step']}") print(f"Global Step: {pl_sd['global_step']}")
sd = get_state_dict_from_checkpoint(pl_sd) sd = get_state_dict_from_checkpoint(pl_sd)
model.load_state_dict(sd, strict=False) missing, extra = model.load_state_dict(sd, strict=False)
if shared.cmd_opts.opt_channelslast: if shared.cmd_opts.opt_channelslast:
model.to(memory_format=torch.channels_last) model.to(memory_format=torch.channels_last)

View file

@ -98,25 +98,8 @@ def store_latent(decoded):
shared.state.current_image = sample_to_image(decoded) shared.state.current_image = sample_to_image(decoded)
class InterruptedException(BaseException):
def extended_tdqm(sequence, *args, desc=None, **kwargs): pass
state.sampling_steps = len(sequence)
state.sampling_step = 0
seq = sequence if cmd_opts.disable_console_progressbars else tqdm.tqdm(sequence, *args, desc=state.job, file=shared.progress_print_out, **kwargs)
for x in seq:
if state.interrupted or state.skipped:
break
yield x
state.sampling_step += 1
shared.total_tqdm.update()
ldm.models.diffusion.ddim.tqdm = lambda *args, desc=None, **kwargs: extended_tdqm(*args, desc=desc, **kwargs)
ldm.models.diffusion.plms.tqdm = lambda *args, desc=None, **kwargs: extended_tdqm(*args, desc=desc, **kwargs)
class VanillaStableDiffusionSampler: class VanillaStableDiffusionSampler:
@ -128,14 +111,32 @@ class VanillaStableDiffusionSampler:
self.init_latent = None self.init_latent = None
self.sampler_noises = None self.sampler_noises = None
self.step = 0 self.step = 0
self.stop_at = None
self.eta = None self.eta = None
self.default_eta = 0.0 self.default_eta = 0.0
self.config = None self.config = None
self.last_latent = None
def number_of_needed_noises(self, p): def number_of_needed_noises(self, p):
return 0 return 0
def launch_sampling(self, steps, func):
state.sampling_steps = steps
state.sampling_step = 0
try:
return func()
except InterruptedException:
return self.last_latent
def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs): def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs):
if state.interrupted or state.skipped:
raise InterruptedException
if self.stop_at is not None and self.step > self.stop_at:
raise InterruptedException
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step) conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step) unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step)
@ -159,11 +160,16 @@ class VanillaStableDiffusionSampler:
res = self.orig_p_sample_ddim(x_dec, cond, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs) res = self.orig_p_sample_ddim(x_dec, cond, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs)
if self.mask is not None: if self.mask is not None:
store_latent(self.init_latent * self.mask + self.nmask * res[1]) self.last_latent = self.init_latent * self.mask + self.nmask * res[1]
else: else:
store_latent(res[1]) self.last_latent = res[1]
store_latent(self.last_latent)
self.step += 1 self.step += 1
state.sampling_step = self.step
shared.total_tqdm.update()
return res return res
def initialize(self, p): def initialize(self, p):
@ -192,7 +198,7 @@ class VanillaStableDiffusionSampler:
self.init_latent = x self.init_latent = x
self.step = 0 self.step = 0
samples = self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning) samples = self.launch_sampling(steps, lambda: self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning))
return samples return samples
@ -206,9 +212,9 @@ class VanillaStableDiffusionSampler:
# existing code fails with certain step counts, like 9 # existing code fails with certain step counts, like 9
try: try:
samples_ddim, _ = self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta) samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
except Exception: except Exception:
samples_ddim, _ = self.sampler.sample(S=steps+1, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta) samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps+1, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
return samples_ddim return samples_ddim
@ -223,6 +229,9 @@ class CFGDenoiser(torch.nn.Module):
self.step = 0 self.step = 0
def forward(self, x, sigma, uncond, cond, cond_scale): def forward(self, x, sigma, uncond, cond, cond_scale):
if state.interrupted or state.skipped:
raise InterruptedException
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step) conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step) uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
@ -268,25 +277,6 @@ class CFGDenoiser(torch.nn.Module):
return denoised return denoised
def extended_trange(sampler, count, *args, **kwargs):
state.sampling_steps = count
state.sampling_step = 0
seq = range(count) if cmd_opts.disable_console_progressbars else tqdm.trange(count, *args, desc=state.job, file=shared.progress_print_out, **kwargs)
for x in seq:
if state.interrupted or state.skipped:
break
if sampler.stop_at is not None and x > sampler.stop_at:
break
yield x
state.sampling_step += 1
shared.total_tqdm.update()
class TorchHijack: class TorchHijack:
def __init__(self, kdiff_sampler): def __init__(self, kdiff_sampler):
self.kdiff_sampler = kdiff_sampler self.kdiff_sampler = kdiff_sampler
@ -314,9 +304,28 @@ class KDiffusionSampler:
self.eta = None self.eta = None
self.default_eta = 1.0 self.default_eta = 1.0
self.config = None self.config = None
self.last_latent = None
def callback_state(self, d): def callback_state(self, d):
store_latent(d["denoised"]) step = d['i']
latent = d["denoised"]
store_latent(latent)
self.last_latent = latent
if self.stop_at is not None and step > self.stop_at:
raise InterruptedException
state.sampling_step = step
shared.total_tqdm.update()
def launch_sampling(self, steps, func):
state.sampling_steps = steps
state.sampling_step = 0
try:
return func()
except InterruptedException:
return self.last_latent
def number_of_needed_noises(self, p): def number_of_needed_noises(self, p):
return p.steps return p.steps
@ -339,9 +348,6 @@ class KDiffusionSampler:
self.sampler_noise_index = 0 self.sampler_noise_index = 0
self.eta = p.eta or opts.eta_ancestral self.eta = p.eta or opts.eta_ancestral
if hasattr(k_diffusion.sampling, 'trange'):
k_diffusion.sampling.trange = lambda *args, **kwargs: extended_trange(self, *args, **kwargs)
if self.sampler_noises is not None: if self.sampler_noises is not None:
k_diffusion.sampling.torch = TorchHijack(self) k_diffusion.sampling.torch = TorchHijack(self)
@ -383,8 +389,9 @@ class KDiffusionSampler:
self.model_wrap_cfg.init_latent = x self.model_wrap_cfg.init_latent = x
return self.func(self.model_wrap_cfg, xi, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs) samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, xi, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs))
return samples
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None): def sample(self, p, x, conditioning, unconditional_conditioning, steps=None):
steps = steps or p.steps steps = steps or p.steps
@ -406,6 +413,8 @@ class KDiffusionSampler:
extra_params_kwargs['n'] = steps extra_params_kwargs['n'] = steps
else: else:
extra_params_kwargs['sigmas'] = sigmas extra_params_kwargs['sigmas'] = sigmas
samples = self.func(self.model_wrap_cfg, x, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs)
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs))
return samples return samples

View file

@ -13,7 +13,7 @@ import modules.memmon
import modules.sd_models import modules.sd_models
import modules.styles import modules.styles
import modules.devices as devices import modules.devices as devices
from modules import sd_samplers, sd_models from modules import sd_samplers, sd_models, localization
from modules.hypernetworks import hypernetwork from modules.hypernetworks import hypernetwork
from modules.paths import models_path, script_path, sd_path from modules.paths import models_path, script_path, sd_path
@ -31,6 +31,7 @@ parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not
parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI") parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI")
parser.add_argument("--embeddings-dir", type=str, default=os.path.join(script_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)") parser.add_argument("--embeddings-dir", type=str, default=os.path.join(script_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)")
parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory") parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory")
parser.add_argument("--localizations-dir", type=str, default=os.path.join(script_path, 'localizations'), help="localizations directory")
parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui") parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui")
parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a little speed for low VRM usage") parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a little speed for low VRM usage")
parser.add_argument("--lowvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a lot of speed for very low VRM usage") parser.add_argument("--lowvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a lot of speed for very low VRM usage")
@ -40,6 +41,7 @@ parser.add_argument("--unload-gfpgan", action='store_true', help="does not do an
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast") parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site (doesn't work for me but you might have better luck)") parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site (doesn't work for me but you might have better luck)")
parser.add_argument("--ngrok", type=str, help="ngrok authtoken, alternative to gradio --share", default=None) parser.add_argument("--ngrok", type=str, help="ngrok authtoken, alternative to gradio --share", default=None)
parser.add_argument("--ngrok-region", type=str, help="The region in which ngrok should start.", default="us")
parser.add_argument("--codeformer-models-path", type=str, help="Path to directory with codeformer model file(s).", default=os.path.join(models_path, 'Codeformer')) parser.add_argument("--codeformer-models-path", type=str, help="Path to directory with codeformer model file(s).", default=os.path.join(models_path, 'Codeformer'))
parser.add_argument("--gfpgan-models-path", type=str, help="Path to directory with GFPGAN model file(s).", default=os.path.join(models_path, 'GFPGAN')) parser.add_argument("--gfpgan-models-path", type=str, help="Path to directory with GFPGAN model file(s).", default=os.path.join(models_path, 'GFPGAN'))
parser.add_argument("--esrgan-models-path", type=str, help="Path to directory with ESRGAN model file(s).", default=os.path.join(models_path, 'ESRGAN')) parser.add_argument("--esrgan-models-path", type=str, help="Path to directory with ESRGAN model file(s).", default=os.path.join(models_path, 'ESRGAN'))
@ -68,14 +70,26 @@ parser.add_argument("--gradio-img2img-tool", type=str, help='gradio image upload
parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last") parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last")
parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(script_path, 'styles.csv')) parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(script_path, 'styles.csv'))
parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False) parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False)
parser.add_argument("--theme", type=str, help="launches the UI with light or dark theme", default=None)
parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False) parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False)
parser.add_argument("--disable-console-progressbars", action='store_true', help="do not output progressbars to console", default=False) parser.add_argument("--disable-console-progressbars", action='store_true', help="do not output progressbars to console", default=False)
parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False) parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False)
parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencoders model', default=None) parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencoders model', default=None)
parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False) parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False)
parser.add_argument("--api", action='store_true', help="use api=True to launch the api with the webui")
parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the api instead of the webui")
cmd_opts = parser.parse_args() cmd_opts = parser.parse_args()
restricted_opts = [
"samples_filename_pattern",
"outdir_samples",
"outdir_txt2img_samples",
"outdir_img2img_samples",
"outdir_extras_samples",
"outdir_grids",
"outdir_txt2img_grids",
"outdir_save",
]
devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_bsrgan, devices.device_esrgan, devices.device_scunet, devices.device_codeformer = \ devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_bsrgan, devices.device_esrgan, devices.device_scunet, devices.device_codeformer = \
(devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'bsrgan', 'esrgan', 'scunet', 'codeformer']) (devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'bsrgan', 'esrgan', 'scunet', 'codeformer'])
@ -92,7 +106,6 @@ os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True)
hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir) hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir)
loaded_hypernetwork = None loaded_hypernetwork = None
def reload_hypernetworks(): def reload_hypernetworks():
global hypernetworks global hypernetworks
@ -140,6 +153,8 @@ interrogator = modules.interrogate.InterrogateModels("interrogate")
face_restorers = [] face_restorers = []
localization.list_localizations(cmd_opts.localizations_dir)
def realesrgan_models_names(): def realesrgan_models_names():
import modules.realesrgan_model import modules.realesrgan_model
@ -280,11 +295,13 @@ options_templates.update(options_section(('ui', "User interface"), {
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"), "do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"), "add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
"add_model_name_to_info": OptionInfo(False, "Add model name to generation information"), "add_model_name_to_info": OptionInfo(False, "Add model name to generation information"),
"disable_weights_auto_swap": OptionInfo(False, "When reading generation parameters from text into UI (from PNG info or pasted text), do not change the selected model/checkpoint."),
"font": OptionInfo("", "Font for image grids that have text"), "font": OptionInfo("", "Font for image grids that have text"),
"js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"), "js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"),
"js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"), "js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),
"show_progress_in_title": OptionInfo(True, "Show generation progress in window title."), "show_progress_in_title": OptionInfo(True, "Show generation progress in window title."),
'quicksettings': OptionInfo("sd_model_checkpoint", "Quicksettings list"), 'quicksettings': OptionInfo("sd_model_checkpoint", "Quicksettings list"),
'localization': OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)),
})) }))
options_templates.update(options_section(('sampler-params', "Sampler parameters"), { options_templates.update(options_section(('sampler-params', "Sampler parameters"), {

View file

@ -45,7 +45,7 @@ class StyleDatabase:
if not os.path.exists(path): if not os.path.exists(path):
return return
with open(path, "r", encoding="utf8", newline='') as file: with open(path, "r", encoding="utf-8-sig", newline='') as file:
reader = csv.DictReader(file) reader = csv.DictReader(file)
for row in reader: for row in reader:
# Support loading old CSV format with "name, text"-columns # Support loading old CSV format with "name, text"-columns
@ -79,7 +79,7 @@ class StyleDatabase:
def save_styles(self, path: str) -> None: def save_styles(self, path: str) -> None:
# Write to temporary file first, so we don't nuke the file if something goes wrong # Write to temporary file first, so we don't nuke the file if something goes wrong
fd, temp_path = tempfile.mkstemp(".csv") fd, temp_path = tempfile.mkstemp(".csv")
with os.fdopen(fd, "w", encoding="utf8", newline='') as file: with os.fdopen(fd, "w", encoding="utf-8-sig", newline='') as file:
# _fields is actually part of the public API: typing.NamedTuple is a replacement for collections.NamedTuple, # _fields is actually part of the public API: typing.NamedTuple is a replacement for collections.NamedTuple,
# and collections.NamedTuple has explicit documentation for accessing _fields. Same goes for _asdict() # and collections.NamedTuple has explicit documentation for accessing _fields. Same goes for _asdict()
writer = csv.DictWriter(file, fieldnames=PromptStyle._fields) writer = csv.DictWriter(file, fieldnames=PromptStyle._fields)

View file

@ -137,6 +137,7 @@ class EmbeddingDatabase:
continue continue
print(f"Loaded a total of {len(self.word_embeddings)} textual inversion embeddings.") print(f"Loaded a total of {len(self.word_embeddings)} textual inversion embeddings.")
print("Embeddings:", ', '.join(self.word_embeddings.keys()))
def find_embedding_at_position(self, tokens, offset): def find_embedding_at_position(self, tokens, offset):
token = tokens[offset] token = tokens[offset]
@ -296,6 +297,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
sd_model=shared.sd_model, sd_model=shared.sd_model,
do_not_save_grid=True, do_not_save_grid=True,
do_not_save_samples=True, do_not_save_samples=True,
do_not_reload_embeddings=True,
) )
if preview_from_txt2img: if preview_from_txt2img:

View file

@ -23,9 +23,9 @@ import gradio as gr
import gradio.utils import gradio.utils
import gradio.routes import gradio.routes
from modules import sd_hijack, sd_models from modules import sd_hijack, sd_models, localization
from modules.paths import script_path from modules.paths import script_path
from modules.shared import opts, cmd_opts from modules.shared import opts, cmd_opts, restricted_opts
if cmd_opts.deepdanbooru: if cmd_opts.deepdanbooru:
from modules.deepbooru import get_deepbooru_tags from modules.deepbooru import get_deepbooru_tags
import modules.shared as shared import modules.shared as shared
@ -56,7 +56,7 @@ if not cmd_opts.share and not cmd_opts.listen:
if cmd_opts.ngrok != None: if cmd_opts.ngrok != None:
import modules.ngrok as ngrok import modules.ngrok as ngrok
print('ngrok authtoken detected, trying to connect...') print('ngrok authtoken detected, trying to connect...')
ngrok.connect(cmd_opts.ngrok, cmd_opts.port if cmd_opts.port != None else 7860) ngrok.connect(cmd_opts.ngrok, cmd_opts.port if cmd_opts.port != None else 7860, cmd_opts.ngrok_region)
def gr_show(visible=True): def gr_show(visible=True):
@ -261,6 +261,19 @@ def wrap_gradio_call(func, extra_outputs=None):
return f return f
def calc_time_left(progress, threshold, label, force_display):
if progress == 0:
return ""
else:
time_since_start = time.time() - shared.state.time_start
eta = (time_since_start/progress)
eta_relative = eta-time_since_start
if (eta_relative > threshold and progress > 0.02) or force_display:
return label + time.strftime('%H:%M:%S', time.gmtime(eta_relative))
else:
return ""
def check_progress_call(id_part): def check_progress_call(id_part):
if shared.state.job_count == 0: if shared.state.job_count == 0:
return "", gr_show(False), gr_show(False), gr_show(False) return "", gr_show(False), gr_show(False), gr_show(False)
@ -272,11 +285,15 @@ def check_progress_call(id_part):
if shared.state.sampling_steps > 0: if shared.state.sampling_steps > 0:
progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps
time_left = calc_time_left( progress, 60, " ETA:", shared.state.time_left_force_display )
if time_left != "":
shared.state.time_left_force_display = True
progress = min(progress, 1) progress = min(progress, 1)
progressbar = "" progressbar = ""
if opts.show_progressbar: if opts.show_progressbar:
progressbar = f"""<div class='progressDiv'><div class='progress' style="width:{progress * 100}%">{str(int(progress*100))+"%" if progress > 0.01 else ""}</div></div>""" progressbar = f"""<div class='progressDiv'><div class='progress' style="overflow:hidden;width:{progress * 100}%">{str(int(progress*100))+"%"+time_left if progress > 0.01 else ""}</div></div>"""
image = gr_show(False) image = gr_show(False)
preview_visibility = gr_show(False) preview_visibility = gr_show(False)
@ -308,6 +325,8 @@ def check_progress_call_initial(id_part):
shared.state.current_latent = None shared.state.current_latent = None
shared.state.current_image = None shared.state.current_image = None
shared.state.textinfo = None shared.state.textinfo = None
shared.state.time_start = time.time()
shared.state.time_left_force_display = False
return check_progress_call(id_part) return check_progress_call(id_part)
@ -508,9 +527,11 @@ def create_toprow(is_img2img):
with gr.Row(): with gr.Row():
with gr.Column(scale=1, elem_id="style_pos_col"): with gr.Column(scale=1, elem_id="style_pos_col"):
prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys()))) prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())))
prompt_style.save_to_config = True
with gr.Column(scale=1, elem_id="style_neg_col"): with gr.Column(scale=1, elem_id="style_neg_col"):
prompt_style2 = gr.Dropdown(label="Style 2", elem_id=f"{id_part}_style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys()))) prompt_style2 = gr.Dropdown(label="Style 2", elem_id=f"{id_part}_style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())))
prompt_style2.save_to_config = True
return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button
@ -540,6 +561,10 @@ def apply_setting(key, value):
if value is None: if value is None:
return gr.update() return gr.update()
# dont allow model to be swapped when model hash exists in prompt
if key == "sd_model_checkpoint" and opts.disable_weights_auto_swap:
return gr.update()
if key == "sd_model_checkpoint": if key == "sd_model_checkpoint":
ckpt_info = sd_models.get_closet_checkpoint_match(value) ckpt_info = sd_models.get_closet_checkpoint_match(value)
@ -566,6 +591,24 @@ def create_ui(wrap_gradio_gpu_call):
import modules.img2img import modules.img2img
import modules.txt2img import modules.txt2img
def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
def refresh():
refresh_method()
args = refreshed_args() if callable(refreshed_args) else refreshed_args
for k, v in args.items():
setattr(refresh_component, k, v)
return gr.update(**(args or {}))
refresh_button = gr.Button(value=refresh_symbol, elem_id=elem_id)
refresh_button.click(
fn = refresh,
inputs = [],
outputs = [refresh_component]
)
return refresh_button
with gr.Blocks(analytics_enabled=False) as txt2img_interface: with gr.Blocks(analytics_enabled=False) as txt2img_interface:
txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False) txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False)
dummy_component = gr.Label(visible=False) dummy_component = gr.Label(visible=False)
@ -1016,6 +1059,15 @@ def create_ui(wrap_gradio_gpu_call):
with gr.TabItem('Batch Process'): with gr.TabItem('Batch Process'):
image_batch = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file") image_batch = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file")
with gr.TabItem('Batch from Directory'):
extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs,
placeholder="A directory on the same machine where the server is running."
)
extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs,
placeholder="Leave blank to save images to the default path."
)
show_extras_results = gr.Checkbox(label='Show result images', value=True)
with gr.Tabs(elem_id="extras_resize_mode"): with gr.Tabs(elem_id="extras_resize_mode"):
with gr.TabItem('Scale by'): with gr.TabItem('Scale by'):
upscaling_resize = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Resize", value=2) upscaling_resize = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Resize", value=2)
@ -1027,10 +1079,10 @@ def create_ui(wrap_gradio_gpu_call):
upscaling_crop = gr.Checkbox(label='Crop to fit', value=True) upscaling_crop = gr.Checkbox(label='Crop to fit', value=True)
with gr.Group(): with gr.Group():
extras_upscaler_1 = gr.Radio(label='Upscaler 1', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") extras_upscaler_1 = gr.Radio(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index")
with gr.Group(): with gr.Group():
extras_upscaler_2 = gr.Radio(label='Upscaler 2', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") extras_upscaler_2 = gr.Radio(label='Upscaler 2', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index")
extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=1) extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=1)
with gr.Group(): with gr.Group():
@ -1060,6 +1112,9 @@ def create_ui(wrap_gradio_gpu_call):
dummy_component, dummy_component,
extras_image, extras_image,
image_batch, image_batch,
extras_batch_input_dir,
extras_batch_output_dir,
show_extras_results,
gfpgan_visibility, gfpgan_visibility,
codeformer_visibility, codeformer_visibility,
codeformer_weight, codeformer_weight,
@ -1191,8 +1246,12 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Tab(label="Train"): with gr.Tab(label="Train"):
gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding; must specify a directory with a set of 1:1 ratio images</p>") gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding; must specify a directory with a set of 1:1 ratio images</p>")
train_embedding_name = gr.Dropdown(label='Embedding', choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())) with gr.Row():
train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', choices=[x for x in shared.hypernetworks.keys()]) train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name")
with gr.Row():
train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=[x for x in shared.hypernetworks.keys()])
create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted([x for x in shared.hypernetworks.keys()])}, "refresh_train_hypernetwork_name")
learn_rate = gr.Textbox(label='Learning rate', placeholder="Learning rate", value="0.005") learn_rate = gr.Textbox(label='Learning rate', placeholder="Learning rate", value="0.005")
batch_size = gr.Number(label='Batch size', value=1, precision=0) batch_size = gr.Number(label='Batch size', value=1, precision=0)
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images") dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
@ -1301,6 +1360,8 @@ def create_ui(wrap_gradio_gpu_call):
batch_size, batch_size,
dataset_directory, dataset_directory,
log_directory, log_directory,
training_width,
training_height,
steps, steps,
create_image_every, create_image_every,
save_embedding_every, save_embedding_every,
@ -1340,31 +1401,18 @@ def create_ui(wrap_gradio_gpu_call):
else: else:
raise Exception(f'bad options item type: {str(t)} for key {key}') raise Exception(f'bad options item type: {str(t)} for key {key}')
elem_id = "setting_"+key
if info.refresh is not None: if info.refresh is not None:
if is_quicksettings: if is_quicksettings:
res = comp(label=info.label, value=fun, **(args or {})) res = comp(label=info.label, value=fun, elem_id=elem_id, **(args or {}))
refresh_button = gr.Button(value=refresh_symbol, elem_id="refresh_"+key) create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key)
else: else:
with gr.Row(variant="compact"): with gr.Row(variant="compact"):
res = comp(label=info.label, value=fun, **(args or {})) res = comp(label=info.label, value=fun, elem_id=elem_id, **(args or {}))
refresh_button = gr.Button(value=refresh_symbol, elem_id="refresh_" + key) create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key)
def refresh():
info.refresh()
refreshed_args = info.component_args() if callable(info.component_args) else info.component_args
for k, v in refreshed_args.items():
setattr(res, k, v)
return gr.update(**(refreshed_args or {}))
refresh_button.click(
fn=refresh,
inputs=[],
outputs=[res],
)
else: else:
res = comp(label=info.label, value=fun, **(args or {})) res = comp(label=info.label, value=fun, elem_id=elem_id, **(args or {}))
return res return res
@ -1373,7 +1421,10 @@ def create_ui(wrap_gradio_gpu_call):
component_dict = {} component_dict = {}
def open_folder(f): def open_folder(f):
if not os.path.isdir(f): if not os.path.exists(f):
print(f'Folder "{f}" does not exist. After you create an image, the folder will be created.')
return
elif not os.path.isdir(f):
print(f""" print(f"""
WARNING WARNING
An open_folder request was made with an argument that is not a folder. An open_folder request was made with an argument that is not a folder.
@ -1406,6 +1457,9 @@ Requested path was: {f}
if comp_args and isinstance(comp_args, dict) and comp_args.get('visible') is False: if comp_args and isinstance(comp_args, dict) and comp_args.get('visible') is False:
continue continue
if cmd_opts.hide_ui_dir_config and key in restricted_opts:
continue
oldval = opts.data.get(key, None) oldval = opts.data.get(key, None)
opts.data[key] = value opts.data[key] = value
@ -1423,6 +1477,9 @@ Requested path was: {f}
if not opts.same_type(value, opts.data_labels[key].default): if not opts.same_type(value, opts.data_labels[key].default):
return gr.update(visible=True), opts.dumpjson() return gr.update(visible=True), opts.dumpjson()
if cmd_opts.hide_ui_dir_config and key in restricted_opts:
return gr.update(value=oldval), opts.dumpjson()
oldval = opts.data.get(key, None) oldval = opts.data.get(key, None)
opts.data[key] = value opts.data[key] = value
@ -1479,6 +1536,9 @@ Requested path was: {f}
with gr.Row(): with gr.Row():
request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications") request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications")
download_localization = gr.Button(value='Download localization template', elem_id="download_localization")
with gr.Row():
reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary') reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary')
restart_gradio = gr.Button(value='Restart Gradio and Refresh components (Custom Scripts, ui.py, js and css only)', variant='primary') restart_gradio = gr.Button(value='Restart Gradio and Refresh components (Custom Scripts, ui.py, js and css only)', variant='primary')
@ -1489,6 +1549,13 @@ Requested path was: {f}
_js='function(){}' _js='function(){}'
) )
download_localization.click(
fn=lambda: None,
inputs=[],
outputs=[],
_js='download_localization'
)
def reload_scripts(): def reload_scripts():
modules.scripts.reload_script_body_only() modules.scripts.reload_script_body_only()
reload_javascript() # need to refresh the html page reload_javascript() # need to refresh the html page
@ -1692,7 +1759,7 @@ Requested path was: {f}
print(traceback.format_exc(), file=sys.stderr) print(traceback.format_exc(), file=sys.stderr)
def loadsave(path, x): def loadsave(path, x):
def apply_field(obj, field, condition=None): def apply_field(obj, field, condition=None, init_field=None):
key = path + "/" + field key = path + "/" + field
if getattr(obj,'custom_script_source',None) is not None: if getattr(obj,'custom_script_source',None) is not None:
@ -1704,8 +1771,12 @@ Requested path was: {f}
saved_value = ui_settings.get(key, None) saved_value = ui_settings.get(key, None)
if saved_value is None: if saved_value is None:
ui_settings[key] = getattr(obj, field) ui_settings[key] = getattr(obj, field)
elif condition is None or condition(saved_value): elif condition and not condition(saved_value):
print(f'Warning: Bad ui setting value: {key}: {saved_value}; Default value "{getattr(obj, field)}" will be used instead.')
else:
setattr(obj, field, saved_value) setattr(obj, field, saved_value)
if init_field is not None:
init_field(saved_value)
if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number] and x.visible: if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number] and x.visible:
apply_field(x, 'visible') apply_field(x, 'visible')
@ -1728,9 +1799,16 @@ Requested path was: {f}
if type(x) == gr.Number: if type(x) == gr.Number:
apply_field(x, 'value') apply_field(x, 'value')
# Since there are many dropdowns that shouldn't be saved,
# we only mark dropdowns that should be saved.
if type(x) == gr.Dropdown and getattr(x, 'save_to_config', False):
apply_field(x, 'value', lambda val: val in x.choices, getattr(x, 'init_field', None))
apply_field(x, 'visible')
visit(txt2img_interface, loadsave, "txt2img") visit(txt2img_interface, loadsave, "txt2img")
visit(img2img_interface, loadsave, "img2img") visit(img2img_interface, loadsave, "img2img")
visit(extras_interface, loadsave, "extras") visit(extras_interface, loadsave, "extras")
visit(modelmerger_interface, loadsave, "modelmerger")
if not error_loading and (not os.path.exists(ui_config_file) or settings_count != len(ui_settings)): if not error_loading and (not os.path.exists(ui_config_file) or settings_count != len(ui_settings)):
with open(ui_config_file, "w", encoding="utf8") as file: with open(ui_config_file, "w", encoding="utf8") as file:
@ -1748,6 +1826,11 @@ def load_javascript(raw_response):
with open(os.path.join(jsdir, filename), "r", encoding="utf8") as jsfile: with open(os.path.join(jsdir, filename), "r", encoding="utf8") as jsfile:
javascript += f"\n<!-- {filename} --><script>{jsfile.read()}</script>" javascript += f"\n<!-- {filename} --><script>{jsfile.read()}</script>"
if cmd_opts.theme is not None:
javascript += f"\n<script>set_theme('{cmd_opts.theme}');</script>\n"
javascript += f"\n<script>{localization.localization_js(shared.opts.localization)}</script>"
def template_response(*args, **kwargs): def template_response(*args, **kwargs):
res = raw_response(*args, **kwargs) res = raw_response(*args, **kwargs)
res.body = res.body.replace( res.body = res.body.replace(

View file

@ -23,3 +23,4 @@ resize-right
torchdiffeq torchdiffeq
kornia kornia
lark lark
inflection

View file

@ -22,3 +22,4 @@ resize-right==0.0.2
torchdiffeq==0.2.3 torchdiffeq==0.2.3
kornia==0.6.7 kornia==0.6.7
lark==1.1.2 lark==1.1.2
inflection==0.5.1

View file

@ -21,20 +21,20 @@ function onUiTabChange(callback){
uiTabChangeCallbacks.push(callback) uiTabChangeCallbacks.push(callback)
} }
function runCallback(x){ function runCallback(x, m){
try { try {
x() x(m)
} catch (e) { } catch (e) {
(console.error || console.log).call(console, e.message, e); (console.error || console.log).call(console, e.message, e);
} }
} }
function executeCallbacks(queue) { function executeCallbacks(queue, m) {
queue.forEach(runCallback) queue.forEach(function(x){runCallback(x, m)})
} }
document.addEventListener("DOMContentLoaded", function() { document.addEventListener("DOMContentLoaded", function() {
var mutationObserver = new MutationObserver(function(m){ var mutationObserver = new MutationObserver(function(m){
executeCallbacks(uiUpdateCallbacks); executeCallbacks(uiUpdateCallbacks, m);
const newTab = get_uiCurrentTab(); const newTab = get_uiCurrentTab();
if ( newTab && ( newTab !== uiCurrentTab ) ) { if ( newTab && ( newTab !== uiCurrentTab ) ) {
uiCurrentTab = newTab; uiCurrentTab = newTab;

View file

@ -233,6 +233,21 @@ def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend, include_lone_
return processed_result return processed_result
class SharedSettingsStackHelper(object):
def __enter__(self):
self.CLIP_stop_at_last_layers = opts.CLIP_stop_at_last_layers
self.hypernetwork = opts.sd_hypernetwork
self.model = shared.sd_model
def __exit__(self, exc_type, exc_value, tb):
modules.sd_models.reload_model_weights(self.model)
hypernetwork.load_hypernetwork(self.hypernetwork)
hypernetwork.apply_strength()
opts.data["CLIP_stop_at_last_layers"] = self.CLIP_stop_at_last_layers
re_range = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\(([+-]\d+)\s*\))?\s*") re_range = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\(([+-]\d+)\s*\))?\s*")
re_range_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\(([+-]\d+(?:.\d*)?)\s*\))?\s*") re_range_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\(([+-]\d+(?:.\d*)?)\s*\))?\s*")
@ -267,9 +282,6 @@ class Script(scripts.Script):
if not opts.return_grid: if not opts.return_grid:
p.batch_size = 1 p.batch_size = 1
CLIP_stop_at_last_layers = opts.CLIP_stop_at_last_layers
def process_axis(opt, vals): def process_axis(opt, vals):
if opt.label == 'Nothing': if opt.label == 'Nothing':
return [0] return [0]
@ -367,6 +379,7 @@ class Script(scripts.Script):
return process_images(pc) return process_images(pc)
with SharedSettingsStackHelper():
processed = draw_xy_grid( processed = draw_xy_grid(
p, p,
xs=xs, xs=xs,
@ -381,13 +394,4 @@ class Script(scripts.Script):
if opts.grid_save: if opts.grid_save:
images.save_image(processed.images[0], p.outpath_grids, "xy_grid", prompt=p.prompt, seed=processed.seed, grid=True, p=p) images.save_image(processed.images[0], p.outpath_grids, "xy_grid", prompt=p.prompt, seed=processed.seed, grid=True, p=p)
# restore checkpoint in case it was changed by axes
modules.sd_models.reload_model_weights(shared.sd_model)
hypernetwork.load_hypernetwork(opts.sd_hypernetwork)
hypernetwork.apply_strength()
opts.data["CLIP_stop_at_last_layers"] = CLIP_stop_at_last_layers
return processed return processed

View file

@ -478,7 +478,7 @@ input[type="range"]{
padding: 0; padding: 0;
} }
#refresh_sd_model_checkpoint, #refresh_sd_hypernetwork{ #refresh_sd_model_checkpoint, #refresh_sd_hypernetwork, #refresh_train_hypernetwork_name, #refresh_train_embedding_name, #refresh_localization{
max-width: 2.5em; max-width: 2.5em;
min-width: 2.5em; min-width: 2.5em;
height: 2.4em; height: 2.4em;

View file

@ -33,7 +33,7 @@ goto :launch
:skip_venv :skip_venv
:launch :launch
%PYTHON% launch.py %PYTHON% launch.py %*
pause pause
exit /b exit /b

View file

@ -4,7 +4,7 @@ import time
import importlib import importlib
import signal import signal
import threading import threading
from fastapi import FastAPI
from fastapi.middleware.gzip import GZipMiddleware from fastapi.middleware.gzip import GZipMiddleware
from modules.paths import script_path from modules.paths import script_path
@ -31,7 +31,6 @@ from modules.paths import script_path
from modules.shared import cmd_opts from modules.shared import cmd_opts
import modules.hypernetworks.hypernetwork import modules.hypernetworks.hypernetwork
queue_lock = threading.Lock() queue_lock = threading.Lock()
@ -87,10 +86,6 @@ def initialize():
shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork))) shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork)))
shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength) shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength)
def webui():
initialize()
# make the program just exit at ctrl+c without waiting for anything # make the program just exit at ctrl+c without waiting for anything
def sigint_handler(sig, frame): def sigint_handler(sig, frame):
print(f'Interrupted with signal {sig} in {frame}') print(f'Interrupted with signal {sig} in {frame}')
@ -98,8 +93,35 @@ def webui():
signal.signal(signal.SIGINT, sigint_handler) signal.signal(signal.SIGINT, sigint_handler)
while 1:
def create_api(app):
from modules.api.api import Api
api = Api(app, queue_lock)
return api
def wait_on_server(demo=None):
while 1:
time.sleep(0.5)
if demo and getattr(demo, 'do_restart', False):
time.sleep(0.5)
demo.close()
time.sleep(0.5)
break
def api_only():
initialize()
app = FastAPI()
app.add_middleware(GZipMiddleware, minimum_size=1000)
api = create_api(app)
api.launch(server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", port=cmd_opts.port if cmd_opts.port else 7861)
def webui(launch_api=False):
initialize()
while 1:
demo = modules.ui.create_ui(wrap_gradio_gpu_call=wrap_gradio_gpu_call) demo = modules.ui.create_ui(wrap_gradio_gpu_call=wrap_gradio_gpu_call)
app, local_url, share_url = demo.launch( app, local_url, share_url = demo.launch(
@ -114,13 +136,10 @@ def webui():
app.add_middleware(GZipMiddleware, minimum_size=1000) app.add_middleware(GZipMiddleware, minimum_size=1000)
while 1: if (launch_api):
time.sleep(0.5) create_api(app)
if getattr(demo, 'do_restart', False):
time.sleep(0.5) wait_on_server(demo)
demo.close()
time.sleep(0.5)
break
sd_samplers.set_samplers() sd_samplers.set_samplers()
@ -133,5 +152,10 @@ def webui():
print('Restarting Gradio') print('Restarting Gradio')
task = []
if __name__ == "__main__": if __name__ == "__main__":
webui() if cmd_opts.nowebui:
api_only()
else:
webui(cmd_opts.api)

View file

@ -138,4 +138,4 @@ fi
printf "\n%s\n" "${delimiter}" printf "\n%s\n" "${delimiter}"
printf "Launching launch.py..." printf "Launching launch.py..."
printf "\n%s\n" "${delimiter}" printf "\n%s\n" "${delimiter}"
"${python_cmd}" "${LAUNCH_SCRIPT}" "${python_cmd}" "${LAUNCH_SCRIPT}" "$@"