send all three of GFPGAN's and codeformer's models to CPU memory instead of just one for #1283
This commit is contained in:
parent
556c36b960
commit
6c6ae28bf5
4 changed files with 41 additions and 11 deletions
|
@ -69,10 +69,14 @@ def setup_model(dirname):
|
|||
|
||||
self.net = net
|
||||
self.face_helper = face_helper
|
||||
self.net.to(devices.device_codeformer)
|
||||
|
||||
return net, face_helper
|
||||
|
||||
def send_model_to(self, device):
|
||||
self.net.to(device)
|
||||
self.face_helper.face_det.to(device)
|
||||
self.face_helper.face_parse.to(device)
|
||||
|
||||
def restore(self, np_image, w=None):
|
||||
np_image = np_image[:, :, ::-1]
|
||||
|
||||
|
@ -82,6 +86,8 @@ def setup_model(dirname):
|
|||
if self.net is None or self.face_helper is None:
|
||||
return np_image
|
||||
|
||||
self.send_model_to(devices.device_codeformer)
|
||||
|
||||
self.face_helper.clean_all()
|
||||
self.face_helper.read_image(np_image)
|
||||
self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
|
||||
|
@ -113,8 +119,10 @@ def setup_model(dirname):
|
|||
if original_resolution != restored_img.shape[0:2]:
|
||||
restored_img = cv2.resize(restored_img, (0, 0), fx=original_resolution[1]/restored_img.shape[1], fy=original_resolution[0]/restored_img.shape[0], interpolation=cv2.INTER_LINEAR)
|
||||
|
||||
self.face_helper.clean_all()
|
||||
|
||||
if shared.opts.face_restoration_unload:
|
||||
self.net.to(devices.cpu)
|
||||
self.send_model_to(devices.cpu)
|
||||
|
||||
return restored_img
|
||||
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
import contextlib
|
||||
|
||||
import torch
|
||||
|
||||
# has_mps is only available in nightly pytorch (for now), `getattr` for compatibility
|
||||
|
@ -57,3 +59,11 @@ def randn_without_seed(shape):
|
|||
|
||||
return torch.randn(shape, device=device)
|
||||
|
||||
|
||||
def autocast():
|
||||
from modules import shared
|
||||
|
||||
if dtype == torch.float32 or shared.cmd_opts.precision == "full":
|
||||
return contextlib.nullcontext()
|
||||
|
||||
return torch.autocast("cuda")
|
||||
|
|
|
@ -37,22 +37,32 @@ def gfpgann():
|
|||
print("Unable to load gfpgan model!")
|
||||
return None
|
||||
model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None)
|
||||
model.gfpgan.to(shared.device)
|
||||
loaded_gfpgan_model = model
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def send_model_to(model, device):
|
||||
model.gfpgan.to(device)
|
||||
model.face_helper.face_det.to(device)
|
||||
model.face_helper.face_parse.to(device)
|
||||
|
||||
|
||||
def gfpgan_fix_faces(np_image):
|
||||
model = gfpgann()
|
||||
if model is None:
|
||||
return np_image
|
||||
|
||||
send_model_to(model, devices.device)
|
||||
|
||||
np_image_bgr = np_image[:, :, ::-1]
|
||||
cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True)
|
||||
np_image = gfpgan_output_bgr[:, :, ::-1]
|
||||
|
||||
model.face_helper.clean_all()
|
||||
|
||||
if shared.opts.face_restoration_unload:
|
||||
model.gfpgan.to(devices.cpu)
|
||||
send_model_to(model, devices.cpu)
|
||||
|
||||
return np_image
|
||||
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
import contextlib
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
|
@ -330,9 +329,8 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
|||
|
||||
infotexts = []
|
||||
output_images = []
|
||||
precision_scope = torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext
|
||||
ema_scope = (contextlib.nullcontext if cmd_opts.lowvram else p.sd_model.ema_scope)
|
||||
with torch.no_grad(), precision_scope("cuda"), ema_scope():
|
||||
|
||||
with torch.no_grad():
|
||||
p.init(all_prompts, all_seeds, all_subseeds)
|
||||
|
||||
if state.job_count == -1:
|
||||
|
@ -351,8 +349,9 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
|||
|
||||
#uc = p.sd_model.get_learned_conditioning(len(prompts) * [p.negative_prompt])
|
||||
#c = p.sd_model.get_learned_conditioning(prompts)
|
||||
uc = prompt_parser.get_learned_conditioning(len(prompts) * [p.negative_prompt], p.steps)
|
||||
c = prompt_parser.get_learned_conditioning(prompts, p.steps)
|
||||
with devices.autocast():
|
||||
uc = prompt_parser.get_learned_conditioning(len(prompts) * [p.negative_prompt], p.steps)
|
||||
c = prompt_parser.get_learned_conditioning(prompts, p.steps)
|
||||
|
||||
if len(model_hijack.comments) > 0:
|
||||
for comment in model_hijack.comments:
|
||||
|
@ -361,7 +360,9 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
|||
if p.n_iter > 1:
|
||||
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
|
||||
|
||||
samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength)
|
||||
with devices.autocast():
|
||||
samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength).to(devices.dtype)
|
||||
|
||||
if state.interrupted:
|
||||
|
||||
# if we are interruped, sample returns just noise
|
||||
|
@ -386,6 +387,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
|||
devices.torch_gc()
|
||||
|
||||
x_sample = modules.face_restoration.restore_faces(x_sample)
|
||||
devices.torch_gc()
|
||||
|
||||
image = Image.fromarray(x_sample)
|
||||
|
||||
|
|
Loading…
Reference in a new issue