Merge branch 'AUTOMATIC1111:master' into interrogate

This commit is contained in:
Vladimir Mandic 2023-01-23 12:25:07 -05:00 committed by GitHub
commit efa7287be0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 140 additions and 27 deletions

View file

@ -57,6 +57,7 @@ class LoraUpDownModule:
def __init__(self): def __init__(self):
self.up = None self.up = None
self.down = None self.down = None
self.alpha = None
def assign_lora_names_to_compvis_modules(sd_model): def assign_lora_names_to_compvis_modules(sd_model):
@ -92,6 +93,15 @@ def load_lora(name, filename):
keys_failed_to_match.append(key_diffusers) keys_failed_to_match.append(key_diffusers)
continue continue
lora_module = lora.modules.get(key, None)
if lora_module is None:
lora_module = LoraUpDownModule()
lora.modules[key] = lora_module
if lora_key == "alpha":
lora_module.alpha = weight.item()
continue
if type(sd_module) == torch.nn.Linear: if type(sd_module) == torch.nn.Linear:
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
elif type(sd_module) == torch.nn.Conv2d: elif type(sd_module) == torch.nn.Conv2d:
@ -104,17 +114,12 @@ def load_lora(name, filename):
module.to(device=devices.device, dtype=devices.dtype) module.to(device=devices.device, dtype=devices.dtype)
lora_module = lora.modules.get(key, None)
if lora_module is None:
lora_module = LoraUpDownModule()
lora.modules[key] = lora_module
if lora_key == "lora_up.weight": if lora_key == "lora_up.weight":
lora_module.up = module lora_module.up = module
elif lora_key == "lora_down.weight": elif lora_key == "lora_down.weight":
lora_module.down = module lora_module.down = module
else: else:
assert False, f'Bad Lora layer name: {key_diffusers} - must end in lora_up.weight or lora_down.weight' assert False, f'Bad Lora layer name: {key_diffusers} - must end in lora_up.weight, lora_down.weight or alpha'
if len(keys_failed_to_match) > 0: if len(keys_failed_to_match) > 0:
print(f"Failed to match keys when loading Lora {filename}: {keys_failed_to_match}") print(f"Failed to match keys when loading Lora {filename}: {keys_failed_to_match}")
@ -161,7 +166,7 @@ def lora_forward(module, input, res):
for lora in loaded_loras: for lora in loaded_loras:
module = lora.modules.get(lora_layer_name, None) module = lora.modules.get(lora_layer_name, None)
if module is not None: if module is not None:
res = res + module.up(module.down(input)) * lora.multiplier res = res + module.up(module.down(input)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
return res return res

7
html/image-update.svg Normal file
View file

@ -0,0 +1,7 @@
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24">
<filter id='shadow' color-interpolation-filters="sRGB">
<feDropShadow flood-color="black" dx="0" dy="0" flood-opacity="0.9" stdDeviation="0.5"/>
<feDropShadow flood-color="black" dx="0" dy="0" flood-opacity="0.9" stdDeviation="0.5"/>
</filter>
<path style="filter:url(#shadow);" fill="#FFFFFF" d="M13.18 19C13.35 19.72 13.64 20.39 14.03 21H5C3.9 21 3 20.11 3 19V5C3 3.9 3.9 3 5 3H19C20.11 3 21 3.9 21 5V11.18C20.5 11.07 20 11 19.5 11C19.33 11 19.17 11 19 11.03V5H5V19H13.18M11.21 15.83L9.25 13.47L6.5 17H13.03C13.14 15.54 13.73 14.22 14.64 13.19L13.96 12.29L11.21 15.83M19 13.5V12L16.75 14.25L19 16.5V15C20.38 15 21.5 16.12 21.5 17.5C21.5 17.9 21.41 18.28 21.24 18.62L22.33 19.71C22.75 19.08 23 18.32 23 17.5C23 15.29 21.21 13.5 19 13.5M19 20C17.62 20 16.5 18.88 16.5 17.5C16.5 17.1 16.59 16.72 16.76 16.38L15.67 15.29C15.25 15.92 15 16.68 15 17.5C15 19.71 16.79 21.5 19 21.5V23L21.25 20.75L19 18.5V20Z" />
</svg>

After

Width:  |  Height:  |  Size: 989 B

View file

@ -179,7 +179,7 @@ def run_extensions_installers(settings_file):
def prepare_environment(): def prepare_environment():
global skip_install global skip_install
torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113") torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 --extra-index-url https://download.pytorch.org/whl/cu117")
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt") requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
commandline_args = os.environ.get('COMMANDLINE_ARGS', "") commandline_args = os.environ.get('COMMANDLINE_ARGS', "")
@ -187,8 +187,6 @@ def prepare_environment():
clip_package = os.environ.get('CLIP_PACKAGE', "git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1") clip_package = os.environ.get('CLIP_PACKAGE', "git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1")
openclip_package = os.environ.get('OPENCLIP_PACKAGE', "git+https://github.com/mlfoundations/open_clip.git@bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b") openclip_package = os.environ.get('OPENCLIP_PACKAGE', "git+https://github.com/mlfoundations/open_clip.git@bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b")
xformers_windows_package = os.environ.get('XFORMERS_WINDOWS_PACKAGE', 'https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl')
stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git") stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git")
taming_transformers_repo = os.environ.get('TAMING_TRANSFORMERS_REPO', "https://github.com/CompVis/taming-transformers.git") taming_transformers_repo = os.environ.get('TAMING_TRANSFORMERS_REPO', "https://github.com/CompVis/taming-transformers.git")
k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git') k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git')
@ -210,6 +208,7 @@ def prepare_environment():
sys.argv, _ = extract_arg(sys.argv, '-f') sys.argv, _ = extract_arg(sys.argv, '-f')
sys.argv, skip_torch_cuda_test = extract_arg(sys.argv, '--skip-torch-cuda-test') sys.argv, skip_torch_cuda_test = extract_arg(sys.argv, '--skip-torch-cuda-test')
sys.argv, reinstall_xformers = extract_arg(sys.argv, '--reinstall-xformers') sys.argv, reinstall_xformers = extract_arg(sys.argv, '--reinstall-xformers')
sys.argv, reinstall_torch = extract_arg(sys.argv, '--reinstall-torch')
sys.argv, update_check = extract_arg(sys.argv, '--update-check') sys.argv, update_check = extract_arg(sys.argv, '--update-check')
sys.argv, run_tests, test_dir = extract_opt(sys.argv, '--tests') sys.argv, run_tests, test_dir = extract_opt(sys.argv, '--tests')
sys.argv, skip_install = extract_arg(sys.argv, '--skip-install') sys.argv, skip_install = extract_arg(sys.argv, '--skip-install')
@ -221,7 +220,7 @@ def prepare_environment():
print(f"Python {sys.version}") print(f"Python {sys.version}")
print(f"Commit hash: {commit}") print(f"Commit hash: {commit}")
if not is_installed("torch") or not is_installed("torchvision"): if reinstall_torch or not is_installed("torch") or not is_installed("torchvision"):
run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch") run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch")
if not skip_torch_cuda_test: if not skip_torch_cuda_test:
@ -239,7 +238,7 @@ def prepare_environment():
if (not is_installed("xformers") or reinstall_xformers) and xformers: if (not is_installed("xformers") or reinstall_xformers) and xformers:
if platform.system() == "Windows": if platform.system() == "Windows":
if platform.python_version().startswith("3.10"): if platform.python_version().startswith("3.10"):
run_pip(f"install -U -I --no-deps {xformers_windows_package}", "xformers") run_pip(f"install -U -I --no-deps xformers==0.0.16rc425", "xformers")
else: else:
print("Installation of xformers is not supported in this version of Python.") print("Installation of xformers is not supported in this version of Python.")
print("You can also check this and build manually: https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers#building-xformers-on-windows-by-duckness") print("You can also check this and build manually: https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers#building-xformers-on-windows-by-duckness")

View file

@ -22,6 +22,8 @@ from modules.sd_models import checkpoints_list, find_checkpoint_config
from modules.realesrgan_model import get_realesrgan_models from modules.realesrgan_model import get_realesrgan_models
from modules import devices from modules import devices
from typing import List from typing import List
import piexif
import piexif.helper
def upscaler_to_index(name: str): def upscaler_to_index(name: str):
try: try:
@ -56,18 +58,30 @@ def decode_base64_to_image(encoding):
def encode_pil_to_base64(image): def encode_pil_to_base64(image):
with io.BytesIO() as output_bytes: with io.BytesIO() as output_bytes:
# Copy any text-only metadata if opts.samples_format.lower() == 'png':
use_metadata = False use_metadata = False
metadata = PngImagePlugin.PngInfo() metadata = PngImagePlugin.PngInfo()
for key, value in image.info.items(): for key, value in image.info.items():
if isinstance(key, str) and isinstance(value, str): if isinstance(key, str) and isinstance(value, str):
metadata.add_text(key, value) metadata.add_text(key, value)
use_metadata = True use_metadata = True
image.save(output_bytes, format="PNG", pnginfo=(metadata if use_metadata else None), quality=opts.jpeg_quality)
elif opts.samples_format.lower() in ("jpg", "jpeg", "webp"):
parameters = image.info.get('parameters', None)
exif_bytes = piexif.dump({
"Exif": { piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(parameters or "", encoding="unicode") }
})
if opts.samples_format.lower() in ("jpg", "jpeg"):
image.save(output_bytes, format="JPEG", exif = exif_bytes, quality=opts.jpeg_quality)
else:
image.save(output_bytes, format="WEBP", exif = exif_bytes, quality=opts.jpeg_quality)
else:
raise HTTPException(status_code=500, detail="Invalid image format")
image.save(
output_bytes, "PNG", pnginfo=(metadata if use_metadata else None)
)
bytes_data = output_bytes.getvalue() bytes_data = output_bytes.getvalue()
return base64.b64encode(bytes_data) return base64.b64encode(bytes_data)
def api_middleware(app: FastAPI): def api_middleware(app: FastAPI):

View file

@ -67,7 +67,7 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,
pp.image.info["postprocessing"] = infotext pp.image.info["postprocessing"] = infotext
if save_output: if save_output:
images.save_image(pp.image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=pp.info, short_filename=True, no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None) images.save_image(pp.image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=infotext, short_filename=True, no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None)
if extras_mode != 2 or show_extras_results: if extras_mode != 2 or show_extras_results:
outputs.append(pp.image) outputs.append(pp.image)

View file

@ -432,6 +432,10 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"),
"deepbooru_filter_tags": OptionInfo("", "filter out those tags from deepbooru output (separated by comma)"), "deepbooru_filter_tags": OptionInfo("", "filter out those tags from deepbooru output (separated by comma)"),
})) }))
options_templates.update(options_section(('extra_networks', "Extra Networks"), {
"extra_networks_default_view": OptionInfo("cards", "Default view for Extra Networks", gr.Dropdown, { "choices": ["cards", "thumbs"] }),
}))
options_templates.update(options_section(('ui', "User interface"), { options_templates.update(options_section(('ui', "User interface"), {
"return_grid": OptionInfo(True, "Show grid in results for web"), "return_grid": OptionInfo(True, "Show grid in results for web"),
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"), "do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),

View file

@ -26,6 +26,7 @@ class ExtraNetworksPage:
pass pass
def create_html(self, tabname): def create_html(self, tabname):
view = shared.opts.extra_networks_default_view
items_html = '' items_html = ''
for item in self.list_items(): for item in self.list_items():
@ -36,7 +37,7 @@ class ExtraNetworksPage:
items_html = shared.html("extra-networks-no-cards.html").format(dirs=dirs) items_html = shared.html("extra-networks-no-cards.html").format(dirs=dirs)
res = f""" res = f"""
<div id='{tabname}_{self.name}_cards' class='extra-network-cards'> <div id='{tabname}_{self.name}_cards' class='extra-network-{view}'>
{items_html} {items_html}
</div> </div>
""" """

View file

@ -792,21 +792,78 @@ footer {
display: inline-block; display: inline-block;
max-width: 16em; max-width: 16em;
margin: 0.3em; margin: 0.3em;
align-self: center;
} }
.extra-network-cards .nocards{ #txt2img_extra_view, #img2img_extra_view {
width: auto;
}
.extra-network-cards .nocards, .extra-network-thumbs .nocards{
margin: 1.25em 0.5em 0.5em 0.5em; margin: 1.25em 0.5em 0.5em 0.5em;
} }
.extra-network-cards .nocards h1{ .extra-network-cards .nocards h1, .extra-network-thumbs .nocards h1{
font-size: 1.5em; font-size: 1.5em;
margin-bottom: 1em; margin-bottom: 1em;
} }
.extra-network-cards .nocards li{ .extra-network-cards .nocards li, .extra-network-thumbs .nocards li{
margin-left: 0.5em; margin-left: 0.5em;
} }
.extra-network-thumbs {
display: flex;
flex-flow: row wrap;
gap: 10px;
}
.extra-network-thumbs .card {
height: 6em;
width: 6em;
cursor: pointer;
background-image: url('./file=html/card-no-preview.png');
background-size: cover;
background-position: center center;
position: relative;
}
.extra-network-thumbs .card:hover .additional a {
display: block;
}
.extra-network-thumbs .actions .additional a {
background-image: url('./file=html/image-update.svg');
background-repeat: no-repeat;
background-size: cover;
background-position: center center;
position: absolute;
top: 0;
left: 0;
width: 24px;
height: 24px;
display: none;
font-size: 0;
text-align: -9999;
}
.extra-network-thumbs .actions .name {
position: absolute;
bottom: 0;
font-size: 10px;
padding: 3px;
width: 100%;
overflow: hidden;
white-space: nowrap;
text-overflow: ellipsis;
background: rgba(0,0,0,.5);
}
.extra-network-thumbs .card:hover .actions .name {
white-space: normal;
word-break: break-all;
}
.extra-network-cards .card{ .extra-network-cards .card{
display: inline-block; display: inline-block;
margin: 0.5em; margin: 0.5em;

View file

@ -8,6 +8,7 @@ import re
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware from fastapi.middleware.gzip import GZipMiddleware
from packaging import version
from modules import import_hook, errors, extra_networks from modules import import_hook, errors, extra_networks
from modules import extra_networks_hypernet, ui_extra_networks_hypernets, ui_extra_networks_textual_inversion from modules import extra_networks_hypernet, ui_extra_networks_hypernets, ui_extra_networks_textual_inversion
@ -49,7 +50,32 @@ else:
server_name = "0.0.0.0" if cmd_opts.listen else None server_name = "0.0.0.0" if cmd_opts.listen else None
def check_versions():
expected_torch_version = "1.13.1"
if version.parse(torch.__version__) < version.parse(expected_torch_version):
errors.print_error_explanation(f"""
You are running torch {torch.__version__}.
The program is tested to work with torch {expected_torch_version}.
To reinstall the desired version, run with commandline flag --reinstall-torch.
Beware that this will cause a lot of large files to be downloaded.
""".strip())
expected_xformers_version = "0.0.16rc425"
if shared.xformers_available:
import xformers
if version.parse(xformers.__version__) < version.parse(expected_xformers_version):
errors.print_error_explanation(f"""
You are running xformers {xformers.__version__}.
The program is tested to work with xformers {expected_xformers_version}.
To reinstall the desired version, run with commandline flag --reinstall-xformers.
""".strip())
def initialize(): def initialize():
check_versions()
extensions.list_extensions() extensions.list_extensions()
localization.list_localizations(cmd_opts.localizations_dir) localization.list_localizations(cmd_opts.localizations_dir)