Merge branch 'AUTOMATIC1111:master' into interrogate
This commit is contained in:
commit
efa7287be0
9 changed files with 140 additions and 27 deletions
|
@ -57,6 +57,7 @@ class LoraUpDownModule:
|
|||
def __init__(self):
|
||||
self.up = None
|
||||
self.down = None
|
||||
self.alpha = None
|
||||
|
||||
|
||||
def assign_lora_names_to_compvis_modules(sd_model):
|
||||
|
@ -92,6 +93,15 @@ def load_lora(name, filename):
|
|||
keys_failed_to_match.append(key_diffusers)
|
||||
continue
|
||||
|
||||
lora_module = lora.modules.get(key, None)
|
||||
if lora_module is None:
|
||||
lora_module = LoraUpDownModule()
|
||||
lora.modules[key] = lora_module
|
||||
|
||||
if lora_key == "alpha":
|
||||
lora_module.alpha = weight.item()
|
||||
continue
|
||||
|
||||
if type(sd_module) == torch.nn.Linear:
|
||||
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
|
||||
elif type(sd_module) == torch.nn.Conv2d:
|
||||
|
@ -104,17 +114,12 @@ def load_lora(name, filename):
|
|||
|
||||
module.to(device=devices.device, dtype=devices.dtype)
|
||||
|
||||
lora_module = lora.modules.get(key, None)
|
||||
if lora_module is None:
|
||||
lora_module = LoraUpDownModule()
|
||||
lora.modules[key] = lora_module
|
||||
|
||||
if lora_key == "lora_up.weight":
|
||||
lora_module.up = module
|
||||
elif lora_key == "lora_down.weight":
|
||||
lora_module.down = module
|
||||
else:
|
||||
assert False, f'Bad Lora layer name: {key_diffusers} - must end in lora_up.weight or lora_down.weight'
|
||||
assert False, f'Bad Lora layer name: {key_diffusers} - must end in lora_up.weight, lora_down.weight or alpha'
|
||||
|
||||
if len(keys_failed_to_match) > 0:
|
||||
print(f"Failed to match keys when loading Lora {filename}: {keys_failed_to_match}")
|
||||
|
@ -161,7 +166,7 @@ def lora_forward(module, input, res):
|
|||
for lora in loaded_loras:
|
||||
module = lora.modules.get(lora_layer_name, None)
|
||||
if module is not None:
|
||||
res = res + module.up(module.down(input)) * lora.multiplier
|
||||
res = res + module.up(module.down(input)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
|
||||
|
||||
return res
|
||||
|
||||
|
|
7
html/image-update.svg
Normal file
7
html/image-update.svg
Normal file
|
@ -0,0 +1,7 @@
|
|||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24">
|
||||
<filter id='shadow' color-interpolation-filters="sRGB">
|
||||
<feDropShadow flood-color="black" dx="0" dy="0" flood-opacity="0.9" stdDeviation="0.5"/>
|
||||
<feDropShadow flood-color="black" dx="0" dy="0" flood-opacity="0.9" stdDeviation="0.5"/>
|
||||
</filter>
|
||||
<path style="filter:url(#shadow);" fill="#FFFFFF" d="M13.18 19C13.35 19.72 13.64 20.39 14.03 21H5C3.9 21 3 20.11 3 19V5C3 3.9 3.9 3 5 3H19C20.11 3 21 3.9 21 5V11.18C20.5 11.07 20 11 19.5 11C19.33 11 19.17 11 19 11.03V5H5V19H13.18M11.21 15.83L9.25 13.47L6.5 17H13.03C13.14 15.54 13.73 14.22 14.64 13.19L13.96 12.29L11.21 15.83M19 13.5V12L16.75 14.25L19 16.5V15C20.38 15 21.5 16.12 21.5 17.5C21.5 17.9 21.41 18.28 21.24 18.62L22.33 19.71C22.75 19.08 23 18.32 23 17.5C23 15.29 21.21 13.5 19 13.5M19 20C17.62 20 16.5 18.88 16.5 17.5C16.5 17.1 16.59 16.72 16.76 16.38L15.67 15.29C15.25 15.92 15 16.68 15 17.5C15 19.71 16.79 21.5 19 21.5V23L21.25 20.75L19 18.5V20Z" />
|
||||
</svg>
|
After Width: | Height: | Size: 989 B |
|
@ -179,7 +179,7 @@ def run_extensions_installers(settings_file):
|
|||
def prepare_environment():
|
||||
global skip_install
|
||||
|
||||
torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113")
|
||||
torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 --extra-index-url https://download.pytorch.org/whl/cu117")
|
||||
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
|
||||
commandline_args = os.environ.get('COMMANDLINE_ARGS', "")
|
||||
|
||||
|
@ -187,8 +187,6 @@ def prepare_environment():
|
|||
clip_package = os.environ.get('CLIP_PACKAGE', "git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1")
|
||||
openclip_package = os.environ.get('OPENCLIP_PACKAGE', "git+https://github.com/mlfoundations/open_clip.git@bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b")
|
||||
|
||||
xformers_windows_package = os.environ.get('XFORMERS_WINDOWS_PACKAGE', 'https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl')
|
||||
|
||||
stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git")
|
||||
taming_transformers_repo = os.environ.get('TAMING_TRANSFORMERS_REPO', "https://github.com/CompVis/taming-transformers.git")
|
||||
k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git')
|
||||
|
@ -210,6 +208,7 @@ def prepare_environment():
|
|||
sys.argv, _ = extract_arg(sys.argv, '-f')
|
||||
sys.argv, skip_torch_cuda_test = extract_arg(sys.argv, '--skip-torch-cuda-test')
|
||||
sys.argv, reinstall_xformers = extract_arg(sys.argv, '--reinstall-xformers')
|
||||
sys.argv, reinstall_torch = extract_arg(sys.argv, '--reinstall-torch')
|
||||
sys.argv, update_check = extract_arg(sys.argv, '--update-check')
|
||||
sys.argv, run_tests, test_dir = extract_opt(sys.argv, '--tests')
|
||||
sys.argv, skip_install = extract_arg(sys.argv, '--skip-install')
|
||||
|
@ -221,7 +220,7 @@ def prepare_environment():
|
|||
print(f"Python {sys.version}")
|
||||
print(f"Commit hash: {commit}")
|
||||
|
||||
if not is_installed("torch") or not is_installed("torchvision"):
|
||||
if reinstall_torch or not is_installed("torch") or not is_installed("torchvision"):
|
||||
run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch")
|
||||
|
||||
if not skip_torch_cuda_test:
|
||||
|
@ -239,7 +238,7 @@ def prepare_environment():
|
|||
if (not is_installed("xformers") or reinstall_xformers) and xformers:
|
||||
if platform.system() == "Windows":
|
||||
if platform.python_version().startswith("3.10"):
|
||||
run_pip(f"install -U -I --no-deps {xformers_windows_package}", "xformers")
|
||||
run_pip(f"install -U -I --no-deps xformers==0.0.16rc425", "xformers")
|
||||
else:
|
||||
print("Installation of xformers is not supported in this version of Python.")
|
||||
print("You can also check this and build manually: https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers#building-xformers-on-windows-by-duckness")
|
||||
|
|
|
@ -22,6 +22,8 @@ from modules.sd_models import checkpoints_list, find_checkpoint_config
|
|||
from modules.realesrgan_model import get_realesrgan_models
|
||||
from modules import devices
|
||||
from typing import List
|
||||
import piexif
|
||||
import piexif.helper
|
||||
|
||||
def upscaler_to_index(name: str):
|
||||
try:
|
||||
|
@ -56,18 +58,30 @@ def decode_base64_to_image(encoding):
|
|||
def encode_pil_to_base64(image):
|
||||
with io.BytesIO() as output_bytes:
|
||||
|
||||
# Copy any text-only metadata
|
||||
if opts.samples_format.lower() == 'png':
|
||||
use_metadata = False
|
||||
metadata = PngImagePlugin.PngInfo()
|
||||
for key, value in image.info.items():
|
||||
if isinstance(key, str) and isinstance(value, str):
|
||||
metadata.add_text(key, value)
|
||||
use_metadata = True
|
||||
image.save(output_bytes, format="PNG", pnginfo=(metadata if use_metadata else None), quality=opts.jpeg_quality)
|
||||
|
||||
elif opts.samples_format.lower() in ("jpg", "jpeg", "webp"):
|
||||
parameters = image.info.get('parameters', None)
|
||||
exif_bytes = piexif.dump({
|
||||
"Exif": { piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(parameters or "", encoding="unicode") }
|
||||
})
|
||||
if opts.samples_format.lower() in ("jpg", "jpeg"):
|
||||
image.save(output_bytes, format="JPEG", exif = exif_bytes, quality=opts.jpeg_quality)
|
||||
else:
|
||||
image.save(output_bytes, format="WEBP", exif = exif_bytes, quality=opts.jpeg_quality)
|
||||
|
||||
else:
|
||||
raise HTTPException(status_code=500, detail="Invalid image format")
|
||||
|
||||
image.save(
|
||||
output_bytes, "PNG", pnginfo=(metadata if use_metadata else None)
|
||||
)
|
||||
bytes_data = output_bytes.getvalue()
|
||||
|
||||
return base64.b64encode(bytes_data)
|
||||
|
||||
def api_middleware(app: FastAPI):
|
||||
|
|
|
@ -67,7 +67,7 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,
|
|||
pp.image.info["postprocessing"] = infotext
|
||||
|
||||
if save_output:
|
||||
images.save_image(pp.image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=pp.info, short_filename=True, no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None)
|
||||
images.save_image(pp.image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=infotext, short_filename=True, no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None)
|
||||
|
||||
if extras_mode != 2 or show_extras_results:
|
||||
outputs.append(pp.image)
|
||||
|
|
|
@ -432,6 +432,10 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"),
|
|||
"deepbooru_filter_tags": OptionInfo("", "filter out those tags from deepbooru output (separated by comma)"),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('extra_networks', "Extra Networks"), {
|
||||
"extra_networks_default_view": OptionInfo("cards", "Default view for Extra Networks", gr.Dropdown, { "choices": ["cards", "thumbs"] }),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('ui', "User interface"), {
|
||||
"return_grid": OptionInfo(True, "Show grid in results for web"),
|
||||
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
|
||||
|
|
|
@ -26,6 +26,7 @@ class ExtraNetworksPage:
|
|||
pass
|
||||
|
||||
def create_html(self, tabname):
|
||||
view = shared.opts.extra_networks_default_view
|
||||
items_html = ''
|
||||
|
||||
for item in self.list_items():
|
||||
|
@ -36,7 +37,7 @@ class ExtraNetworksPage:
|
|||
items_html = shared.html("extra-networks-no-cards.html").format(dirs=dirs)
|
||||
|
||||
res = f"""
|
||||
<div id='{tabname}_{self.name}_cards' class='extra-network-cards'>
|
||||
<div id='{tabname}_{self.name}_cards' class='extra-network-{view}'>
|
||||
{items_html}
|
||||
</div>
|
||||
"""
|
||||
|
|
63
style.css
63
style.css
|
@ -792,21 +792,78 @@ footer {
|
|||
display: inline-block;
|
||||
max-width: 16em;
|
||||
margin: 0.3em;
|
||||
align-self: center;
|
||||
}
|
||||
|
||||
.extra-network-cards .nocards{
|
||||
#txt2img_extra_view, #img2img_extra_view {
|
||||
width: auto;
|
||||
}
|
||||
|
||||
.extra-network-cards .nocards, .extra-network-thumbs .nocards{
|
||||
margin: 1.25em 0.5em 0.5em 0.5em;
|
||||
}
|
||||
|
||||
.extra-network-cards .nocards h1{
|
||||
.extra-network-cards .nocards h1, .extra-network-thumbs .nocards h1{
|
||||
font-size: 1.5em;
|
||||
margin-bottom: 1em;
|
||||
}
|
||||
|
||||
.extra-network-cards .nocards li{
|
||||
.extra-network-cards .nocards li, .extra-network-thumbs .nocards li{
|
||||
margin-left: 0.5em;
|
||||
}
|
||||
|
||||
.extra-network-thumbs {
|
||||
display: flex;
|
||||
flex-flow: row wrap;
|
||||
gap: 10px;
|
||||
}
|
||||
|
||||
.extra-network-thumbs .card {
|
||||
height: 6em;
|
||||
width: 6em;
|
||||
cursor: pointer;
|
||||
background-image: url('./file=html/card-no-preview.png');
|
||||
background-size: cover;
|
||||
background-position: center center;
|
||||
position: relative;
|
||||
}
|
||||
|
||||
.extra-network-thumbs .card:hover .additional a {
|
||||
display: block;
|
||||
}
|
||||
|
||||
.extra-network-thumbs .actions .additional a {
|
||||
background-image: url('./file=html/image-update.svg');
|
||||
background-repeat: no-repeat;
|
||||
background-size: cover;
|
||||
background-position: center center;
|
||||
position: absolute;
|
||||
top: 0;
|
||||
left: 0;
|
||||
width: 24px;
|
||||
height: 24px;
|
||||
display: none;
|
||||
font-size: 0;
|
||||
text-align: -9999;
|
||||
}
|
||||
|
||||
.extra-network-thumbs .actions .name {
|
||||
position: absolute;
|
||||
bottom: 0;
|
||||
font-size: 10px;
|
||||
padding: 3px;
|
||||
width: 100%;
|
||||
overflow: hidden;
|
||||
white-space: nowrap;
|
||||
text-overflow: ellipsis;
|
||||
background: rgba(0,0,0,.5);
|
||||
}
|
||||
|
||||
.extra-network-thumbs .card:hover .actions .name {
|
||||
white-space: normal;
|
||||
word-break: break-all;
|
||||
}
|
||||
|
||||
.extra-network-cards .card{
|
||||
display: inline-block;
|
||||
margin: 0.5em;
|
||||
|
|
26
webui.py
26
webui.py
|
@ -8,6 +8,7 @@ import re
|
|||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.middleware.gzip import GZipMiddleware
|
||||
from packaging import version
|
||||
|
||||
from modules import import_hook, errors, extra_networks
|
||||
from modules import extra_networks_hypernet, ui_extra_networks_hypernets, ui_extra_networks_textual_inversion
|
||||
|
@ -49,7 +50,32 @@ else:
|
|||
server_name = "0.0.0.0" if cmd_opts.listen else None
|
||||
|
||||
|
||||
def check_versions():
|
||||
expected_torch_version = "1.13.1"
|
||||
|
||||
if version.parse(torch.__version__) < version.parse(expected_torch_version):
|
||||
errors.print_error_explanation(f"""
|
||||
You are running torch {torch.__version__}.
|
||||
The program is tested to work with torch {expected_torch_version}.
|
||||
To reinstall the desired version, run with commandline flag --reinstall-torch.
|
||||
Beware that this will cause a lot of large files to be downloaded.
|
||||
""".strip())
|
||||
|
||||
expected_xformers_version = "0.0.16rc425"
|
||||
if shared.xformers_available:
|
||||
import xformers
|
||||
|
||||
if version.parse(xformers.__version__) < version.parse(expected_xformers_version):
|
||||
errors.print_error_explanation(f"""
|
||||
You are running xformers {xformers.__version__}.
|
||||
The program is tested to work with xformers {expected_xformers_version}.
|
||||
To reinstall the desired version, run with commandline flag --reinstall-xformers.
|
||||
""".strip())
|
||||
|
||||
|
||||
def initialize():
|
||||
check_versions()
|
||||
|
||||
extensions.list_extensions()
|
||||
localization.list_localizations(cmd_opts.localizations_dir)
|
||||
|
||||
|
|
Loading…
Reference in a new issue