add safetensors to requirements

This commit is contained in:
AUTOMATIC 2022-11-27 14:46:40 +03:00
parent f108782e30
commit 6074175faa
3 changed files with 7 additions and 6 deletions

View file

@ -5,6 +5,7 @@ import gc
from collections import namedtuple from collections import namedtuple
import torch import torch
import re import re
import safetensors.torch
from omegaconf import OmegaConf from omegaconf import OmegaConf
from ldm.util import instantiate_from_config from ldm.util import instantiate_from_config
@ -173,14 +174,12 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"):
# load from file # load from file
print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}") print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}")
if checkpoint_file.endswith(".safetensors"): _, extension = os.path.splitext(checkpoint_file)
try: if extension.lower() == ".safetensors":
from safetensors.torch import load_file pl_sd = safetensors.torch.load_file(checkpoint_file, device=shared.weight_load_location)
except ImportError as e:
raise ImportError(f"The model is in safetensors format and it is not installed, use `pip install safetensors`: {e}")
pl_sd = load_file(checkpoint_file, device=shared.weight_load_location)
else: else:
pl_sd = torch.load(checkpoint_file, map_location=shared.weight_load_location) pl_sd = torch.load(checkpoint_file, map_location=shared.weight_load_location)
if "global_step" in pl_sd: if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}") print(f"Global Step: {pl_sd['global_step']}")

View file

@ -29,3 +29,4 @@ lark
inflection inflection
GitPython GitPython
torchsde torchsde
safetensors

View file

@ -26,3 +26,4 @@ lark==1.1.2
inflection==0.5.1 inflection==0.5.1
GitPython==3.1.27 GitPython==3.1.27
torchsde==0.2.5 torchsde==0.2.5
safetensors==0.2.5