Merge branch 'master' into hot-reload-javascript
This commit is contained in:
commit
05315d8a23
36 changed files with 858 additions and 232 deletions
6
.github/workflows/on_pull_request.yaml
vendored
6
.github/workflows/on_pull_request.yaml
vendored
|
@ -22,6 +22,12 @@ jobs:
|
||||||
uses: actions/setup-python@v3
|
uses: actions/setup-python@v3
|
||||||
with:
|
with:
|
||||||
python-version: 3.10.6
|
python-version: 3.10.6
|
||||||
|
- uses: actions/cache@v2
|
||||||
|
with:
|
||||||
|
path: ~/.cache/pip
|
||||||
|
key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }}
|
||||||
|
restore-keys: |
|
||||||
|
${{ runner.os }}-pip-
|
||||||
- name: Install PyLint
|
- name: Install PyLint
|
||||||
run: |
|
run: |
|
||||||
python -m pip install --upgrade pip
|
python -m pip install --upgrade pip
|
||||||
|
|
|
@ -523,7 +523,6 @@ Affandi,0.7170285,nudity
|
||||||
Diane Arbus,0.655138,digipa-high-impact
|
Diane Arbus,0.655138,digipa-high-impact
|
||||||
Joseph Ducreux,0.65247905,digipa-high-impact
|
Joseph Ducreux,0.65247905,digipa-high-impact
|
||||||
Berthe Morisot,0.7165984,fineart
|
Berthe Morisot,0.7165984,fineart
|
||||||
Hilma AF Klint,0.71643853,scribbles
|
|
||||||
Hilma af Klint,0.71643853,scribbles
|
Hilma af Klint,0.71643853,scribbles
|
||||||
Filippino Lippi,0.7163017,fineart
|
Filippino Lippi,0.7163017,fineart
|
||||||
Leonid Afremov,0.7163005,fineart
|
Leonid Afremov,0.7163005,fineart
|
||||||
|
@ -738,14 +737,12 @@ Abraham Mignon,0.60605425,fineart
|
||||||
Albert Bloch,0.69573116,nudity
|
Albert Bloch,0.69573116,nudity
|
||||||
Charles Dana Gibson,0.67155975,fineart
|
Charles Dana Gibson,0.67155975,fineart
|
||||||
Alexandre-Évariste Fragonard,0.6507174,fineart
|
Alexandre-Évariste Fragonard,0.6507174,fineart
|
||||||
Alexandre-Évariste Fragonard,0.6507174,fineart
|
|
||||||
Ernst Fuchs,0.6953538,nudity
|
Ernst Fuchs,0.6953538,nudity
|
||||||
Alfredo Jaar,0.6952965,digipa-high-impact
|
Alfredo Jaar,0.6952965,digipa-high-impact
|
||||||
Judy Chicago,0.6952246,weird
|
Judy Chicago,0.6952246,weird
|
||||||
Frans van Mieris the Younger,0.6951849,fineart
|
Frans van Mieris the Younger,0.6951849,fineart
|
||||||
Aertgen van Leyden,0.6951305,fineart
|
Aertgen van Leyden,0.6951305,fineart
|
||||||
Emily Carr,0.69512105,fineart
|
Emily Carr,0.69512105,fineart
|
||||||
Frances Macdonald,0.6950408,scribbles
|
|
||||||
Frances MacDonald,0.6950408,scribbles
|
Frances MacDonald,0.6950408,scribbles
|
||||||
Hannah Höch,0.69495845,scribbles
|
Hannah Höch,0.69495845,scribbles
|
||||||
Gillis Rombouts,0.58770025,fineart
|
Gillis Rombouts,0.58770025,fineart
|
||||||
|
@ -895,7 +892,6 @@ Richard McGuire,0.6820089,scribbles
|
||||||
Anni Albers,0.65708244,digipa-high-impact
|
Anni Albers,0.65708244,digipa-high-impact
|
||||||
Aleksey Savrasov,0.65207493,fineart
|
Aleksey Savrasov,0.65207493,fineart
|
||||||
Wayne Barlowe,0.6537874,fineart
|
Wayne Barlowe,0.6537874,fineart
|
||||||
Giorgio De Chirico,0.6815907,fineart
|
|
||||||
Giorgio de Chirico,0.6815907,fineart
|
Giorgio de Chirico,0.6815907,fineart
|
||||||
Ernest Procter,0.6815795,fineart
|
Ernest Procter,0.6815795,fineart
|
||||||
Adriaen Brouwer,0.6815058,fineart
|
Adriaen Brouwer,0.6815058,fineart
|
||||||
|
@ -1241,7 +1237,6 @@ Betty Churcher,0.65387225,fineart
|
||||||
Claes Corneliszoon Moeyaert,0.65386075,fineart
|
Claes Corneliszoon Moeyaert,0.65386075,fineart
|
||||||
David Bomberg,0.6537477,fineart
|
David Bomberg,0.6537477,fineart
|
||||||
Abraham Bosschaert,0.6535562,fineart
|
Abraham Bosschaert,0.6535562,fineart
|
||||||
Giuseppe De Nittis,0.65354455,fineart
|
|
||||||
Giuseppe de Nittis,0.65354455,fineart
|
Giuseppe de Nittis,0.65354455,fineart
|
||||||
John La Farge,0.65342575,fineart
|
John La Farge,0.65342575,fineart
|
||||||
Frits Thaulow,0.65341854,fineart
|
Frits Thaulow,0.65341854,fineart
|
||||||
|
@ -1522,7 +1517,6 @@ Gertrude Harvey,0.5903887,fineart
|
||||||
Grant Wood,0.6266253,fineart
|
Grant Wood,0.6266253,fineart
|
||||||
Fyodor Vasilyev,0.5234919,digipa-med-impact
|
Fyodor Vasilyev,0.5234919,digipa-med-impact
|
||||||
Cagnaccio di San Pietro,0.6261671,fineart
|
Cagnaccio di San Pietro,0.6261671,fineart
|
||||||
Cagnaccio Di San Pietro,0.6261671,fineart
|
|
||||||
Doris Boulton-Maude,0.62593174,fineart
|
Doris Boulton-Maude,0.62593174,fineart
|
||||||
Adolf Hirémy-Hirschl,0.5946784,fineart
|
Adolf Hirémy-Hirschl,0.5946784,fineart
|
||||||
Harold von Schmidt,0.6256755,fineart
|
Harold von Schmidt,0.6256755,fineart
|
||||||
|
@ -2411,7 +2405,6 @@ Hermann Feierabend,0.5346168,digipa-high-impact
|
||||||
Antonio Donghi,0.4610982,digipa-low-impact
|
Antonio Donghi,0.4610982,digipa-low-impact
|
||||||
Adonna Khare,0.4858036,digipa-med-impact
|
Adonna Khare,0.4858036,digipa-med-impact
|
||||||
James Stokoe,0.5015107,digipa-med-impact
|
James Stokoe,0.5015107,digipa-med-impact
|
||||||
Art & Language,0.5341332,digipa-high-impact
|
|
||||||
Agustín Fernández,0.53403986,fineart
|
Agustín Fernández,0.53403986,fineart
|
||||||
Germán Londoño,0.5338712,fineart
|
Germán Londoño,0.5338712,fineart
|
||||||
Emmanuelle Moureaux,0.5335641,digipa-high-impact
|
Emmanuelle Moureaux,0.5335641,digipa-high-impact
|
||||||
|
|
|
|
@ -9,9 +9,38 @@ addEventListener('keydown', (event) => {
|
||||||
let minus = "ArrowDown"
|
let minus = "ArrowDown"
|
||||||
if (event.key != plus && event.key != minus) return;
|
if (event.key != plus && event.key != minus) return;
|
||||||
|
|
||||||
selectionStart = target.selectionStart;
|
let selectionStart = target.selectionStart;
|
||||||
selectionEnd = target.selectionEnd;
|
let selectionEnd = target.selectionEnd;
|
||||||
if(selectionStart == selectionEnd) return;
|
// If the user hasn't selected anything, let's select their current parenthesis block
|
||||||
|
if (selectionStart === selectionEnd) {
|
||||||
|
// Find opening parenthesis around current cursor
|
||||||
|
const before = target.value.substring(0, selectionStart);
|
||||||
|
let beforeParen = before.lastIndexOf("(");
|
||||||
|
if (beforeParen == -1) return;
|
||||||
|
let beforeParenClose = before.lastIndexOf(")");
|
||||||
|
while (beforeParenClose !== -1 && beforeParenClose > beforeParen) {
|
||||||
|
beforeParen = before.lastIndexOf("(", beforeParen - 1);
|
||||||
|
beforeParenClose = before.lastIndexOf(")", beforeParenClose - 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find closing parenthesis around current cursor
|
||||||
|
const after = target.value.substring(selectionStart);
|
||||||
|
let afterParen = after.indexOf(")");
|
||||||
|
if (afterParen == -1) return;
|
||||||
|
let afterParenOpen = after.indexOf("(");
|
||||||
|
while (afterParenOpen !== -1 && afterParen > afterParenOpen) {
|
||||||
|
afterParen = after.indexOf(")", afterParen + 1);
|
||||||
|
afterParenOpen = after.indexOf("(", afterParenOpen + 1);
|
||||||
|
}
|
||||||
|
if (beforeParen === -1 || afterParen === -1) return;
|
||||||
|
|
||||||
|
// Set the selection to the text between the parenthesis
|
||||||
|
const parenContent = target.value.substring(beforeParen + 1, selectionStart + afterParen);
|
||||||
|
const lastColon = parenContent.lastIndexOf(":");
|
||||||
|
selectionStart = beforeParen + 1;
|
||||||
|
selectionEnd = selectionStart + lastColon;
|
||||||
|
target.setSelectionRange(selectionStart, selectionEnd);
|
||||||
|
}
|
||||||
|
|
||||||
event.preventDefault();
|
event.preventDefault();
|
||||||
|
|
||||||
|
|
|
@ -31,8 +31,8 @@ function imageMaskResize() {
|
||||||
|
|
||||||
wrapper.style.width = `${wW}px`;
|
wrapper.style.width = `${wW}px`;
|
||||||
wrapper.style.height = `${wH}px`;
|
wrapper.style.height = `${wH}px`;
|
||||||
wrapper.style.left = `${(w-wW)/2}px`;
|
wrapper.style.left = `0px`;
|
||||||
wrapper.style.top = `${(h-wH)/2}px`;
|
wrapper.style.top = `0px`;
|
||||||
|
|
||||||
canvases.forEach( c => {
|
canvases.forEach( c => {
|
||||||
c.style.width = c.style.height = '';
|
c.style.width = c.style.height = '';
|
||||||
|
|
|
@ -31,7 +31,7 @@ function updateOnBackgroundChange() {
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
if (modalImage.src != currentButton.children[0].src) {
|
if (currentButton?.children?.length > 0 && modalImage.src != currentButton.children[0].src) {
|
||||||
modalImage.src = currentButton.children[0].src;
|
modalImage.src = currentButton.children[0].src;
|
||||||
if (modalImage.style.display === 'none') {
|
if (modalImage.style.display === 'none') {
|
||||||
modal.style.setProperty('background-image', `url(${modalImage.src})`)
|
modal.style.setProperty('background-image', `url(${modalImage.src})`)
|
||||||
|
@ -116,6 +116,7 @@ function showGalleryImage() {
|
||||||
e.dataset.modded = true;
|
e.dataset.modded = true;
|
||||||
if(e && e.parentElement.tagName == 'DIV'){
|
if(e && e.parentElement.tagName == 'DIV'){
|
||||||
e.style.cursor='pointer'
|
e.style.cursor='pointer'
|
||||||
|
e.style.userSelect='none'
|
||||||
e.addEventListener('click', function (evt) {
|
e.addEventListener('click', function (evt) {
|
||||||
if(!opts.js_modal_lightbox) return;
|
if(!opts.js_modal_lightbox) return;
|
||||||
modalZoomSet(gradioApp().getElementById('modalImage'), opts.js_modal_lightbox_initially_zoomed)
|
modalZoomSet(gradioApp().getElementById('modalImage'), opts.js_modal_lightbox_initially_zoomed)
|
||||||
|
|
146
javascript/localization.js
Normal file
146
javascript/localization.js
Normal file
|
@ -0,0 +1,146 @@
|
||||||
|
|
||||||
|
// localization = {} -- the dict with translations is created by the backend
|
||||||
|
|
||||||
|
ignore_ids_for_localization={
|
||||||
|
setting_sd_hypernetwork: 'OPTION',
|
||||||
|
setting_sd_model_checkpoint: 'OPTION',
|
||||||
|
setting_realesrgan_enabled_models: 'OPTION',
|
||||||
|
modelmerger_primary_model_name: 'OPTION',
|
||||||
|
modelmerger_secondary_model_name: 'OPTION',
|
||||||
|
modelmerger_tertiary_model_name: 'OPTION',
|
||||||
|
train_embedding: 'OPTION',
|
||||||
|
train_hypernetwork: 'OPTION',
|
||||||
|
txt2img_style_index: 'OPTION',
|
||||||
|
txt2img_style2_index: 'OPTION',
|
||||||
|
img2img_style_index: 'OPTION',
|
||||||
|
img2img_style2_index: 'OPTION',
|
||||||
|
setting_random_artist_categories: 'SPAN',
|
||||||
|
setting_face_restoration_model: 'SPAN',
|
||||||
|
setting_realesrgan_enabled_models: 'SPAN',
|
||||||
|
extras_upscaler_1: 'SPAN',
|
||||||
|
extras_upscaler_2: 'SPAN',
|
||||||
|
}
|
||||||
|
|
||||||
|
re_num = /^[\.\d]+$/
|
||||||
|
re_emoji = /[\p{Extended_Pictographic}\u{1F3FB}-\u{1F3FF}\u{1F9B0}-\u{1F9B3}]/u
|
||||||
|
|
||||||
|
original_lines = {}
|
||||||
|
translated_lines = {}
|
||||||
|
|
||||||
|
function textNodesUnder(el){
|
||||||
|
var n, a=[], walk=document.createTreeWalker(el,NodeFilter.SHOW_TEXT,null,false);
|
||||||
|
while(n=walk.nextNode()) a.push(n);
|
||||||
|
return a;
|
||||||
|
}
|
||||||
|
|
||||||
|
function canBeTranslated(node, text){
|
||||||
|
if(! text) return false;
|
||||||
|
if(! node.parentElement) return false;
|
||||||
|
|
||||||
|
parentType = node.parentElement.nodeName
|
||||||
|
if(parentType=='SCRIPT' || parentType=='STYLE' || parentType=='TEXTAREA') return false;
|
||||||
|
|
||||||
|
if (parentType=='OPTION' || parentType=='SPAN'){
|
||||||
|
pnode = node
|
||||||
|
for(var level=0; level<4; level++){
|
||||||
|
pnode = pnode.parentElement
|
||||||
|
if(! pnode) break;
|
||||||
|
|
||||||
|
if(ignore_ids_for_localization[pnode.id] == parentType) return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if(re_num.test(text)) return false;
|
||||||
|
if(re_emoji.test(text)) return false;
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
function getTranslation(text){
|
||||||
|
if(! text) return undefined
|
||||||
|
|
||||||
|
if(translated_lines[text] === undefined){
|
||||||
|
original_lines[text] = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
tl = localization[text]
|
||||||
|
if(tl !== undefined){
|
||||||
|
translated_lines[tl] = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
return tl
|
||||||
|
}
|
||||||
|
|
||||||
|
function processTextNode(node){
|
||||||
|
text = node.textContent.trim()
|
||||||
|
|
||||||
|
if(! canBeTranslated(node, text)) return
|
||||||
|
|
||||||
|
tl = getTranslation(text)
|
||||||
|
if(tl !== undefined){
|
||||||
|
node.textContent = tl
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function processNode(node){
|
||||||
|
if(node.nodeType == 3){
|
||||||
|
processTextNode(node)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if(node.title){
|
||||||
|
tl = getTranslation(node.title)
|
||||||
|
if(tl !== undefined){
|
||||||
|
node.title = tl
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if(node.placeholder){
|
||||||
|
tl = getTranslation(node.placeholder)
|
||||||
|
if(tl !== undefined){
|
||||||
|
node.placeholder = tl
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
textNodesUnder(node).forEach(function(node){
|
||||||
|
processTextNode(node)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
function dumpTranslations(){
|
||||||
|
dumped = {}
|
||||||
|
|
||||||
|
Object.keys(original_lines).forEach(function(text){
|
||||||
|
if(dumped[text] !== undefined) return
|
||||||
|
|
||||||
|
dumped[text] = localization[text] || text
|
||||||
|
})
|
||||||
|
|
||||||
|
return dumped
|
||||||
|
}
|
||||||
|
|
||||||
|
onUiUpdate(function(m){
|
||||||
|
m.forEach(function(mutation){
|
||||||
|
mutation.addedNodes.forEach(function(node){
|
||||||
|
processNode(node)
|
||||||
|
})
|
||||||
|
});
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
document.addEventListener("DOMContentLoaded", function() {
|
||||||
|
processNode(gradioApp())
|
||||||
|
})
|
||||||
|
|
||||||
|
function download_localization() {
|
||||||
|
text = JSON.stringify(dumpTranslations(), null, 4)
|
||||||
|
|
||||||
|
var element = document.createElement('a');
|
||||||
|
element.setAttribute('href', 'data:text/plain;charset=utf-8,' + encodeURIComponent(text));
|
||||||
|
element.setAttribute('download', "localization.json");
|
||||||
|
element.style.display = 'none';
|
||||||
|
document.body.appendChild(element);
|
||||||
|
|
||||||
|
element.click();
|
||||||
|
|
||||||
|
document.body.removeChild(element);
|
||||||
|
}
|
|
@ -72,9 +72,17 @@ function check_gallery(id_gallery){
|
||||||
let galleryButtons = gradioApp().querySelectorAll('#'+id_gallery+' .gallery-item')
|
let galleryButtons = gradioApp().querySelectorAll('#'+id_gallery+' .gallery-item')
|
||||||
let galleryBtnSelected = gradioApp().querySelector('#'+id_gallery+' .gallery-item.\\!ring-2')
|
let galleryBtnSelected = gradioApp().querySelector('#'+id_gallery+' .gallery-item.\\!ring-2')
|
||||||
if (prevSelectedIndex !== -1 && galleryButtons.length>prevSelectedIndex && !galleryBtnSelected) {
|
if (prevSelectedIndex !== -1 && galleryButtons.length>prevSelectedIndex && !galleryBtnSelected) {
|
||||||
//automatically re-open previously selected index (if exists)
|
// automatically re-open previously selected index (if exists)
|
||||||
|
activeElement = gradioApp().activeElement;
|
||||||
|
|
||||||
galleryButtons[prevSelectedIndex].click();
|
galleryButtons[prevSelectedIndex].click();
|
||||||
showGalleryImage();
|
showGalleryImage();
|
||||||
|
|
||||||
|
if(activeElement){
|
||||||
|
// i fought this for about an hour; i don't know why the focus is lost or why this helps recover it
|
||||||
|
// if somenoe has a better solution please by all means
|
||||||
|
setTimeout(function() { activeElement.focus() }, 1);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
galleryObservers[id_gallery].observe( gallery, { childList:true, subtree:false })
|
galleryObservers[id_gallery].observe( gallery, { childList:true, subtree:false })
|
||||||
|
|
|
@ -1,5 +1,12 @@
|
||||||
// various functions for interation with ui.py not large enough to warrant putting them in separate files
|
// various functions for interation with ui.py not large enough to warrant putting them in separate files
|
||||||
|
|
||||||
|
function set_theme(theme){
|
||||||
|
gradioURL = window.location.href
|
||||||
|
if (!gradioURL.includes('?__theme=')) {
|
||||||
|
window.location.replace(gradioURL + '?__theme=' + theme);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
function selected_gallery_index(){
|
function selected_gallery_index(){
|
||||||
var buttons = gradioApp().querySelectorAll('[style="display: block;"].tabitem .gallery-item')
|
var buttons = gradioApp().querySelectorAll('[style="display: block;"].tabitem .gallery-item')
|
||||||
var button = gradioApp().querySelector('[style="display: block;"].tabitem .gallery-item.\\!ring-2')
|
var button = gradioApp().querySelector('[style="display: block;"].tabitem .gallery-item.\\!ring-2')
|
||||||
|
|
58
launch.py
58
launch.py
|
@ -87,6 +87,23 @@ def git_clone(url, dir, name, commithash=None):
|
||||||
run(f'"{git}" -C {dir} checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}")
|
run(f'"{git}" -C {dir} checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}")
|
||||||
|
|
||||||
|
|
||||||
|
def version_check(commit):
|
||||||
|
try:
|
||||||
|
import requests
|
||||||
|
commits = requests.get('https://api.github.com/repos/AUTOMATIC1111/stable-diffusion-webui/branches/master').json()
|
||||||
|
if commit != "<none>" and commits['commit']['sha'] != commit:
|
||||||
|
print("--------------------------------------------------------")
|
||||||
|
print("| You are not up to date with the most recent release. |")
|
||||||
|
print("| Consider running `git pull` to update. |")
|
||||||
|
print("--------------------------------------------------------")
|
||||||
|
elif commits['commit']['sha'] == commit:
|
||||||
|
print("You are up to date with the most recent release.")
|
||||||
|
else:
|
||||||
|
print("Not a git clone, can't perform version check.")
|
||||||
|
except Exception as e:
|
||||||
|
print("versipm check failed",e)
|
||||||
|
|
||||||
|
|
||||||
def prepare_enviroment():
|
def prepare_enviroment():
|
||||||
torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113")
|
torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113")
|
||||||
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
|
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
|
||||||
|
@ -94,6 +111,15 @@ def prepare_enviroment():
|
||||||
|
|
||||||
gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379")
|
gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379")
|
||||||
clip_package = os.environ.get('CLIP_PACKAGE', "git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1")
|
clip_package = os.environ.get('CLIP_PACKAGE', "git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1")
|
||||||
|
deepdanbooru_package = os.environ.get('DEEPDANBOORU_PACKAGE', "git+https://github.com/KichangKim/DeepDanbooru.git@edf73df4cdaeea2cf00e9ac08bd8a9026b7a7b26")
|
||||||
|
|
||||||
|
xformers_windows_package = os.environ.get('XFORMERS_WINDOWS_PACKAGE', 'https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl')
|
||||||
|
|
||||||
|
stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/CompVis/stable-diffusion.git")
|
||||||
|
taming_transformers_repo = os.environ.get('TAMING_REANSFORMERS_REPO', "https://github.com/CompVis/taming-transformers.git")
|
||||||
|
k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git')
|
||||||
|
codeformer_repo = os.environ.get('CODEFORMET_REPO', 'https://github.com/sczhou/CodeFormer.git')
|
||||||
|
blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git')
|
||||||
|
|
||||||
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "69ae4b35e0a0f6ee1af8bb9a5d0016ccb27e36dc")
|
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "69ae4b35e0a0f6ee1af8bb9a5d0016ccb27e36dc")
|
||||||
taming_transformers_commit_hash = os.environ.get('TAMING_TRANSFORMERS_COMMIT_HASH', "24268930bf1dce879235a7fddd0b2355b84d7ea6")
|
taming_transformers_commit_hash = os.environ.get('TAMING_TRANSFORMERS_COMMIT_HASH', "24268930bf1dce879235a7fddd0b2355b84d7ea6")
|
||||||
|
@ -101,13 +127,14 @@ def prepare_enviroment():
|
||||||
codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
|
codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
|
||||||
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
|
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
|
||||||
|
|
||||||
args = shlex.split(commandline_args)
|
sys.argv += shlex.split(commandline_args)
|
||||||
|
|
||||||
args, skip_torch_cuda_test = extract_arg(args, '--skip-torch-cuda-test')
|
sys.argv, skip_torch_cuda_test = extract_arg(sys.argv, '--skip-torch-cuda-test')
|
||||||
args, reinstall_xformers = extract_arg(args, '--reinstall-xformers')
|
sys.argv, reinstall_xformers = extract_arg(sys.argv, '--reinstall-xformers')
|
||||||
xformers = '--xformers' in args
|
sys.argv, update_check = extract_arg(sys.argv, '--update-check')
|
||||||
deepdanbooru = '--deepdanbooru' in args
|
xformers = '--xformers' in sys.argv
|
||||||
ngrok = '--ngrok' in args
|
deepdanbooru = '--deepdanbooru' in sys.argv
|
||||||
|
ngrok = '--ngrok' in sys.argv
|
||||||
|
|
||||||
try:
|
try:
|
||||||
commit = run(f"{git} rev-parse HEAD").strip()
|
commit = run(f"{git} rev-parse HEAD").strip()
|
||||||
|
@ -131,32 +158,33 @@ def prepare_enviroment():
|
||||||
|
|
||||||
if (not is_installed("xformers") or reinstall_xformers) and xformers and platform.python_version().startswith("3.10"):
|
if (not is_installed("xformers") or reinstall_xformers) and xformers and platform.python_version().startswith("3.10"):
|
||||||
if platform.system() == "Windows":
|
if platform.system() == "Windows":
|
||||||
run_pip("install -U -I --no-deps https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl", "xformers")
|
run_pip(f"install -U -I --no-deps {xformers_windows_package}", "xformers")
|
||||||
elif platform.system() == "Linux":
|
elif platform.system() == "Linux":
|
||||||
run_pip("install xformers", "xformers")
|
run_pip("install xformers", "xformers")
|
||||||
|
|
||||||
if not is_installed("deepdanbooru") and deepdanbooru:
|
if not is_installed("deepdanbooru") and deepdanbooru:
|
||||||
run_pip("install git+https://github.com/KichangKim/DeepDanbooru.git@edf73df4cdaeea2cf00e9ac08bd8a9026b7a7b26#egg=deepdanbooru[tensorflow] tensorflow==2.10.0 tensorflow-io==0.27.0", "deepdanbooru")
|
run_pip(f"install {deepdanbooru_package}#egg=deepdanbooru[tensorflow] tensorflow==2.10.0 tensorflow-io==0.27.0", "deepdanbooru")
|
||||||
|
|
||||||
if not is_installed("pyngrok") and ngrok:
|
if not is_installed("pyngrok") and ngrok:
|
||||||
run_pip("install pyngrok", "ngrok")
|
run_pip("install pyngrok", "ngrok")
|
||||||
|
|
||||||
os.makedirs(dir_repos, exist_ok=True)
|
os.makedirs(dir_repos, exist_ok=True)
|
||||||
|
|
||||||
git_clone("https://github.com/CompVis/stable-diffusion.git", repo_dir('stable-diffusion'), "Stable Diffusion", stable_diffusion_commit_hash)
|
git_clone(stable_diffusion_repo, repo_dir('stable-diffusion'), "Stable Diffusion", stable_diffusion_commit_hash)
|
||||||
git_clone("https://github.com/CompVis/taming-transformers.git", repo_dir('taming-transformers'), "Taming Transformers", taming_transformers_commit_hash)
|
git_clone(taming_transformers_repo, repo_dir('taming-transformers'), "Taming Transformers", taming_transformers_commit_hash)
|
||||||
git_clone("https://github.com/crowsonkb/k-diffusion.git", repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash)
|
git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash)
|
||||||
git_clone("https://github.com/sczhou/CodeFormer.git", repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash)
|
git_clone(codeformer_repo, repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash)
|
||||||
git_clone("https://github.com/salesforce/BLIP.git", repo_dir('BLIP'), "BLIP", blip_commit_hash)
|
git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash)
|
||||||
|
|
||||||
if not is_installed("lpips"):
|
if not is_installed("lpips"):
|
||||||
run_pip(f"install -r {os.path.join(repo_dir('CodeFormer'), 'requirements.txt')}", "requirements for CodeFormer")
|
run_pip(f"install -r {os.path.join(repo_dir('CodeFormer'), 'requirements.txt')}", "requirements for CodeFormer")
|
||||||
|
|
||||||
run_pip(f"install -r {requirements_file}", "requirements for Web UI")
|
run_pip(f"install -r {requirements_file}", "requirements for Web UI")
|
||||||
|
|
||||||
sys.argv += args
|
if update_check:
|
||||||
|
version_check(commit)
|
||||||
|
|
||||||
if "--exit" in args:
|
if "--exit" in sys.argv:
|
||||||
print("Exiting because of --exit argument")
|
print("Exiting because of --exit argument")
|
||||||
exit(0)
|
exit(0)
|
||||||
|
|
||||||
|
|
0
localizations/Put localization files here.txt
Normal file
0
localizations/Put localization files here.txt
Normal file
68
modules/api/api.py
Normal file
68
modules/api/api.py
Normal file
|
@ -0,0 +1,68 @@
|
||||||
|
from modules.api.processing import StableDiffusionProcessingAPI
|
||||||
|
from modules.processing import StableDiffusionProcessingTxt2Img, process_images
|
||||||
|
from modules.sd_samplers import all_samplers
|
||||||
|
from modules.extras import run_pnginfo
|
||||||
|
import modules.shared as shared
|
||||||
|
import uvicorn
|
||||||
|
from fastapi import Body, APIRouter, HTTPException
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
from pydantic import BaseModel, Field, Json
|
||||||
|
import json
|
||||||
|
import io
|
||||||
|
import base64
|
||||||
|
|
||||||
|
sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None)
|
||||||
|
|
||||||
|
class TextToImageResponse(BaseModel):
|
||||||
|
images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
|
||||||
|
parameters: Json
|
||||||
|
info: Json
|
||||||
|
|
||||||
|
|
||||||
|
class Api:
|
||||||
|
def __init__(self, app, queue_lock):
|
||||||
|
self.router = APIRouter()
|
||||||
|
self.app = app
|
||||||
|
self.queue_lock = queue_lock
|
||||||
|
self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"])
|
||||||
|
|
||||||
|
def text2imgapi(self, txt2imgreq: StableDiffusionProcessingAPI ):
|
||||||
|
sampler_index = sampler_to_index(txt2imgreq.sampler_index)
|
||||||
|
|
||||||
|
if sampler_index is None:
|
||||||
|
raise HTTPException(status_code=404, detail="Sampler not found")
|
||||||
|
|
||||||
|
populate = txt2imgreq.copy(update={ # Override __init__ params
|
||||||
|
"sd_model": shared.sd_model,
|
||||||
|
"sampler_index": sampler_index[0],
|
||||||
|
"do_not_save_samples": True,
|
||||||
|
"do_not_save_grid": True
|
||||||
|
}
|
||||||
|
)
|
||||||
|
p = StableDiffusionProcessingTxt2Img(**vars(populate))
|
||||||
|
# Override object param
|
||||||
|
with self.queue_lock:
|
||||||
|
processed = process_images(p)
|
||||||
|
|
||||||
|
b64images = []
|
||||||
|
for i in processed.images:
|
||||||
|
buffer = io.BytesIO()
|
||||||
|
i.save(buffer, format="png")
|
||||||
|
b64images.append(base64.b64encode(buffer.getvalue()))
|
||||||
|
|
||||||
|
return TextToImageResponse(images=b64images, parameters=json.dumps(vars(txt2imgreq)), info=json.dumps(processed.info))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def img2imgapi(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def extrasapi(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def pnginfoapi(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def launch(self, server_name, port):
|
||||||
|
self.app.include_router(self.router)
|
||||||
|
uvicorn.run(self.app, host=server_name, port=port)
|
99
modules/api/processing.py
Normal file
99
modules/api/processing.py
Normal file
|
@ -0,0 +1,99 @@
|
||||||
|
from inflection import underscore
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
from pydantic import BaseModel, Field, create_model
|
||||||
|
from modules.processing import StableDiffusionProcessingTxt2Img
|
||||||
|
import inspect
|
||||||
|
|
||||||
|
|
||||||
|
API_NOT_ALLOWED = [
|
||||||
|
"self",
|
||||||
|
"kwargs",
|
||||||
|
"sd_model",
|
||||||
|
"outpath_samples",
|
||||||
|
"outpath_grids",
|
||||||
|
"sampler_index",
|
||||||
|
"do_not_save_samples",
|
||||||
|
"do_not_save_grid",
|
||||||
|
"extra_generation_params",
|
||||||
|
"overlay_images",
|
||||||
|
"do_not_reload_embeddings",
|
||||||
|
"seed_enable_extras",
|
||||||
|
"prompt_for_display",
|
||||||
|
"sampler_noise_scheduler_override",
|
||||||
|
"ddim_discretize"
|
||||||
|
]
|
||||||
|
|
||||||
|
class ModelDef(BaseModel):
|
||||||
|
"""Assistance Class for Pydantic Dynamic Model Generation"""
|
||||||
|
|
||||||
|
field: str
|
||||||
|
field_alias: str
|
||||||
|
field_type: Any
|
||||||
|
field_value: Any
|
||||||
|
|
||||||
|
|
||||||
|
class PydanticModelGenerator:
|
||||||
|
"""
|
||||||
|
Takes in created classes and stubs them out in a way FastAPI/Pydantic is happy about:
|
||||||
|
source_data is a snapshot of the default values produced by the class
|
||||||
|
params are the names of the actual keys required by __init__
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_name: str = None,
|
||||||
|
class_instance = None,
|
||||||
|
additional_fields = None,
|
||||||
|
):
|
||||||
|
def field_type_generator(k, v):
|
||||||
|
# field_type = str if not overrides.get(k) else overrides[k]["type"]
|
||||||
|
# print(k, v.annotation, v.default)
|
||||||
|
field_type = v.annotation
|
||||||
|
|
||||||
|
return Optional[field_type]
|
||||||
|
|
||||||
|
def merge_class_params(class_):
|
||||||
|
all_classes = list(filter(lambda x: x is not object, inspect.getmro(class_)))
|
||||||
|
parameters = {}
|
||||||
|
for classes in all_classes:
|
||||||
|
parameters = {**parameters, **inspect.signature(classes.__init__).parameters}
|
||||||
|
return parameters
|
||||||
|
|
||||||
|
|
||||||
|
self._model_name = model_name
|
||||||
|
self._class_data = merge_class_params(class_instance)
|
||||||
|
self._model_def = [
|
||||||
|
ModelDef(
|
||||||
|
field=underscore(k),
|
||||||
|
field_alias=k,
|
||||||
|
field_type=field_type_generator(k, v),
|
||||||
|
field_value=v.default
|
||||||
|
)
|
||||||
|
for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED
|
||||||
|
]
|
||||||
|
|
||||||
|
for fields in additional_fields:
|
||||||
|
self._model_def.append(ModelDef(
|
||||||
|
field=underscore(fields["key"]),
|
||||||
|
field_alias=fields["key"],
|
||||||
|
field_type=fields["type"],
|
||||||
|
field_value=fields["default"]))
|
||||||
|
|
||||||
|
def generate_model(self):
|
||||||
|
"""
|
||||||
|
Creates a pydantic BaseModel
|
||||||
|
from the json and overrides provided at initialization
|
||||||
|
"""
|
||||||
|
fields = {
|
||||||
|
d.field: (d.field_type, Field(default=d.field_value, alias=d.field_alias)) for d in self._model_def
|
||||||
|
}
|
||||||
|
DynamicModel = create_model(self._model_name, **fields)
|
||||||
|
DynamicModel.__config__.allow_population_by_field_name = True
|
||||||
|
DynamicModel.__config__.allow_mutation = True
|
||||||
|
return DynamicModel
|
||||||
|
|
||||||
|
StableDiffusionProcessingAPI = PydanticModelGenerator(
|
||||||
|
"StableDiffusionProcessingTxt2Img",
|
||||||
|
StableDiffusionProcessingTxt2Img,
|
||||||
|
[{"key": "sampler_index", "type": str, "default": "Euler"}]
|
||||||
|
).generate_model()
|
|
@ -157,8 +157,7 @@ def get_deepbooru_tags_from_model(model, tags, pil_image, threshold, deepbooru_o
|
||||||
# sort by reverse by likelihood and normal for alpha, and format tag text as requested
|
# sort by reverse by likelihood and normal for alpha, and format tag text as requested
|
||||||
unsorted_tags_in_theshold.sort(key=lambda y: y[sort_ndx], reverse=(not alpha_sort))
|
unsorted_tags_in_theshold.sort(key=lambda y: y[sort_ndx], reverse=(not alpha_sort))
|
||||||
for weight, tag in unsorted_tags_in_theshold:
|
for weight, tag in unsorted_tags_in_theshold:
|
||||||
# note: tag_outformat will still have a colon if include_ranks is True
|
tag_outformat = tag
|
||||||
tag_outformat = tag.replace(':', ' ')
|
|
||||||
if use_spaces:
|
if use_spaces:
|
||||||
tag_outformat = tag_outformat.replace('_', ' ')
|
tag_outformat = tag_outformat.replace('_', ' ')
|
||||||
if use_escape:
|
if use_escape:
|
||||||
|
|
|
@ -20,12 +20,13 @@ import gradio as gr
|
||||||
cached_images = {}
|
cached_images = {}
|
||||||
|
|
||||||
|
|
||||||
def run_extras(extras_mode, resize_mode, image, image_folder, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility):
|
def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility):
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
|
|
||||||
imageArr = []
|
imageArr = []
|
||||||
# Also keep track of original file names
|
# Also keep track of original file names
|
||||||
imageNameArr = []
|
imageNameArr = []
|
||||||
|
outputs = []
|
||||||
|
|
||||||
if extras_mode == 1:
|
if extras_mode == 1:
|
||||||
#convert file to pillow image
|
#convert file to pillow image
|
||||||
|
@ -33,13 +34,26 @@ def run_extras(extras_mode, resize_mode, image, image_folder, gfpgan_visibility,
|
||||||
image = Image.open(img)
|
image = Image.open(img)
|
||||||
imageArr.append(image)
|
imageArr.append(image)
|
||||||
imageNameArr.append(os.path.splitext(img.orig_name)[0])
|
imageNameArr.append(os.path.splitext(img.orig_name)[0])
|
||||||
|
elif extras_mode == 2:
|
||||||
|
assert not shared.cmd_opts.hide_ui_dir_config, '--hide-ui-dir-config option must be disabled'
|
||||||
|
|
||||||
|
if input_dir == '':
|
||||||
|
return outputs, "Please select an input directory.", ''
|
||||||
|
image_list = [file for file in [os.path.join(input_dir, x) for x in os.listdir(input_dir)] if os.path.isfile(file)]
|
||||||
|
for img in image_list:
|
||||||
|
image = Image.open(img)
|
||||||
|
imageArr.append(image)
|
||||||
|
imageNameArr.append(img)
|
||||||
else:
|
else:
|
||||||
imageArr.append(image)
|
imageArr.append(image)
|
||||||
imageNameArr.append(None)
|
imageNameArr.append(None)
|
||||||
|
|
||||||
|
if extras_mode == 2 and output_dir != '':
|
||||||
|
outpath = output_dir
|
||||||
|
else:
|
||||||
outpath = opts.outdir_samples or opts.outdir_extras_samples
|
outpath = opts.outdir_samples or opts.outdir_extras_samples
|
||||||
|
|
||||||
outputs = []
|
|
||||||
for image, image_name in zip(imageArr, imageNameArr):
|
for image, image_name in zip(imageArr, imageNameArr):
|
||||||
if image is None:
|
if image is None:
|
||||||
return outputs, "Please select an input image.", ''
|
return outputs, "Please select an input image.", ''
|
||||||
|
@ -77,7 +91,8 @@ def run_extras(extras_mode, resize_mode, image, image_folder, gfpgan_visibility,
|
||||||
def upscale(image, scaler_index, resize, mode, resize_w, resize_h, crop):
|
def upscale(image, scaler_index, resize, mode, resize_w, resize_h, crop):
|
||||||
small = image.crop((image.width // 2, image.height // 2, image.width // 2 + 10, image.height // 2 + 10))
|
small = image.crop((image.width // 2, image.height // 2, image.width // 2 + 10, image.height // 2 + 10))
|
||||||
pixels = tuple(np.array(small).flatten().tolist())
|
pixels = tuple(np.array(small).flatten().tolist())
|
||||||
key = (resize, scaler_index, image.width, image.height, gfpgan_visibility, codeformer_visibility, codeformer_weight) + pixels
|
key = (resize, scaler_index, image.width, image.height, gfpgan_visibility, codeformer_visibility, codeformer_weight,
|
||||||
|
resize_mode, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop) + pixels
|
||||||
|
|
||||||
c = cached_images.get(key)
|
c = cached_images.get(key)
|
||||||
if c is None:
|
if c is None:
|
||||||
|
@ -112,6 +127,7 @@ def run_extras(extras_mode, resize_mode, image, image_folder, gfpgan_visibility,
|
||||||
image.info = existing_pnginfo
|
image.info = existing_pnginfo
|
||||||
image.info["extras"] = info
|
image.info["extras"] = info
|
||||||
|
|
||||||
|
if extras_mode != 2 or show_extras_results :
|
||||||
outputs.append(image)
|
outputs.append(image)
|
||||||
|
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
|
@ -160,11 +176,14 @@ def run_pnginfo(image):
|
||||||
|
|
||||||
|
|
||||||
def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_name, interp_method, multiplier, save_as_half, custom_name):
|
def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_name, interp_method, multiplier, save_as_half, custom_name):
|
||||||
def weighted_sum(theta0, theta1, theta2, alpha):
|
def weighted_sum(theta0, theta1, alpha):
|
||||||
return ((1 - alpha) * theta0) + (alpha * theta1)
|
return ((1 - alpha) * theta0) + (alpha * theta1)
|
||||||
|
|
||||||
def add_difference(theta0, theta1, theta2, alpha):
|
def get_difference(theta1, theta2):
|
||||||
return theta0 + (theta1 - theta2) * alpha
|
return theta1 - theta2
|
||||||
|
|
||||||
|
def add_difference(theta0, theta1_2_diff, alpha):
|
||||||
|
return theta0 + (alpha * theta1_2_diff)
|
||||||
|
|
||||||
primary_model_info = sd_models.checkpoints_list[primary_model_name]
|
primary_model_info = sd_models.checkpoints_list[primary_model_name]
|
||||||
secondary_model_info = sd_models.checkpoints_list[secondary_model_name]
|
secondary_model_info = sd_models.checkpoints_list[secondary_model_name]
|
||||||
|
@ -183,23 +202,31 @@ def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_nam
|
||||||
teritary_model = torch.load(teritary_model_info.filename, map_location='cpu')
|
teritary_model = torch.load(teritary_model_info.filename, map_location='cpu')
|
||||||
theta_2 = sd_models.get_state_dict_from_checkpoint(teritary_model)
|
theta_2 = sd_models.get_state_dict_from_checkpoint(teritary_model)
|
||||||
else:
|
else:
|
||||||
|
teritary_model = None
|
||||||
theta_2 = None
|
theta_2 = None
|
||||||
|
|
||||||
theta_funcs = {
|
theta_funcs = {
|
||||||
"Weighted sum": weighted_sum,
|
"Weighted sum": (None, weighted_sum),
|
||||||
"Add difference": add_difference,
|
"Add difference": (get_difference, add_difference),
|
||||||
}
|
}
|
||||||
theta_func = theta_funcs[interp_method]
|
theta_func1, theta_func2 = theta_funcs[interp_method]
|
||||||
|
|
||||||
print(f"Merging...")
|
print(f"Merging...")
|
||||||
|
|
||||||
|
if theta_func1:
|
||||||
|
for key in tqdm.tqdm(theta_1.keys()):
|
||||||
|
if 'model' in key:
|
||||||
|
if key in theta_2:
|
||||||
|
t2 = theta_2.get(key, torch.zeros_like(theta_1[key]))
|
||||||
|
theta_1[key] = theta_func1(theta_1[key], t2)
|
||||||
|
else:
|
||||||
|
theta_1[key] = torch.zeros_like(theta_1[key])
|
||||||
|
del theta_2, teritary_model
|
||||||
|
|
||||||
for key in tqdm.tqdm(theta_0.keys()):
|
for key in tqdm.tqdm(theta_0.keys()):
|
||||||
if 'model' in key and key in theta_1:
|
if 'model' in key and key in theta_1:
|
||||||
t2 = (theta_2 or {}).get(key)
|
|
||||||
if t2 is None:
|
|
||||||
t2 = torch.zeros_like(theta_0[key])
|
|
||||||
|
|
||||||
theta_0[key] = theta_func(theta_0[key], theta_1[key], t2, multiplier)
|
theta_0[key] = theta_func2(theta_0[key], theta_1[key], multiplier)
|
||||||
|
|
||||||
if save_as_half:
|
if save_as_half:
|
||||||
theta_0[key] = theta_0[key].half()
|
theta_0[key] = theta_0[key].half()
|
||||||
|
|
|
@ -196,7 +196,7 @@ def stack_conds(conds):
|
||||||
|
|
||||||
return torch.stack(conds)
|
return torch.stack(conds)
|
||||||
|
|
||||||
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
||||||
assert hypernetwork_name, 'hypernetwork not selected'
|
assert hypernetwork_name, 'hypernetwork not selected'
|
||||||
|
|
||||||
path = shared.hypernetworks.get(hypernetwork_name, None)
|
path = shared.hypernetworks.get(hypernetwork_name, None)
|
||||||
|
@ -225,7 +225,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
||||||
|
|
||||||
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
||||||
with torch.autocast("cuda"):
|
with torch.autocast("cuda"):
|
||||||
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=512, height=512, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size)
|
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size)
|
||||||
|
|
||||||
if unload:
|
if unload:
|
||||||
shared.sd_model.cond_stage_model.to(devices.cpu)
|
shared.sd_model.cond_stage_model.to(devices.cpu)
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
|
import sys
|
||||||
|
|
||||||
def traverse_all_files(output_dir, image_list, curr_dir=None):
|
def traverse_all_files(output_dir, image_list, curr_dir=None):
|
||||||
curr_path = output_dir if curr_dir is None else os.path.join(output_dir, curr_dir)
|
curr_path = output_dir if curr_dir is None else os.path.join(output_dir, curr_dir)
|
||||||
|
@ -24,10 +24,14 @@ def traverse_all_files(output_dir, image_list, curr_dir=None):
|
||||||
|
|
||||||
def get_recent_images(dir_name, page_index, step, image_index, tabname):
|
def get_recent_images(dir_name, page_index, step, image_index, tabname):
|
||||||
page_index = int(page_index)
|
page_index = int(page_index)
|
||||||
f_list = os.listdir(dir_name)
|
|
||||||
image_list = []
|
image_list = []
|
||||||
|
if not os.path.exists(dir_name):
|
||||||
|
pass
|
||||||
|
elif os.path.isdir(dir_name):
|
||||||
image_list = traverse_all_files(dir_name, image_list)
|
image_list = traverse_all_files(dir_name, image_list)
|
||||||
image_list = sorted(image_list, key=lambda file: -os.path.getctime(os.path.join(dir_name, file)))
|
image_list = sorted(image_list, key=lambda file: -os.path.getctime(os.path.join(dir_name, file)))
|
||||||
|
else:
|
||||||
|
print(f'ERROR: "{dir_name}" is not a directory. Check the path in the settings.', file=sys.stderr)
|
||||||
num = 48 if tabname != "extras" else 12
|
num = 48 if tabname != "extras" else 12
|
||||||
max_page_index = len(image_list) // num + 1
|
max_page_index = len(image_list) // num + 1
|
||||||
page_index = max_page_index if page_index == -1 else page_index + step
|
page_index = max_page_index if page_index == -1 else page_index + step
|
||||||
|
@ -105,10 +109,8 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict):
|
||||||
dir_name = opts.outdir_img2img_samples
|
dir_name = opts.outdir_img2img_samples
|
||||||
elif tabname == "extras":
|
elif tabname == "extras":
|
||||||
dir_name = opts.outdir_extras_samples
|
dir_name = opts.outdir_extras_samples
|
||||||
d = dir_name.split("/")
|
else:
|
||||||
dir_name = "/" if dir_name.startswith("/") else d[0]
|
return
|
||||||
for p in d[1:]:
|
|
||||||
dir_name = os.path.join(dir_name, p)
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
renew_page = gr.Button('Renew Page', elem_id=tabname + "_images_history_renew_page")
|
renew_page = gr.Button('Renew Page', elem_id=tabname + "_images_history_renew_page")
|
||||||
first_page = gr.Button('First Page')
|
first_page = gr.Button('First Page')
|
||||||
|
|
|
@ -123,7 +123,7 @@ class InterrogateModels:
|
||||||
|
|
||||||
return caption[0]
|
return caption[0]
|
||||||
|
|
||||||
def interrogate(self, pil_image, include_ranks=False):
|
def interrogate(self, pil_image):
|
||||||
res = None
|
res = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -156,10 +156,10 @@ class InterrogateModels:
|
||||||
for name, topn, items in self.categories:
|
for name, topn, items in self.categories:
|
||||||
matches = self.rank(image_features, items, top_count=topn)
|
matches = self.rank(image_features, items, top_count=topn)
|
||||||
for match, score in matches:
|
for match, score in matches:
|
||||||
if include_ranks:
|
if shared.opts.interrogate_return_ranks:
|
||||||
res += ", " + match
|
res += f", ({match}:{score/100:.3f})"
|
||||||
else:
|
else:
|
||||||
res += f", ({match}:{score})"
|
res += ", " + match
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
print(f"Error interrogating", file=sys.stderr)
|
print(f"Error interrogating", file=sys.stderr)
|
||||||
|
|
31
modules/localization.py
Normal file
31
modules/localization.py
Normal file
|
@ -0,0 +1,31 @@
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
localizations = {}
|
||||||
|
|
||||||
|
|
||||||
|
def list_localizations(dirname):
|
||||||
|
localizations.clear()
|
||||||
|
|
||||||
|
for file in os.listdir(dirname):
|
||||||
|
fn, ext = os.path.splitext(file)
|
||||||
|
if ext.lower() != ".json":
|
||||||
|
continue
|
||||||
|
|
||||||
|
localizations[fn] = os.path.join(dirname, file)
|
||||||
|
|
||||||
|
|
||||||
|
def localization_js(current_localization_name):
|
||||||
|
fn = localizations.get(current_localization_name, None)
|
||||||
|
data = {}
|
||||||
|
if fn is not None:
|
||||||
|
try:
|
||||||
|
with open(fn, "r", encoding="utf8") as file:
|
||||||
|
data = json.load(file)
|
||||||
|
except Exception:
|
||||||
|
print(f"Error loading localization from {fn}:", file=sys.stderr)
|
||||||
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
|
||||||
|
return f"var localization = {json.dumps(data)}\n"
|
|
@ -1,12 +1,14 @@
|
||||||
from pyngrok import ngrok, conf, exception
|
from pyngrok import ngrok, conf, exception
|
||||||
|
|
||||||
|
|
||||||
def connect(token, port):
|
def connect(token, port, region):
|
||||||
if token == None:
|
if token == None:
|
||||||
token = 'None'
|
token = 'None'
|
||||||
conf.get_default().auth_token = token
|
config = conf.PyngrokConfig(
|
||||||
|
auth_token=token, region=region
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
public_url = ngrok.connect(port).public_url
|
public_url = ngrok.connect(port, pyngrok_config=config).public_url
|
||||||
except exception.PyngrokNgrokError:
|
except exception.PyngrokNgrokError:
|
||||||
print(f'Invalid ngrok authtoken, ngrok connection aborted.\n'
|
print(f'Invalid ngrok authtoken, ngrok connection aborted.\n'
|
||||||
f'Your token: {token}, get the right one on https://dashboard.ngrok.com/get-started/your-authtoken')
|
f'Your token: {token}, get the right one on https://dashboard.ngrok.com/get-started/your-authtoken')
|
||||||
|
|
|
@ -9,6 +9,7 @@ from PIL import Image, ImageFilter, ImageOps
|
||||||
import random
|
import random
|
||||||
import cv2
|
import cv2
|
||||||
from skimage import exposure
|
from skimage import exposure
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import modules.sd_hijack
|
import modules.sd_hijack
|
||||||
from modules import devices, prompt_parser, masking, sd_samplers, lowvram
|
from modules import devices, prompt_parser, masking, sd_samplers, lowvram
|
||||||
|
@ -51,9 +52,15 @@ def get_correct_sampler(p):
|
||||||
return sd_samplers.samplers
|
return sd_samplers.samplers
|
||||||
elif isinstance(p, modules.processing.StableDiffusionProcessingImg2Img):
|
elif isinstance(p, modules.processing.StableDiffusionProcessingImg2Img):
|
||||||
return sd_samplers.samplers_for_img2img
|
return sd_samplers.samplers_for_img2img
|
||||||
|
elif isinstance(p, modules.api.processing.StableDiffusionProcessingAPI):
|
||||||
|
return sd_samplers.samplers
|
||||||
|
|
||||||
class StableDiffusionProcessing:
|
class StableDiffusionProcessing():
|
||||||
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", styles=None, seed=-1, subseed=-1, subseed_strength=0, seed_resize_from_h=-1, seed_resize_from_w=-1, seed_enable_extras=True, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None, eta=None):
|
"""
|
||||||
|
The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
|
||||||
|
|
||||||
|
"""
|
||||||
|
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str="", styles: List[str]=None, seed: int=-1, subseed: int=-1, subseed_strength: float=0, seed_resize_from_h: int=-1, seed_resize_from_w: int=-1, seed_enable_extras: bool=True, sampler_index: int=0, batch_size: int=1, n_iter: int=1, steps:int =50, cfg_scale:float=7.0, width:int=512, height:int=512, restore_faces:bool=False, tiling:bool=False, do_not_save_samples:bool=False, do_not_save_grid:bool=False, extra_generation_params: Dict[Any,Any]=None, overlay_images: Any=None, negative_prompt: str=None, eta: float =None, do_not_reload_embeddings: bool=False, denoising_strength: float = 0, ddim_discretize: str = "uniform", s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0):
|
||||||
self.sd_model = sd_model
|
self.sd_model = sd_model
|
||||||
self.outpath_samples: str = outpath_samples
|
self.outpath_samples: str = outpath_samples
|
||||||
self.outpath_grids: str = outpath_grids
|
self.outpath_grids: str = outpath_grids
|
||||||
|
@ -80,15 +87,16 @@ class StableDiffusionProcessing:
|
||||||
self.extra_generation_params: dict = extra_generation_params or {}
|
self.extra_generation_params: dict = extra_generation_params or {}
|
||||||
self.overlay_images = overlay_images
|
self.overlay_images = overlay_images
|
||||||
self.eta = eta
|
self.eta = eta
|
||||||
|
self.do_not_reload_embeddings = do_not_reload_embeddings
|
||||||
self.paste_to = None
|
self.paste_to = None
|
||||||
self.color_corrections = None
|
self.color_corrections = None
|
||||||
self.denoising_strength: float = 0
|
self.denoising_strength: float = 0
|
||||||
self.sampler_noise_scheduler_override = None
|
self.sampler_noise_scheduler_override = None
|
||||||
self.ddim_discretize = opts.ddim_discretize
|
self.ddim_discretize = opts.ddim_discretize
|
||||||
self.s_churn = opts.s_churn
|
self.s_churn = s_churn or opts.s_churn
|
||||||
self.s_tmin = opts.s_tmin
|
self.s_tmin = s_tmin or opts.s_tmin
|
||||||
self.s_tmax = float('inf') # not representable as a standard ui option
|
self.s_tmax = s_tmax or float('inf') # not representable as a standard ui option
|
||||||
self.s_noise = opts.s_noise
|
self.s_noise = s_noise or opts.s_noise
|
||||||
|
|
||||||
if not seed_enable_extras:
|
if not seed_enable_extras:
|
||||||
self.subseed = -1
|
self.subseed = -1
|
||||||
|
@ -96,6 +104,7 @@ class StableDiffusionProcessing:
|
||||||
self.seed_resize_from_h = 0
|
self.seed_resize_from_h = 0
|
||||||
self.seed_resize_from_w = 0
|
self.seed_resize_from_w = 0
|
||||||
|
|
||||||
|
|
||||||
def init(self, all_prompts, all_seeds, all_subseeds):
|
def init(self, all_prompts, all_seeds, all_subseeds):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -333,12 +342,6 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
||||||
seed = get_fixed_seed(p.seed)
|
seed = get_fixed_seed(p.seed)
|
||||||
subseed = get_fixed_seed(p.subseed)
|
subseed = get_fixed_seed(p.subseed)
|
||||||
|
|
||||||
if p.outpath_samples is not None:
|
|
||||||
os.makedirs(p.outpath_samples, exist_ok=True)
|
|
||||||
|
|
||||||
if p.outpath_grids is not None:
|
|
||||||
os.makedirs(p.outpath_grids, exist_ok=True)
|
|
||||||
|
|
||||||
modules.sd_hijack.model_hijack.apply_circular(p.tiling)
|
modules.sd_hijack.model_hijack.apply_circular(p.tiling)
|
||||||
modules.sd_hijack.model_hijack.clear_comments()
|
modules.sd_hijack.model_hijack.clear_comments()
|
||||||
|
|
||||||
|
@ -364,7 +367,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
||||||
def infotext(iteration=0, position_in_batch=0):
|
def infotext(iteration=0, position_in_batch=0):
|
||||||
return create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration, position_in_batch)
|
return create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration, position_in_batch)
|
||||||
|
|
||||||
if os.path.exists(cmd_opts.embeddings_dir):
|
if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings:
|
||||||
model_hijack.embedding_db.load_textual_inversion_embeddings()
|
model_hijack.embedding_db.load_textual_inversion_embeddings()
|
||||||
|
|
||||||
infotexts = []
|
infotexts = []
|
||||||
|
@ -407,12 +410,6 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
||||||
with devices.autocast():
|
with devices.autocast():
|
||||||
samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength)
|
samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength)
|
||||||
|
|
||||||
if state.interrupted or state.skipped:
|
|
||||||
|
|
||||||
# if we are interrupted, sample returns just noise
|
|
||||||
# use the image collected previously in sampler loop
|
|
||||||
samples_ddim = shared.state.current_latent
|
|
||||||
|
|
||||||
samples_ddim = samples_ddim.to(devices.dtype_vae)
|
samples_ddim = samples_ddim.to(devices.dtype_vae)
|
||||||
x_samples_ddim = decode_first_stage(p.sd_model, samples_ddim)
|
x_samples_ddim = decode_first_stage(p.sd_model, samples_ddim)
|
||||||
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
|
@ -502,7 +499,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
||||||
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||||
sampler = None
|
sampler = None
|
||||||
|
|
||||||
def __init__(self, enable_hr=False, denoising_strength=0.75, firstphase_width=0, firstphase_height=0, **kwargs):
|
def __init__(self, enable_hr: bool=False, denoising_strength: float=0.75, firstphase_width: int=0, firstphase_height: int=0, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.enable_hr = enable_hr
|
self.enable_hr = enable_hr
|
||||||
self.denoising_strength = denoising_strength
|
self.denoising_strength = denoising_strength
|
||||||
|
|
|
@ -58,6 +58,9 @@ def load_scripts(basedir):
|
||||||
for filename in sorted(os.listdir(basedir)):
|
for filename in sorted(os.listdir(basedir)):
|
||||||
path = os.path.join(basedir, filename)
|
path = os.path.join(basedir, filename)
|
||||||
|
|
||||||
|
if os.path.splitext(path)[1].lower() != '.py':
|
||||||
|
continue
|
||||||
|
|
||||||
if not os.path.isfile(path):
|
if not os.path.isfile(path):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
@ -93,6 +96,7 @@ def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
|
||||||
class ScriptRunner:
|
class ScriptRunner:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.scripts = []
|
self.scripts = []
|
||||||
|
self.titles = []
|
||||||
|
|
||||||
def setup_ui(self, is_img2img):
|
def setup_ui(self, is_img2img):
|
||||||
for script_class, path in scripts_data:
|
for script_class, path in scripts_data:
|
||||||
|
@ -104,9 +108,10 @@ class ScriptRunner:
|
||||||
|
|
||||||
self.scripts.append(script)
|
self.scripts.append(script)
|
||||||
|
|
||||||
titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.scripts]
|
self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.scripts]
|
||||||
|
|
||||||
dropdown = gr.Dropdown(label="Script", choices=["None"] + titles, value="None", type="index")
|
dropdown = gr.Dropdown(label="Script", choices=["None"] + self.titles, value="None", type="index")
|
||||||
|
dropdown.save_to_config = True
|
||||||
inputs = [dropdown]
|
inputs = [dropdown]
|
||||||
|
|
||||||
for script in self.scripts:
|
for script in self.scripts:
|
||||||
|
@ -136,6 +141,15 @@ class ScriptRunner:
|
||||||
|
|
||||||
return [ui.gr_show(True if i == 0 else args_from <= i < args_to) for i in range(len(inputs))]
|
return [ui.gr_show(True if i == 0 else args_from <= i < args_to) for i in range(len(inputs))]
|
||||||
|
|
||||||
|
def init_field(title):
|
||||||
|
if title == 'None':
|
||||||
|
return
|
||||||
|
script_index = self.titles.index(title)
|
||||||
|
script = self.scripts[script_index]
|
||||||
|
for i in range(script.args_from, script.args_to):
|
||||||
|
inputs[i].visible = True
|
||||||
|
|
||||||
|
dropdown.init_field = init_field
|
||||||
dropdown.change(
|
dropdown.change(
|
||||||
fn=select_script,
|
fn=select_script,
|
||||||
inputs=[dropdown],
|
inputs=[dropdown],
|
||||||
|
|
|
@ -181,7 +181,7 @@ def einsum_op_cuda(q, k, v):
|
||||||
mem_free_torch = mem_reserved - mem_active
|
mem_free_torch = mem_reserved - mem_active
|
||||||
mem_free_total = mem_free_cuda + mem_free_torch
|
mem_free_total = mem_free_cuda + mem_free_torch
|
||||||
# Divide factor of safety as there's copying and fragmentation
|
# Divide factor of safety as there's copying and fragmentation
|
||||||
return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
|
return einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
|
||||||
|
|
||||||
def einsum_op(q, k, v):
|
def einsum_op(q, k, v):
|
||||||
if q.device.type == 'cuda':
|
if q.device.type == 'cuda':
|
||||||
|
@ -296,10 +296,16 @@ def xformers_attnblock_forward(self, x):
|
||||||
try:
|
try:
|
||||||
h_ = x
|
h_ = x
|
||||||
h_ = self.norm(h_)
|
h_ = self.norm(h_)
|
||||||
q1 = self.q(h_).contiguous()
|
q = self.q(h_)
|
||||||
k1 = self.k(h_).contiguous()
|
k = self.k(h_)
|
||||||
v = self.v(h_).contiguous()
|
v = self.v(h_)
|
||||||
out = xformers.ops.memory_efficient_attention(q1, k1, v)
|
b, c, h, w = q.shape
|
||||||
|
q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v))
|
||||||
|
q = q.contiguous()
|
||||||
|
k = k.contiguous()
|
||||||
|
v = v.contiguous()
|
||||||
|
out = xformers.ops.memory_efficient_attention(q, k, v)
|
||||||
|
out = rearrange(out, 'b (h w) c -> b c h w', h=h)
|
||||||
out = self.proj_out(out)
|
out = self.proj_out(out)
|
||||||
return x + out
|
return x + out
|
||||||
except NotImplementedError:
|
except NotImplementedError:
|
||||||
|
|
|
@ -122,11 +122,33 @@ def select_checkpoint():
|
||||||
return checkpoint_info
|
return checkpoint_info
|
||||||
|
|
||||||
|
|
||||||
|
chckpoint_dict_replacements = {
|
||||||
|
'cond_stage_model.transformer.embeddings.': 'cond_stage_model.transformer.text_model.embeddings.',
|
||||||
|
'cond_stage_model.transformer.encoder.': 'cond_stage_model.transformer.text_model.encoder.',
|
||||||
|
'cond_stage_model.transformer.final_layer_norm.': 'cond_stage_model.transformer.text_model.final_layer_norm.',
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def transform_checkpoint_dict_key(k):
|
||||||
|
for text, replacement in chckpoint_dict_replacements.items():
|
||||||
|
if k.startswith(text):
|
||||||
|
k = replacement + k[len(text):]
|
||||||
|
|
||||||
|
return k
|
||||||
|
|
||||||
|
|
||||||
def get_state_dict_from_checkpoint(pl_sd):
|
def get_state_dict_from_checkpoint(pl_sd):
|
||||||
if "state_dict" in pl_sd:
|
if "state_dict" in pl_sd:
|
||||||
return pl_sd["state_dict"]
|
pl_sd = pl_sd["state_dict"]
|
||||||
|
|
||||||
return pl_sd
|
sd = {}
|
||||||
|
for k, v in pl_sd.items():
|
||||||
|
new_key = transform_checkpoint_dict_key(k)
|
||||||
|
|
||||||
|
if new_key is not None:
|
||||||
|
sd[new_key] = v
|
||||||
|
|
||||||
|
return sd
|
||||||
|
|
||||||
|
|
||||||
def load_model_weights(model, checkpoint_info):
|
def load_model_weights(model, checkpoint_info):
|
||||||
|
@ -141,7 +163,7 @@ def load_model_weights(model, checkpoint_info):
|
||||||
print(f"Global Step: {pl_sd['global_step']}")
|
print(f"Global Step: {pl_sd['global_step']}")
|
||||||
|
|
||||||
sd = get_state_dict_from_checkpoint(pl_sd)
|
sd = get_state_dict_from_checkpoint(pl_sd)
|
||||||
model.load_state_dict(sd, strict=False)
|
missing, extra = model.load_state_dict(sd, strict=False)
|
||||||
|
|
||||||
if shared.cmd_opts.opt_channelslast:
|
if shared.cmd_opts.opt_channelslast:
|
||||||
model.to(memory_format=torch.channels_last)
|
model.to(memory_format=torch.channels_last)
|
||||||
|
|
|
@ -98,25 +98,8 @@ def store_latent(decoded):
|
||||||
shared.state.current_image = sample_to_image(decoded)
|
shared.state.current_image = sample_to_image(decoded)
|
||||||
|
|
||||||
|
|
||||||
|
class InterruptedException(BaseException):
|
||||||
def extended_tdqm(sequence, *args, desc=None, **kwargs):
|
pass
|
||||||
state.sampling_steps = len(sequence)
|
|
||||||
state.sampling_step = 0
|
|
||||||
|
|
||||||
seq = sequence if cmd_opts.disable_console_progressbars else tqdm.tqdm(sequence, *args, desc=state.job, file=shared.progress_print_out, **kwargs)
|
|
||||||
|
|
||||||
for x in seq:
|
|
||||||
if state.interrupted or state.skipped:
|
|
||||||
break
|
|
||||||
|
|
||||||
yield x
|
|
||||||
|
|
||||||
state.sampling_step += 1
|
|
||||||
shared.total_tqdm.update()
|
|
||||||
|
|
||||||
|
|
||||||
ldm.models.diffusion.ddim.tqdm = lambda *args, desc=None, **kwargs: extended_tdqm(*args, desc=desc, **kwargs)
|
|
||||||
ldm.models.diffusion.plms.tqdm = lambda *args, desc=None, **kwargs: extended_tdqm(*args, desc=desc, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
class VanillaStableDiffusionSampler:
|
class VanillaStableDiffusionSampler:
|
||||||
|
@ -128,14 +111,32 @@ class VanillaStableDiffusionSampler:
|
||||||
self.init_latent = None
|
self.init_latent = None
|
||||||
self.sampler_noises = None
|
self.sampler_noises = None
|
||||||
self.step = 0
|
self.step = 0
|
||||||
|
self.stop_at = None
|
||||||
self.eta = None
|
self.eta = None
|
||||||
self.default_eta = 0.0
|
self.default_eta = 0.0
|
||||||
self.config = None
|
self.config = None
|
||||||
|
self.last_latent = None
|
||||||
|
|
||||||
def number_of_needed_noises(self, p):
|
def number_of_needed_noises(self, p):
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
def launch_sampling(self, steps, func):
|
||||||
|
state.sampling_steps = steps
|
||||||
|
state.sampling_step = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
return func()
|
||||||
|
except InterruptedException:
|
||||||
|
return self.last_latent
|
||||||
|
|
||||||
def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs):
|
def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs):
|
||||||
|
if state.interrupted or state.skipped:
|
||||||
|
raise InterruptedException
|
||||||
|
|
||||||
|
if self.stop_at is not None and self.step > self.stop_at:
|
||||||
|
raise InterruptedException
|
||||||
|
|
||||||
|
|
||||||
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
|
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
|
||||||
unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step)
|
unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step)
|
||||||
|
|
||||||
|
@ -159,11 +160,16 @@ class VanillaStableDiffusionSampler:
|
||||||
res = self.orig_p_sample_ddim(x_dec, cond, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs)
|
res = self.orig_p_sample_ddim(x_dec, cond, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs)
|
||||||
|
|
||||||
if self.mask is not None:
|
if self.mask is not None:
|
||||||
store_latent(self.init_latent * self.mask + self.nmask * res[1])
|
self.last_latent = self.init_latent * self.mask + self.nmask * res[1]
|
||||||
else:
|
else:
|
||||||
store_latent(res[1])
|
self.last_latent = res[1]
|
||||||
|
|
||||||
|
store_latent(self.last_latent)
|
||||||
|
|
||||||
self.step += 1
|
self.step += 1
|
||||||
|
state.sampling_step = self.step
|
||||||
|
shared.total_tqdm.update()
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def initialize(self, p):
|
def initialize(self, p):
|
||||||
|
@ -192,7 +198,7 @@ class VanillaStableDiffusionSampler:
|
||||||
self.init_latent = x
|
self.init_latent = x
|
||||||
self.step = 0
|
self.step = 0
|
||||||
|
|
||||||
samples = self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning)
|
samples = self.launch_sampling(steps, lambda: self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning))
|
||||||
|
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
|
@ -206,9 +212,9 @@ class VanillaStableDiffusionSampler:
|
||||||
|
|
||||||
# existing code fails with certain step counts, like 9
|
# existing code fails with certain step counts, like 9
|
||||||
try:
|
try:
|
||||||
samples_ddim, _ = self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)
|
samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
|
||||||
except Exception:
|
except Exception:
|
||||||
samples_ddim, _ = self.sampler.sample(S=steps+1, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)
|
samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps+1, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
|
||||||
|
|
||||||
return samples_ddim
|
return samples_ddim
|
||||||
|
|
||||||
|
@ -223,6 +229,9 @@ class CFGDenoiser(torch.nn.Module):
|
||||||
self.step = 0
|
self.step = 0
|
||||||
|
|
||||||
def forward(self, x, sigma, uncond, cond, cond_scale):
|
def forward(self, x, sigma, uncond, cond, cond_scale):
|
||||||
|
if state.interrupted or state.skipped:
|
||||||
|
raise InterruptedException
|
||||||
|
|
||||||
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
|
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
|
||||||
uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
|
uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
|
||||||
|
|
||||||
|
@ -268,25 +277,6 @@ class CFGDenoiser(torch.nn.Module):
|
||||||
return denoised
|
return denoised
|
||||||
|
|
||||||
|
|
||||||
def extended_trange(sampler, count, *args, **kwargs):
|
|
||||||
state.sampling_steps = count
|
|
||||||
state.sampling_step = 0
|
|
||||||
|
|
||||||
seq = range(count) if cmd_opts.disable_console_progressbars else tqdm.trange(count, *args, desc=state.job, file=shared.progress_print_out, **kwargs)
|
|
||||||
|
|
||||||
for x in seq:
|
|
||||||
if state.interrupted or state.skipped:
|
|
||||||
break
|
|
||||||
|
|
||||||
if sampler.stop_at is not None and x > sampler.stop_at:
|
|
||||||
break
|
|
||||||
|
|
||||||
yield x
|
|
||||||
|
|
||||||
state.sampling_step += 1
|
|
||||||
shared.total_tqdm.update()
|
|
||||||
|
|
||||||
|
|
||||||
class TorchHijack:
|
class TorchHijack:
|
||||||
def __init__(self, kdiff_sampler):
|
def __init__(self, kdiff_sampler):
|
||||||
self.kdiff_sampler = kdiff_sampler
|
self.kdiff_sampler = kdiff_sampler
|
||||||
|
@ -314,9 +304,28 @@ class KDiffusionSampler:
|
||||||
self.eta = None
|
self.eta = None
|
||||||
self.default_eta = 1.0
|
self.default_eta = 1.0
|
||||||
self.config = None
|
self.config = None
|
||||||
|
self.last_latent = None
|
||||||
|
|
||||||
def callback_state(self, d):
|
def callback_state(self, d):
|
||||||
store_latent(d["denoised"])
|
step = d['i']
|
||||||
|
latent = d["denoised"]
|
||||||
|
store_latent(latent)
|
||||||
|
self.last_latent = latent
|
||||||
|
|
||||||
|
if self.stop_at is not None and step > self.stop_at:
|
||||||
|
raise InterruptedException
|
||||||
|
|
||||||
|
state.sampling_step = step
|
||||||
|
shared.total_tqdm.update()
|
||||||
|
|
||||||
|
def launch_sampling(self, steps, func):
|
||||||
|
state.sampling_steps = steps
|
||||||
|
state.sampling_step = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
return func()
|
||||||
|
except InterruptedException:
|
||||||
|
return self.last_latent
|
||||||
|
|
||||||
def number_of_needed_noises(self, p):
|
def number_of_needed_noises(self, p):
|
||||||
return p.steps
|
return p.steps
|
||||||
|
@ -339,9 +348,6 @@ class KDiffusionSampler:
|
||||||
self.sampler_noise_index = 0
|
self.sampler_noise_index = 0
|
||||||
self.eta = p.eta or opts.eta_ancestral
|
self.eta = p.eta or opts.eta_ancestral
|
||||||
|
|
||||||
if hasattr(k_diffusion.sampling, 'trange'):
|
|
||||||
k_diffusion.sampling.trange = lambda *args, **kwargs: extended_trange(self, *args, **kwargs)
|
|
||||||
|
|
||||||
if self.sampler_noises is not None:
|
if self.sampler_noises is not None:
|
||||||
k_diffusion.sampling.torch = TorchHijack(self)
|
k_diffusion.sampling.torch = TorchHijack(self)
|
||||||
|
|
||||||
|
@ -383,8 +389,9 @@ class KDiffusionSampler:
|
||||||
|
|
||||||
self.model_wrap_cfg.init_latent = x
|
self.model_wrap_cfg.init_latent = x
|
||||||
|
|
||||||
return self.func(self.model_wrap_cfg, xi, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs)
|
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, xi, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
||||||
|
|
||||||
|
return samples
|
||||||
|
|
||||||
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None):
|
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None):
|
||||||
steps = steps or p.steps
|
steps = steps or p.steps
|
||||||
|
@ -406,6 +413,8 @@ class KDiffusionSampler:
|
||||||
extra_params_kwargs['n'] = steps
|
extra_params_kwargs['n'] = steps
|
||||||
else:
|
else:
|
||||||
extra_params_kwargs['sigmas'] = sigmas
|
extra_params_kwargs['sigmas'] = sigmas
|
||||||
samples = self.func(self.model_wrap_cfg, x, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs)
|
|
||||||
|
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
||||||
|
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
|
|
|
@ -13,7 +13,7 @@ import modules.memmon
|
||||||
import modules.sd_models
|
import modules.sd_models
|
||||||
import modules.styles
|
import modules.styles
|
||||||
import modules.devices as devices
|
import modules.devices as devices
|
||||||
from modules import sd_samplers, sd_models
|
from modules import sd_samplers, sd_models, localization
|
||||||
from modules.hypernetworks import hypernetwork
|
from modules.hypernetworks import hypernetwork
|
||||||
from modules.paths import models_path, script_path, sd_path
|
from modules.paths import models_path, script_path, sd_path
|
||||||
|
|
||||||
|
@ -31,6 +31,7 @@ parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not
|
||||||
parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI")
|
parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI")
|
||||||
parser.add_argument("--embeddings-dir", type=str, default=os.path.join(script_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)")
|
parser.add_argument("--embeddings-dir", type=str, default=os.path.join(script_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)")
|
||||||
parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory")
|
parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory")
|
||||||
|
parser.add_argument("--localizations-dir", type=str, default=os.path.join(script_path, 'localizations'), help="localizations directory")
|
||||||
parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui")
|
parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui")
|
||||||
parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a little speed for low VRM usage")
|
parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a little speed for low VRM usage")
|
||||||
parser.add_argument("--lowvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a lot of speed for very low VRM usage")
|
parser.add_argument("--lowvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a lot of speed for very low VRM usage")
|
||||||
|
@ -40,6 +41,7 @@ parser.add_argument("--unload-gfpgan", action='store_true', help="does not do an
|
||||||
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
|
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
|
||||||
parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site (doesn't work for me but you might have better luck)")
|
parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site (doesn't work for me but you might have better luck)")
|
||||||
parser.add_argument("--ngrok", type=str, help="ngrok authtoken, alternative to gradio --share", default=None)
|
parser.add_argument("--ngrok", type=str, help="ngrok authtoken, alternative to gradio --share", default=None)
|
||||||
|
parser.add_argument("--ngrok-region", type=str, help="The region in which ngrok should start.", default="us")
|
||||||
parser.add_argument("--codeformer-models-path", type=str, help="Path to directory with codeformer model file(s).", default=os.path.join(models_path, 'Codeformer'))
|
parser.add_argument("--codeformer-models-path", type=str, help="Path to directory with codeformer model file(s).", default=os.path.join(models_path, 'Codeformer'))
|
||||||
parser.add_argument("--gfpgan-models-path", type=str, help="Path to directory with GFPGAN model file(s).", default=os.path.join(models_path, 'GFPGAN'))
|
parser.add_argument("--gfpgan-models-path", type=str, help="Path to directory with GFPGAN model file(s).", default=os.path.join(models_path, 'GFPGAN'))
|
||||||
parser.add_argument("--esrgan-models-path", type=str, help="Path to directory with ESRGAN model file(s).", default=os.path.join(models_path, 'ESRGAN'))
|
parser.add_argument("--esrgan-models-path", type=str, help="Path to directory with ESRGAN model file(s).", default=os.path.join(models_path, 'ESRGAN'))
|
||||||
|
@ -68,14 +70,26 @@ parser.add_argument("--gradio-img2img-tool", type=str, help='gradio image upload
|
||||||
parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last")
|
parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last")
|
||||||
parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(script_path, 'styles.csv'))
|
parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(script_path, 'styles.csv'))
|
||||||
parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False)
|
parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False)
|
||||||
|
parser.add_argument("--theme", type=str, help="launches the UI with light or dark theme", default=None)
|
||||||
parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False)
|
parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False)
|
||||||
parser.add_argument("--disable-console-progressbars", action='store_true', help="do not output progressbars to console", default=False)
|
parser.add_argument("--disable-console-progressbars", action='store_true', help="do not output progressbars to console", default=False)
|
||||||
parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False)
|
parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False)
|
||||||
parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencoders model', default=None)
|
parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencoders model', default=None)
|
||||||
parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False)
|
parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False)
|
||||||
|
parser.add_argument("--api", action='store_true', help="use api=True to launch the api with the webui")
|
||||||
|
parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the api instead of the webui")
|
||||||
|
|
||||||
cmd_opts = parser.parse_args()
|
cmd_opts = parser.parse_args()
|
||||||
|
restricted_opts = [
|
||||||
|
"samples_filename_pattern",
|
||||||
|
"outdir_samples",
|
||||||
|
"outdir_txt2img_samples",
|
||||||
|
"outdir_img2img_samples",
|
||||||
|
"outdir_extras_samples",
|
||||||
|
"outdir_grids",
|
||||||
|
"outdir_txt2img_grids",
|
||||||
|
"outdir_save",
|
||||||
|
]
|
||||||
|
|
||||||
devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_bsrgan, devices.device_esrgan, devices.device_scunet, devices.device_codeformer = \
|
devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_bsrgan, devices.device_esrgan, devices.device_scunet, devices.device_codeformer = \
|
||||||
(devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'bsrgan', 'esrgan', 'scunet', 'codeformer'])
|
(devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'bsrgan', 'esrgan', 'scunet', 'codeformer'])
|
||||||
|
@ -92,7 +106,6 @@ os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True)
|
||||||
hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir)
|
hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir)
|
||||||
loaded_hypernetwork = None
|
loaded_hypernetwork = None
|
||||||
|
|
||||||
|
|
||||||
def reload_hypernetworks():
|
def reload_hypernetworks():
|
||||||
global hypernetworks
|
global hypernetworks
|
||||||
|
|
||||||
|
@ -140,6 +153,8 @@ interrogator = modules.interrogate.InterrogateModels("interrogate")
|
||||||
|
|
||||||
face_restorers = []
|
face_restorers = []
|
||||||
|
|
||||||
|
localization.list_localizations(cmd_opts.localizations_dir)
|
||||||
|
|
||||||
|
|
||||||
def realesrgan_models_names():
|
def realesrgan_models_names():
|
||||||
import modules.realesrgan_model
|
import modules.realesrgan_model
|
||||||
|
@ -280,11 +295,13 @@ options_templates.update(options_section(('ui', "User interface"), {
|
||||||
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
|
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
|
||||||
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
|
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
|
||||||
"add_model_name_to_info": OptionInfo(False, "Add model name to generation information"),
|
"add_model_name_to_info": OptionInfo(False, "Add model name to generation information"),
|
||||||
|
"disable_weights_auto_swap": OptionInfo(False, "When reading generation parameters from text into UI (from PNG info or pasted text), do not change the selected model/checkpoint."),
|
||||||
"font": OptionInfo("", "Font for image grids that have text"),
|
"font": OptionInfo("", "Font for image grids that have text"),
|
||||||
"js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"),
|
"js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"),
|
||||||
"js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),
|
"js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),
|
||||||
"show_progress_in_title": OptionInfo(True, "Show generation progress in window title."),
|
"show_progress_in_title": OptionInfo(True, "Show generation progress in window title."),
|
||||||
'quicksettings': OptionInfo("sd_model_checkpoint", "Quicksettings list"),
|
'quicksettings': OptionInfo("sd_model_checkpoint", "Quicksettings list"),
|
||||||
|
'localization': OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('sampler-params', "Sampler parameters"), {
|
options_templates.update(options_section(('sampler-params', "Sampler parameters"), {
|
||||||
|
|
|
@ -45,7 +45,7 @@ class StyleDatabase:
|
||||||
if not os.path.exists(path):
|
if not os.path.exists(path):
|
||||||
return
|
return
|
||||||
|
|
||||||
with open(path, "r", encoding="utf8", newline='') as file:
|
with open(path, "r", encoding="utf-8-sig", newline='') as file:
|
||||||
reader = csv.DictReader(file)
|
reader = csv.DictReader(file)
|
||||||
for row in reader:
|
for row in reader:
|
||||||
# Support loading old CSV format with "name, text"-columns
|
# Support loading old CSV format with "name, text"-columns
|
||||||
|
@ -79,7 +79,7 @@ class StyleDatabase:
|
||||||
def save_styles(self, path: str) -> None:
|
def save_styles(self, path: str) -> None:
|
||||||
# Write to temporary file first, so we don't nuke the file if something goes wrong
|
# Write to temporary file first, so we don't nuke the file if something goes wrong
|
||||||
fd, temp_path = tempfile.mkstemp(".csv")
|
fd, temp_path = tempfile.mkstemp(".csv")
|
||||||
with os.fdopen(fd, "w", encoding="utf8", newline='') as file:
|
with os.fdopen(fd, "w", encoding="utf-8-sig", newline='') as file:
|
||||||
# _fields is actually part of the public API: typing.NamedTuple is a replacement for collections.NamedTuple,
|
# _fields is actually part of the public API: typing.NamedTuple is a replacement for collections.NamedTuple,
|
||||||
# and collections.NamedTuple has explicit documentation for accessing _fields. Same goes for _asdict()
|
# and collections.NamedTuple has explicit documentation for accessing _fields. Same goes for _asdict()
|
||||||
writer = csv.DictWriter(file, fieldnames=PromptStyle._fields)
|
writer = csv.DictWriter(file, fieldnames=PromptStyle._fields)
|
||||||
|
|
|
@ -137,6 +137,7 @@ class EmbeddingDatabase:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
print(f"Loaded a total of {len(self.word_embeddings)} textual inversion embeddings.")
|
print(f"Loaded a total of {len(self.word_embeddings)} textual inversion embeddings.")
|
||||||
|
print("Embeddings:", ', '.join(self.word_embeddings.keys()))
|
||||||
|
|
||||||
def find_embedding_at_position(self, tokens, offset):
|
def find_embedding_at_position(self, tokens, offset):
|
||||||
token = tokens[offset]
|
token = tokens[offset]
|
||||||
|
@ -296,6 +297,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
|
||||||
sd_model=shared.sd_model,
|
sd_model=shared.sd_model,
|
||||||
do_not_save_grid=True,
|
do_not_save_grid=True,
|
||||||
do_not_save_samples=True,
|
do_not_save_samples=True,
|
||||||
|
do_not_reload_embeddings=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
if preview_from_txt2img:
|
if preview_from_txt2img:
|
||||||
|
|
145
modules/ui.py
145
modules/ui.py
|
@ -23,9 +23,9 @@ import gradio as gr
|
||||||
import gradio.utils
|
import gradio.utils
|
||||||
import gradio.routes
|
import gradio.routes
|
||||||
|
|
||||||
from modules import sd_hijack, sd_models
|
from modules import sd_hijack, sd_models, localization
|
||||||
from modules.paths import script_path
|
from modules.paths import script_path
|
||||||
from modules.shared import opts, cmd_opts
|
from modules.shared import opts, cmd_opts, restricted_opts
|
||||||
if cmd_opts.deepdanbooru:
|
if cmd_opts.deepdanbooru:
|
||||||
from modules.deepbooru import get_deepbooru_tags
|
from modules.deepbooru import get_deepbooru_tags
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
|
@ -56,7 +56,7 @@ if not cmd_opts.share and not cmd_opts.listen:
|
||||||
if cmd_opts.ngrok != None:
|
if cmd_opts.ngrok != None:
|
||||||
import modules.ngrok as ngrok
|
import modules.ngrok as ngrok
|
||||||
print('ngrok authtoken detected, trying to connect...')
|
print('ngrok authtoken detected, trying to connect...')
|
||||||
ngrok.connect(cmd_opts.ngrok, cmd_opts.port if cmd_opts.port != None else 7860)
|
ngrok.connect(cmd_opts.ngrok, cmd_opts.port if cmd_opts.port != None else 7860, cmd_opts.ngrok_region)
|
||||||
|
|
||||||
|
|
||||||
def gr_show(visible=True):
|
def gr_show(visible=True):
|
||||||
|
@ -261,6 +261,19 @@ def wrap_gradio_call(func, extra_outputs=None):
|
||||||
return f
|
return f
|
||||||
|
|
||||||
|
|
||||||
|
def calc_time_left(progress, threshold, label, force_display):
|
||||||
|
if progress == 0:
|
||||||
|
return ""
|
||||||
|
else:
|
||||||
|
time_since_start = time.time() - shared.state.time_start
|
||||||
|
eta = (time_since_start/progress)
|
||||||
|
eta_relative = eta-time_since_start
|
||||||
|
if (eta_relative > threshold and progress > 0.02) or force_display:
|
||||||
|
return label + time.strftime('%H:%M:%S', time.gmtime(eta_relative))
|
||||||
|
else:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
def check_progress_call(id_part):
|
def check_progress_call(id_part):
|
||||||
if shared.state.job_count == 0:
|
if shared.state.job_count == 0:
|
||||||
return "", gr_show(False), gr_show(False), gr_show(False)
|
return "", gr_show(False), gr_show(False), gr_show(False)
|
||||||
|
@ -272,11 +285,15 @@ def check_progress_call(id_part):
|
||||||
if shared.state.sampling_steps > 0:
|
if shared.state.sampling_steps > 0:
|
||||||
progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps
|
progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps
|
||||||
|
|
||||||
|
time_left = calc_time_left( progress, 60, " ETA:", shared.state.time_left_force_display )
|
||||||
|
if time_left != "":
|
||||||
|
shared.state.time_left_force_display = True
|
||||||
|
|
||||||
progress = min(progress, 1)
|
progress = min(progress, 1)
|
||||||
|
|
||||||
progressbar = ""
|
progressbar = ""
|
||||||
if opts.show_progressbar:
|
if opts.show_progressbar:
|
||||||
progressbar = f"""<div class='progressDiv'><div class='progress' style="width:{progress * 100}%">{str(int(progress*100))+"%" if progress > 0.01 else ""}</div></div>"""
|
progressbar = f"""<div class='progressDiv'><div class='progress' style="overflow:hidden;width:{progress * 100}%">{str(int(progress*100))+"%"+time_left if progress > 0.01 else ""}</div></div>"""
|
||||||
|
|
||||||
image = gr_show(False)
|
image = gr_show(False)
|
||||||
preview_visibility = gr_show(False)
|
preview_visibility = gr_show(False)
|
||||||
|
@ -308,6 +325,8 @@ def check_progress_call_initial(id_part):
|
||||||
shared.state.current_latent = None
|
shared.state.current_latent = None
|
||||||
shared.state.current_image = None
|
shared.state.current_image = None
|
||||||
shared.state.textinfo = None
|
shared.state.textinfo = None
|
||||||
|
shared.state.time_start = time.time()
|
||||||
|
shared.state.time_left_force_display = False
|
||||||
|
|
||||||
return check_progress_call(id_part)
|
return check_progress_call(id_part)
|
||||||
|
|
||||||
|
@ -508,9 +527,11 @@ def create_toprow(is_img2img):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column(scale=1, elem_id="style_pos_col"):
|
with gr.Column(scale=1, elem_id="style_pos_col"):
|
||||||
prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())))
|
prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())))
|
||||||
|
prompt_style.save_to_config = True
|
||||||
|
|
||||||
with gr.Column(scale=1, elem_id="style_neg_col"):
|
with gr.Column(scale=1, elem_id="style_neg_col"):
|
||||||
prompt_style2 = gr.Dropdown(label="Style 2", elem_id=f"{id_part}_style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())))
|
prompt_style2 = gr.Dropdown(label="Style 2", elem_id=f"{id_part}_style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())))
|
||||||
|
prompt_style2.save_to_config = True
|
||||||
|
|
||||||
return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button
|
return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button
|
||||||
|
|
||||||
|
@ -540,6 +561,10 @@ def apply_setting(key, value):
|
||||||
if value is None:
|
if value is None:
|
||||||
return gr.update()
|
return gr.update()
|
||||||
|
|
||||||
|
# dont allow model to be swapped when model hash exists in prompt
|
||||||
|
if key == "sd_model_checkpoint" and opts.disable_weights_auto_swap:
|
||||||
|
return gr.update()
|
||||||
|
|
||||||
if key == "sd_model_checkpoint":
|
if key == "sd_model_checkpoint":
|
||||||
ckpt_info = sd_models.get_closet_checkpoint_match(value)
|
ckpt_info = sd_models.get_closet_checkpoint_match(value)
|
||||||
|
|
||||||
|
@ -566,6 +591,24 @@ def create_ui(wrap_gradio_gpu_call):
|
||||||
import modules.img2img
|
import modules.img2img
|
||||||
import modules.txt2img
|
import modules.txt2img
|
||||||
|
|
||||||
|
def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
|
||||||
|
def refresh():
|
||||||
|
refresh_method()
|
||||||
|
args = refreshed_args() if callable(refreshed_args) else refreshed_args
|
||||||
|
|
||||||
|
for k, v in args.items():
|
||||||
|
setattr(refresh_component, k, v)
|
||||||
|
|
||||||
|
return gr.update(**(args or {}))
|
||||||
|
|
||||||
|
refresh_button = gr.Button(value=refresh_symbol, elem_id=elem_id)
|
||||||
|
refresh_button.click(
|
||||||
|
fn = refresh,
|
||||||
|
inputs = [],
|
||||||
|
outputs = [refresh_component]
|
||||||
|
)
|
||||||
|
return refresh_button
|
||||||
|
|
||||||
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
|
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
|
||||||
txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False)
|
txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False)
|
||||||
dummy_component = gr.Label(visible=False)
|
dummy_component = gr.Label(visible=False)
|
||||||
|
@ -1016,6 +1059,15 @@ def create_ui(wrap_gradio_gpu_call):
|
||||||
with gr.TabItem('Batch Process'):
|
with gr.TabItem('Batch Process'):
|
||||||
image_batch = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file")
|
image_batch = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file")
|
||||||
|
|
||||||
|
with gr.TabItem('Batch from Directory'):
|
||||||
|
extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs,
|
||||||
|
placeholder="A directory on the same machine where the server is running."
|
||||||
|
)
|
||||||
|
extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs,
|
||||||
|
placeholder="Leave blank to save images to the default path."
|
||||||
|
)
|
||||||
|
show_extras_results = gr.Checkbox(label='Show result images', value=True)
|
||||||
|
|
||||||
with gr.Tabs(elem_id="extras_resize_mode"):
|
with gr.Tabs(elem_id="extras_resize_mode"):
|
||||||
with gr.TabItem('Scale by'):
|
with gr.TabItem('Scale by'):
|
||||||
upscaling_resize = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Resize", value=2)
|
upscaling_resize = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Resize", value=2)
|
||||||
|
@ -1027,10 +1079,10 @@ def create_ui(wrap_gradio_gpu_call):
|
||||||
upscaling_crop = gr.Checkbox(label='Crop to fit', value=True)
|
upscaling_crop = gr.Checkbox(label='Crop to fit', value=True)
|
||||||
|
|
||||||
with gr.Group():
|
with gr.Group():
|
||||||
extras_upscaler_1 = gr.Radio(label='Upscaler 1', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index")
|
extras_upscaler_1 = gr.Radio(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index")
|
||||||
|
|
||||||
with gr.Group():
|
with gr.Group():
|
||||||
extras_upscaler_2 = gr.Radio(label='Upscaler 2', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index")
|
extras_upscaler_2 = gr.Radio(label='Upscaler 2', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index")
|
||||||
extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=1)
|
extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=1)
|
||||||
|
|
||||||
with gr.Group():
|
with gr.Group():
|
||||||
|
@ -1060,6 +1112,9 @@ def create_ui(wrap_gradio_gpu_call):
|
||||||
dummy_component,
|
dummy_component,
|
||||||
extras_image,
|
extras_image,
|
||||||
image_batch,
|
image_batch,
|
||||||
|
extras_batch_input_dir,
|
||||||
|
extras_batch_output_dir,
|
||||||
|
show_extras_results,
|
||||||
gfpgan_visibility,
|
gfpgan_visibility,
|
||||||
codeformer_visibility,
|
codeformer_visibility,
|
||||||
codeformer_weight,
|
codeformer_weight,
|
||||||
|
@ -1191,8 +1246,12 @@ def create_ui(wrap_gradio_gpu_call):
|
||||||
|
|
||||||
with gr.Tab(label="Train"):
|
with gr.Tab(label="Train"):
|
||||||
gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding; must specify a directory with a set of 1:1 ratio images</p>")
|
gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding; must specify a directory with a set of 1:1 ratio images</p>")
|
||||||
train_embedding_name = gr.Dropdown(label='Embedding', choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
|
with gr.Row():
|
||||||
train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', choices=[x for x in shared.hypernetworks.keys()])
|
train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
|
||||||
|
create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name")
|
||||||
|
with gr.Row():
|
||||||
|
train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=[x for x in shared.hypernetworks.keys()])
|
||||||
|
create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted([x for x in shared.hypernetworks.keys()])}, "refresh_train_hypernetwork_name")
|
||||||
learn_rate = gr.Textbox(label='Learning rate', placeholder="Learning rate", value="0.005")
|
learn_rate = gr.Textbox(label='Learning rate', placeholder="Learning rate", value="0.005")
|
||||||
batch_size = gr.Number(label='Batch size', value=1, precision=0)
|
batch_size = gr.Number(label='Batch size', value=1, precision=0)
|
||||||
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
|
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
|
||||||
|
@ -1301,6 +1360,8 @@ def create_ui(wrap_gradio_gpu_call):
|
||||||
batch_size,
|
batch_size,
|
||||||
dataset_directory,
|
dataset_directory,
|
||||||
log_directory,
|
log_directory,
|
||||||
|
training_width,
|
||||||
|
training_height,
|
||||||
steps,
|
steps,
|
||||||
create_image_every,
|
create_image_every,
|
||||||
save_embedding_every,
|
save_embedding_every,
|
||||||
|
@ -1340,31 +1401,18 @@ def create_ui(wrap_gradio_gpu_call):
|
||||||
else:
|
else:
|
||||||
raise Exception(f'bad options item type: {str(t)} for key {key}')
|
raise Exception(f'bad options item type: {str(t)} for key {key}')
|
||||||
|
|
||||||
|
elem_id = "setting_"+key
|
||||||
|
|
||||||
if info.refresh is not None:
|
if info.refresh is not None:
|
||||||
if is_quicksettings:
|
if is_quicksettings:
|
||||||
res = comp(label=info.label, value=fun, **(args or {}))
|
res = comp(label=info.label, value=fun, elem_id=elem_id, **(args or {}))
|
||||||
refresh_button = gr.Button(value=refresh_symbol, elem_id="refresh_"+key)
|
create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key)
|
||||||
else:
|
else:
|
||||||
with gr.Row(variant="compact"):
|
with gr.Row(variant="compact"):
|
||||||
res = comp(label=info.label, value=fun, **(args or {}))
|
res = comp(label=info.label, value=fun, elem_id=elem_id, **(args or {}))
|
||||||
refresh_button = gr.Button(value=refresh_symbol, elem_id="refresh_" + key)
|
create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key)
|
||||||
|
|
||||||
def refresh():
|
|
||||||
info.refresh()
|
|
||||||
refreshed_args = info.component_args() if callable(info.component_args) else info.component_args
|
|
||||||
|
|
||||||
for k, v in refreshed_args.items():
|
|
||||||
setattr(res, k, v)
|
|
||||||
|
|
||||||
return gr.update(**(refreshed_args or {}))
|
|
||||||
|
|
||||||
refresh_button.click(
|
|
||||||
fn=refresh,
|
|
||||||
inputs=[],
|
|
||||||
outputs=[res],
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
res = comp(label=info.label, value=fun, **(args or {}))
|
res = comp(label=info.label, value=fun, elem_id=elem_id, **(args or {}))
|
||||||
|
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
@ -1373,7 +1421,10 @@ def create_ui(wrap_gradio_gpu_call):
|
||||||
component_dict = {}
|
component_dict = {}
|
||||||
|
|
||||||
def open_folder(f):
|
def open_folder(f):
|
||||||
if not os.path.isdir(f):
|
if not os.path.exists(f):
|
||||||
|
print(f'Folder "{f}" does not exist. After you create an image, the folder will be created.')
|
||||||
|
return
|
||||||
|
elif not os.path.isdir(f):
|
||||||
print(f"""
|
print(f"""
|
||||||
WARNING
|
WARNING
|
||||||
An open_folder request was made with an argument that is not a folder.
|
An open_folder request was made with an argument that is not a folder.
|
||||||
|
@ -1406,6 +1457,9 @@ Requested path was: {f}
|
||||||
if comp_args and isinstance(comp_args, dict) and comp_args.get('visible') is False:
|
if comp_args and isinstance(comp_args, dict) and comp_args.get('visible') is False:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
if cmd_opts.hide_ui_dir_config and key in restricted_opts:
|
||||||
|
continue
|
||||||
|
|
||||||
oldval = opts.data.get(key, None)
|
oldval = opts.data.get(key, None)
|
||||||
opts.data[key] = value
|
opts.data[key] = value
|
||||||
|
|
||||||
|
@ -1423,6 +1477,9 @@ Requested path was: {f}
|
||||||
if not opts.same_type(value, opts.data_labels[key].default):
|
if not opts.same_type(value, opts.data_labels[key].default):
|
||||||
return gr.update(visible=True), opts.dumpjson()
|
return gr.update(visible=True), opts.dumpjson()
|
||||||
|
|
||||||
|
if cmd_opts.hide_ui_dir_config and key in restricted_opts:
|
||||||
|
return gr.update(value=oldval), opts.dumpjson()
|
||||||
|
|
||||||
oldval = opts.data.get(key, None)
|
oldval = opts.data.get(key, None)
|
||||||
opts.data[key] = value
|
opts.data[key] = value
|
||||||
|
|
||||||
|
@ -1479,6 +1536,9 @@ Requested path was: {f}
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications")
|
request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications")
|
||||||
|
download_localization = gr.Button(value='Download localization template', elem_id="download_localization")
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary')
|
reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary')
|
||||||
restart_gradio = gr.Button(value='Restart Gradio and Refresh components (Custom Scripts, ui.py, js and css only)', variant='primary')
|
restart_gradio = gr.Button(value='Restart Gradio and Refresh components (Custom Scripts, ui.py, js and css only)', variant='primary')
|
||||||
|
|
||||||
|
@ -1489,6 +1549,13 @@ Requested path was: {f}
|
||||||
_js='function(){}'
|
_js='function(){}'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
download_localization.click(
|
||||||
|
fn=lambda: None,
|
||||||
|
inputs=[],
|
||||||
|
outputs=[],
|
||||||
|
_js='download_localization'
|
||||||
|
)
|
||||||
|
|
||||||
def reload_scripts():
|
def reload_scripts():
|
||||||
modules.scripts.reload_script_body_only()
|
modules.scripts.reload_script_body_only()
|
||||||
reload_javascript() # need to refresh the html page
|
reload_javascript() # need to refresh the html page
|
||||||
|
@ -1692,7 +1759,7 @@ Requested path was: {f}
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
|
||||||
def loadsave(path, x):
|
def loadsave(path, x):
|
||||||
def apply_field(obj, field, condition=None):
|
def apply_field(obj, field, condition=None, init_field=None):
|
||||||
key = path + "/" + field
|
key = path + "/" + field
|
||||||
|
|
||||||
if getattr(obj,'custom_script_source',None) is not None:
|
if getattr(obj,'custom_script_source',None) is not None:
|
||||||
|
@ -1704,8 +1771,12 @@ Requested path was: {f}
|
||||||
saved_value = ui_settings.get(key, None)
|
saved_value = ui_settings.get(key, None)
|
||||||
if saved_value is None:
|
if saved_value is None:
|
||||||
ui_settings[key] = getattr(obj, field)
|
ui_settings[key] = getattr(obj, field)
|
||||||
elif condition is None or condition(saved_value):
|
elif condition and not condition(saved_value):
|
||||||
|
print(f'Warning: Bad ui setting value: {key}: {saved_value}; Default value "{getattr(obj, field)}" will be used instead.')
|
||||||
|
else:
|
||||||
setattr(obj, field, saved_value)
|
setattr(obj, field, saved_value)
|
||||||
|
if init_field is not None:
|
||||||
|
init_field(saved_value)
|
||||||
|
|
||||||
if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number] and x.visible:
|
if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number] and x.visible:
|
||||||
apply_field(x, 'visible')
|
apply_field(x, 'visible')
|
||||||
|
@ -1728,9 +1799,16 @@ Requested path was: {f}
|
||||||
if type(x) == gr.Number:
|
if type(x) == gr.Number:
|
||||||
apply_field(x, 'value')
|
apply_field(x, 'value')
|
||||||
|
|
||||||
|
# Since there are many dropdowns that shouldn't be saved,
|
||||||
|
# we only mark dropdowns that should be saved.
|
||||||
|
if type(x) == gr.Dropdown and getattr(x, 'save_to_config', False):
|
||||||
|
apply_field(x, 'value', lambda val: val in x.choices, getattr(x, 'init_field', None))
|
||||||
|
apply_field(x, 'visible')
|
||||||
|
|
||||||
visit(txt2img_interface, loadsave, "txt2img")
|
visit(txt2img_interface, loadsave, "txt2img")
|
||||||
visit(img2img_interface, loadsave, "img2img")
|
visit(img2img_interface, loadsave, "img2img")
|
||||||
visit(extras_interface, loadsave, "extras")
|
visit(extras_interface, loadsave, "extras")
|
||||||
|
visit(modelmerger_interface, loadsave, "modelmerger")
|
||||||
|
|
||||||
if not error_loading and (not os.path.exists(ui_config_file) or settings_count != len(ui_settings)):
|
if not error_loading and (not os.path.exists(ui_config_file) or settings_count != len(ui_settings)):
|
||||||
with open(ui_config_file, "w", encoding="utf8") as file:
|
with open(ui_config_file, "w", encoding="utf8") as file:
|
||||||
|
@ -1748,6 +1826,11 @@ def load_javascript(raw_response):
|
||||||
with open(os.path.join(jsdir, filename), "r", encoding="utf8") as jsfile:
|
with open(os.path.join(jsdir, filename), "r", encoding="utf8") as jsfile:
|
||||||
javascript += f"\n<!-- {filename} --><script>{jsfile.read()}</script>"
|
javascript += f"\n<!-- {filename} --><script>{jsfile.read()}</script>"
|
||||||
|
|
||||||
|
if cmd_opts.theme is not None:
|
||||||
|
javascript += f"\n<script>set_theme('{cmd_opts.theme}');</script>\n"
|
||||||
|
|
||||||
|
javascript += f"\n<script>{localization.localization_js(shared.opts.localization)}</script>"
|
||||||
|
|
||||||
def template_response(*args, **kwargs):
|
def template_response(*args, **kwargs):
|
||||||
res = raw_response(*args, **kwargs)
|
res = raw_response(*args, **kwargs)
|
||||||
res.body = res.body.replace(
|
res.body = res.body.replace(
|
||||||
|
|
|
@ -23,3 +23,4 @@ resize-right
|
||||||
torchdiffeq
|
torchdiffeq
|
||||||
kornia
|
kornia
|
||||||
lark
|
lark
|
||||||
|
inflection
|
||||||
|
|
|
@ -22,3 +22,4 @@ resize-right==0.0.2
|
||||||
torchdiffeq==0.2.3
|
torchdiffeq==0.2.3
|
||||||
kornia==0.6.7
|
kornia==0.6.7
|
||||||
lark==1.1.2
|
lark==1.1.2
|
||||||
|
inflection==0.5.1
|
||||||
|
|
10
script.js
10
script.js
|
@ -21,20 +21,20 @@ function onUiTabChange(callback){
|
||||||
uiTabChangeCallbacks.push(callback)
|
uiTabChangeCallbacks.push(callback)
|
||||||
}
|
}
|
||||||
|
|
||||||
function runCallback(x){
|
function runCallback(x, m){
|
||||||
try {
|
try {
|
||||||
x()
|
x(m)
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
(console.error || console.log).call(console, e.message, e);
|
(console.error || console.log).call(console, e.message, e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
function executeCallbacks(queue) {
|
function executeCallbacks(queue, m) {
|
||||||
queue.forEach(runCallback)
|
queue.forEach(function(x){runCallback(x, m)})
|
||||||
}
|
}
|
||||||
|
|
||||||
document.addEventListener("DOMContentLoaded", function() {
|
document.addEventListener("DOMContentLoaded", function() {
|
||||||
var mutationObserver = new MutationObserver(function(m){
|
var mutationObserver = new MutationObserver(function(m){
|
||||||
executeCallbacks(uiUpdateCallbacks);
|
executeCallbacks(uiUpdateCallbacks, m);
|
||||||
const newTab = get_uiCurrentTab();
|
const newTab = get_uiCurrentTab();
|
||||||
if ( newTab && ( newTab !== uiCurrentTab ) ) {
|
if ( newTab && ( newTab !== uiCurrentTab ) ) {
|
||||||
uiCurrentTab = newTab;
|
uiCurrentTab = newTab;
|
||||||
|
|
|
@ -233,6 +233,21 @@ def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend, include_lone_
|
||||||
return processed_result
|
return processed_result
|
||||||
|
|
||||||
|
|
||||||
|
class SharedSettingsStackHelper(object):
|
||||||
|
def __enter__(self):
|
||||||
|
self.CLIP_stop_at_last_layers = opts.CLIP_stop_at_last_layers
|
||||||
|
self.hypernetwork = opts.sd_hypernetwork
|
||||||
|
self.model = shared.sd_model
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_value, tb):
|
||||||
|
modules.sd_models.reload_model_weights(self.model)
|
||||||
|
|
||||||
|
hypernetwork.load_hypernetwork(self.hypernetwork)
|
||||||
|
hypernetwork.apply_strength()
|
||||||
|
|
||||||
|
opts.data["CLIP_stop_at_last_layers"] = self.CLIP_stop_at_last_layers
|
||||||
|
|
||||||
|
|
||||||
re_range = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\(([+-]\d+)\s*\))?\s*")
|
re_range = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\(([+-]\d+)\s*\))?\s*")
|
||||||
re_range_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\(([+-]\d+(?:.\d*)?)\s*\))?\s*")
|
re_range_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\(([+-]\d+(?:.\d*)?)\s*\))?\s*")
|
||||||
|
|
||||||
|
@ -267,9 +282,6 @@ class Script(scripts.Script):
|
||||||
if not opts.return_grid:
|
if not opts.return_grid:
|
||||||
p.batch_size = 1
|
p.batch_size = 1
|
||||||
|
|
||||||
|
|
||||||
CLIP_stop_at_last_layers = opts.CLIP_stop_at_last_layers
|
|
||||||
|
|
||||||
def process_axis(opt, vals):
|
def process_axis(opt, vals):
|
||||||
if opt.label == 'Nothing':
|
if opt.label == 'Nothing':
|
||||||
return [0]
|
return [0]
|
||||||
|
@ -367,6 +379,7 @@ class Script(scripts.Script):
|
||||||
|
|
||||||
return process_images(pc)
|
return process_images(pc)
|
||||||
|
|
||||||
|
with SharedSettingsStackHelper():
|
||||||
processed = draw_xy_grid(
|
processed = draw_xy_grid(
|
||||||
p,
|
p,
|
||||||
xs=xs,
|
xs=xs,
|
||||||
|
@ -381,13 +394,4 @@ class Script(scripts.Script):
|
||||||
if opts.grid_save:
|
if opts.grid_save:
|
||||||
images.save_image(processed.images[0], p.outpath_grids, "xy_grid", prompt=p.prompt, seed=processed.seed, grid=True, p=p)
|
images.save_image(processed.images[0], p.outpath_grids, "xy_grid", prompt=p.prompt, seed=processed.seed, grid=True, p=p)
|
||||||
|
|
||||||
# restore checkpoint in case it was changed by axes
|
|
||||||
modules.sd_models.reload_model_weights(shared.sd_model)
|
|
||||||
|
|
||||||
hypernetwork.load_hypernetwork(opts.sd_hypernetwork)
|
|
||||||
hypernetwork.apply_strength()
|
|
||||||
|
|
||||||
|
|
||||||
opts.data["CLIP_stop_at_last_layers"] = CLIP_stop_at_last_layers
|
|
||||||
|
|
||||||
return processed
|
return processed
|
||||||
|
|
|
@ -478,7 +478,7 @@ input[type="range"]{
|
||||||
padding: 0;
|
padding: 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
#refresh_sd_model_checkpoint, #refresh_sd_hypernetwork{
|
#refresh_sd_model_checkpoint, #refresh_sd_hypernetwork, #refresh_train_hypernetwork_name, #refresh_train_embedding_name, #refresh_localization{
|
||||||
max-width: 2.5em;
|
max-width: 2.5em;
|
||||||
min-width: 2.5em;
|
min-width: 2.5em;
|
||||||
height: 2.4em;
|
height: 2.4em;
|
||||||
|
|
|
@ -33,7 +33,7 @@ goto :launch
|
||||||
:skip_venv
|
:skip_venv
|
||||||
|
|
||||||
:launch
|
:launch
|
||||||
%PYTHON% launch.py
|
%PYTHON% launch.py %*
|
||||||
pause
|
pause
|
||||||
exit /b
|
exit /b
|
||||||
|
|
||||||
|
|
54
webui.py
54
webui.py
|
@ -4,7 +4,7 @@ import time
|
||||||
import importlib
|
import importlib
|
||||||
import signal
|
import signal
|
||||||
import threading
|
import threading
|
||||||
|
from fastapi import FastAPI
|
||||||
from fastapi.middleware.gzip import GZipMiddleware
|
from fastapi.middleware.gzip import GZipMiddleware
|
||||||
|
|
||||||
from modules.paths import script_path
|
from modules.paths import script_path
|
||||||
|
@ -31,7 +31,6 @@ from modules.paths import script_path
|
||||||
from modules.shared import cmd_opts
|
from modules.shared import cmd_opts
|
||||||
import modules.hypernetworks.hypernetwork
|
import modules.hypernetworks.hypernetwork
|
||||||
|
|
||||||
|
|
||||||
queue_lock = threading.Lock()
|
queue_lock = threading.Lock()
|
||||||
|
|
||||||
|
|
||||||
|
@ -87,10 +86,6 @@ def initialize():
|
||||||
shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork)))
|
shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork)))
|
||||||
shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength)
|
shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength)
|
||||||
|
|
||||||
|
|
||||||
def webui():
|
|
||||||
initialize()
|
|
||||||
|
|
||||||
# make the program just exit at ctrl+c without waiting for anything
|
# make the program just exit at ctrl+c without waiting for anything
|
||||||
def sigint_handler(sig, frame):
|
def sigint_handler(sig, frame):
|
||||||
print(f'Interrupted with signal {sig} in {frame}')
|
print(f'Interrupted with signal {sig} in {frame}')
|
||||||
|
@ -98,8 +93,35 @@ def webui():
|
||||||
|
|
||||||
signal.signal(signal.SIGINT, sigint_handler)
|
signal.signal(signal.SIGINT, sigint_handler)
|
||||||
|
|
||||||
while 1:
|
|
||||||
|
|
||||||
|
def create_api(app):
|
||||||
|
from modules.api.api import Api
|
||||||
|
api = Api(app, queue_lock)
|
||||||
|
return api
|
||||||
|
|
||||||
|
def wait_on_server(demo=None):
|
||||||
|
while 1:
|
||||||
|
time.sleep(0.5)
|
||||||
|
if demo and getattr(demo, 'do_restart', False):
|
||||||
|
time.sleep(0.5)
|
||||||
|
demo.close()
|
||||||
|
time.sleep(0.5)
|
||||||
|
break
|
||||||
|
|
||||||
|
def api_only():
|
||||||
|
initialize()
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
app.add_middleware(GZipMiddleware, minimum_size=1000)
|
||||||
|
api = create_api(app)
|
||||||
|
|
||||||
|
api.launch(server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", port=cmd_opts.port if cmd_opts.port else 7861)
|
||||||
|
|
||||||
|
|
||||||
|
def webui(launch_api=False):
|
||||||
|
initialize()
|
||||||
|
|
||||||
|
while 1:
|
||||||
demo = modules.ui.create_ui(wrap_gradio_gpu_call=wrap_gradio_gpu_call)
|
demo = modules.ui.create_ui(wrap_gradio_gpu_call=wrap_gradio_gpu_call)
|
||||||
|
|
||||||
app, local_url, share_url = demo.launch(
|
app, local_url, share_url = demo.launch(
|
||||||
|
@ -114,13 +136,10 @@ def webui():
|
||||||
|
|
||||||
app.add_middleware(GZipMiddleware, minimum_size=1000)
|
app.add_middleware(GZipMiddleware, minimum_size=1000)
|
||||||
|
|
||||||
while 1:
|
if (launch_api):
|
||||||
time.sleep(0.5)
|
create_api(app)
|
||||||
if getattr(demo, 'do_restart', False):
|
|
||||||
time.sleep(0.5)
|
wait_on_server(demo)
|
||||||
demo.close()
|
|
||||||
time.sleep(0.5)
|
|
||||||
break
|
|
||||||
|
|
||||||
sd_samplers.set_samplers()
|
sd_samplers.set_samplers()
|
||||||
|
|
||||||
|
@ -133,5 +152,10 @@ def webui():
|
||||||
print('Restarting Gradio')
|
print('Restarting Gradio')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
task = []
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
webui()
|
if cmd_opts.nowebui:
|
||||||
|
api_only()
|
||||||
|
else:
|
||||||
|
webui(cmd_opts.api)
|
2
webui.sh
2
webui.sh
|
@ -138,4 +138,4 @@ fi
|
||||||
printf "\n%s\n" "${delimiter}"
|
printf "\n%s\n" "${delimiter}"
|
||||||
printf "Launching launch.py..."
|
printf "Launching launch.py..."
|
||||||
printf "\n%s\n" "${delimiter}"
|
printf "\n%s\n" "${delimiter}"
|
||||||
"${python_cmd}" "${LAUNCH_SCRIPT}"
|
"${python_cmd}" "${LAUNCH_SCRIPT}" "$@"
|
||||||
|
|
Loading…
Reference in a new issue