Merge remote-tracking branch 'origin/master'
This commit is contained in:
commit
2e6153e343
4 changed files with 19 additions and 13 deletions
12
modules/devices.py
Normal file
12
modules/devices.py
Normal file
|
@ -0,0 +1,12 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
# has_mps is only available in nightly pytorch (for now), `getattr` for compatibility
|
||||||
|
has_mps = getattr(torch, 'has_mps', False)
|
||||||
|
|
||||||
|
def get_optimal_device():
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
return torch.device("cuda")
|
||||||
|
if has_mps:
|
||||||
|
return torch.device("mps")
|
||||||
|
return torch.device("cpu")
|
|
@ -9,12 +9,13 @@ from PIL import Image
|
||||||
import modules.esrgam_model_arch as arch
|
import modules.esrgam_model_arch as arch
|
||||||
from modules import shared
|
from modules import shared
|
||||||
from modules.shared import opts
|
from modules.shared import opts
|
||||||
|
from modules.devices import has_mps
|
||||||
import modules.images
|
import modules.images
|
||||||
|
|
||||||
|
|
||||||
def load_model(filename):
|
def load_model(filename):
|
||||||
# this code is adapted from https://github.com/xinntao/ESRGAN
|
# this code is adapted from https://github.com/xinntao/ESRGAN
|
||||||
pretrained_net = torch.load(filename, map_location='cpu' if torch.has_mps else None)
|
pretrained_net = torch.load(filename, map_location='cpu' if has_mps else None)
|
||||||
crt_model = arch.RRDBNet(3, 3, 64, 23, gc=32)
|
crt_model = arch.RRDBNet(3, 3, 64, 23, gc=32)
|
||||||
|
|
||||||
if 'conv_first.weight' in pretrained_net:
|
if 'conv_first.weight' in pretrained_net:
|
||||||
|
|
|
@ -1,13 +1,9 @@
|
||||||
import torch
|
import torch
|
||||||
|
from modules.devices import get_optimal_device
|
||||||
|
|
||||||
module_in_gpu = None
|
module_in_gpu = None
|
||||||
cpu = torch.device("cpu")
|
cpu = torch.device("cpu")
|
||||||
if torch.has_cuda:
|
device = gpu = get_optimal_device()
|
||||||
device = gpu = torch.device("cuda")
|
|
||||||
elif torch.has_mps:
|
|
||||||
device = gpu = torch.device("mps")
|
|
||||||
else:
|
|
||||||
device = gpu = torch.device("cpu")
|
|
||||||
|
|
||||||
def setup_for_low_vram(sd_model, use_medvram):
|
def setup_for_low_vram(sd_model, use_medvram):
|
||||||
parents = {}
|
parents = {}
|
||||||
|
|
|
@ -9,6 +9,7 @@ import tqdm
|
||||||
|
|
||||||
import modules.artists
|
import modules.artists
|
||||||
from modules.paths import script_path, sd_path
|
from modules.paths import script_path, sd_path
|
||||||
|
from modules.devices import get_optimal_device
|
||||||
import modules.styles
|
import modules.styles
|
||||||
|
|
||||||
config_filename = "config.json"
|
config_filename = "config.json"
|
||||||
|
@ -43,12 +44,8 @@ parser.add_argument("--ui-config-file", type=str, help="filename to use for ui c
|
||||||
|
|
||||||
cmd_opts = parser.parse_args()
|
cmd_opts = parser.parse_args()
|
||||||
|
|
||||||
if torch.has_cuda:
|
device = get_optimal_device()
|
||||||
device = torch.device("cuda")
|
|
||||||
elif torch.has_mps:
|
|
||||||
device = torch.device("mps")
|
|
||||||
else:
|
|
||||||
device = torch.device("cpu")
|
|
||||||
batch_cond_uncond = cmd_opts.always_batch_cond_uncond or not (cmd_opts.lowvram or cmd_opts.medvram)
|
batch_cond_uncond = cmd_opts.always_batch_cond_uncond or not (cmd_opts.lowvram or cmd_opts.medvram)
|
||||||
parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram
|
parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue