make launch.py run installers for extensions that have ones

add some more classes to safety module for an extension
This commit is contained in:
AUTOMATIC 2022-11-01 14:19:24 +03:00
parent f126986b76
commit d35bf64945
2 changed files with 21 additions and 3 deletions

View file

@ -7,6 +7,7 @@ import shlex
import platform import platform
dir_repos = "repositories" dir_repos = "repositories"
dir_extensions = "extensions"
python = sys.executable python = sys.executable
git = os.environ.get('GIT', "git") git = os.environ.get('GIT', "git")
index_url = os.environ.get('INDEX_URL', "") index_url = os.environ.get('INDEX_URL', "")
@ -101,9 +102,24 @@ def version_check(commit):
else: else:
print("Not a git clone, can't perform version check.") print("Not a git clone, can't perform version check.")
except Exception as e: except Exception as e:
print("versipm check failed",e) print("version check failed", e)
def run_extensions_installers():
if not os.path.isdir(dir_extensions):
return
for dirname_extension in os.listdir(dir_extensions):
path_installer = os.path.join(dir_extensions, dirname_extension, "install.py")
if not os.path.isfile(path_installer):
continue
try:
print(run(f'"{python}" "{path_installer}"', errdesc=f"Error running install.py for extension {dirname_extension}"))
except Exception as e:
print(e, file=sys.stderr)
def prepare_enviroment(): def prepare_enviroment():
torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113") torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113")
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt") requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
@ -189,6 +205,8 @@ def prepare_enviroment():
run_pip(f"install -r {requirements_file}", "requirements for Web UI") run_pip(f"install -r {requirements_file}", "requirements for Web UI")
run_extensions_installers()
if update_check: if update_check:
version_check(commit) version_check(commit)

View file

@ -32,7 +32,7 @@ class RestrictedUnpickler(pickle.Unpickler):
return getattr(collections, name) return getattr(collections, name)
if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter']: if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter']:
return getattr(torch._utils, name) return getattr(torch._utils, name)
if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage']: if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage', 'ByteStorage']:
return getattr(torch, name) return getattr(torch, name)
if module == 'torch.nn.modules.container' and name in ['ParameterDict']: if module == 'torch.nn.modules.container' and name in ['ParameterDict']:
return getattr(torch.nn.modules.container, name) return getattr(torch.nn.modules.container, name)