add a message about new torch/xformers version and a way to upgrade by specifying a commandline flag

This commit is contained in:
AUTOMATIC 2023-01-23 17:17:31 +03:00
parent 56f63cd498
commit 7ff1ef77dd
2 changed files with 28 additions and 1 deletions

View file

@ -208,6 +208,7 @@ def prepare_environment():
sys.argv, _ = extract_arg(sys.argv, '-f') sys.argv, _ = extract_arg(sys.argv, '-f')
sys.argv, skip_torch_cuda_test = extract_arg(sys.argv, '--skip-torch-cuda-test') sys.argv, skip_torch_cuda_test = extract_arg(sys.argv, '--skip-torch-cuda-test')
sys.argv, reinstall_xformers = extract_arg(sys.argv, '--reinstall-xformers') sys.argv, reinstall_xformers = extract_arg(sys.argv, '--reinstall-xformers')
sys.argv, reinstall_torch = extract_arg(sys.argv, '--reinstall-torch')
sys.argv, update_check = extract_arg(sys.argv, '--update-check') sys.argv, update_check = extract_arg(sys.argv, '--update-check')
sys.argv, run_tests, test_dir = extract_opt(sys.argv, '--tests') sys.argv, run_tests, test_dir = extract_opt(sys.argv, '--tests')
sys.argv, skip_install = extract_arg(sys.argv, '--skip-install') sys.argv, skip_install = extract_arg(sys.argv, '--skip-install')
@ -219,7 +220,7 @@ def prepare_environment():
print(f"Python {sys.version}") print(f"Python {sys.version}")
print(f"Commit hash: {commit}") print(f"Commit hash: {commit}")
if not is_installed("torch") or not is_installed("torchvision"): if reinstall_torch or not is_installed("torch") or not is_installed("torchvision"):
run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch") run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch")
if not skip_torch_cuda_test: if not skip_torch_cuda_test:

View file

@ -8,6 +8,7 @@ import re
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware from fastapi.middleware.gzip import GZipMiddleware
from packaging import version
from modules import import_hook, errors, extra_networks from modules import import_hook, errors, extra_networks
from modules import extra_networks_hypernet, ui_extra_networks_hypernets, ui_extra_networks_textual_inversion from modules import extra_networks_hypernet, ui_extra_networks_hypernets, ui_extra_networks_textual_inversion
@ -49,7 +50,32 @@ else:
server_name = "0.0.0.0" if cmd_opts.listen else None server_name = "0.0.0.0" if cmd_opts.listen else None
def check_versions():
expected_torch_version = "1.13.1"
if version.parse(torch.__version__) < version.parse(expected_torch_version):
errors.print_error_explanation(f"""
You are running torch {torch.__version__}.
The program is tested to work with torch {expected_torch_version}.
To reinstall the desired version, run with commandline flag --reinstall-torch.
Beware that this will cause a lot of large files to be downloaded.
""".strip())
expected_xformers_version = "0.0.16rc425"
if shared.xformers_available:
import xformers
if version.parse(xformers.__version__) < version.parse(expected_xformers_version):
errors.print_error_explanation(f"""
You are running xformers {xformers.__version__}.
The program is tested to work with xformers {expected_xformers_version}.
To reinstall the desired version, run with commandline flag --reinstall-xformers.
""".strip())
def initialize(): def initialize():
check_versions()
extensions.list_extensions() extensions.list_extensions()
localization.list_localizations(cmd_opts.localizations_dir) localization.list_localizations(cmd_opts.localizations_dir)