Fix memory leak and reduce memory usage
This commit is contained in:
parent
041d2aefc0
commit
c938679de7
6 changed files with 42 additions and 16 deletions
|
@ -89,7 +89,7 @@ def setup_codeformer():
|
||||||
output = self.net(cropped_face_t, w=w if w is not None else shared.opts.code_former_weight, adain=True)[0]
|
output = self.net(cropped_face_t, w=w if w is not None else shared.opts.code_former_weight, adain=True)[0]
|
||||||
restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
|
restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
|
||||||
del output
|
del output
|
||||||
torch.cuda.empty_cache()
|
devices.torch_gc()
|
||||||
except Exception as error:
|
except Exception as error:
|
||||||
print(f'\tFailed inference for CodeFormer: {error}', file=sys.stderr)
|
print(f'\tFailed inference for CodeFormer: {error}', file=sys.stderr)
|
||||||
restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
|
restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
|
||||||
|
@ -106,7 +106,9 @@ def setup_codeformer():
|
||||||
restored_img = cv2.resize(restored_img, (0, 0), fx=original_resolution[1]/restored_img.shape[1], fy=original_resolution[0]/restored_img.shape[0], interpolation=cv2.INTER_LINEAR)
|
restored_img = cv2.resize(restored_img, (0, 0), fx=original_resolution[1]/restored_img.shape[1], fy=original_resolution[0]/restored_img.shape[0], interpolation=cv2.INTER_LINEAR)
|
||||||
|
|
||||||
if shared.opts.face_restoration_unload:
|
if shared.opts.face_restoration_unload:
|
||||||
self.net.to(devices.cpu)
|
self.net = None
|
||||||
|
self.face_helper = None
|
||||||
|
devices.torch_gc()
|
||||||
|
|
||||||
return restored_img
|
return restored_img
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import torch
|
import torch
|
||||||
|
import gc
|
||||||
|
|
||||||
# has_mps is only available in nightly pytorch (for now), `getattr` for compatibility
|
# has_mps is only available in nightly pytorch (for now), `getattr` for compatibility
|
||||||
from modules import errors
|
from modules import errors
|
||||||
|
@ -17,8 +18,8 @@ def get_optimal_device():
|
||||||
|
|
||||||
return cpu
|
return cpu
|
||||||
|
|
||||||
|
|
||||||
def torch_gc():
|
def torch_gc():
|
||||||
|
gc.collect()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
torch.cuda.ipc_collect()
|
torch.cuda.ipc_collect()
|
||||||
|
|
|
@ -98,6 +98,8 @@ def run_extras(extras_mode, image, image_folder, gfpgan_visibility, codeformer_v
|
||||||
|
|
||||||
outputs.append(image)
|
outputs.append(image)
|
||||||
|
|
||||||
|
devices.torch_gc()
|
||||||
|
|
||||||
return outputs, plaintext_to_html(info), ''
|
return outputs, plaintext_to_html(info), ''
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -49,6 +49,7 @@ def gfpgan():
|
||||||
|
|
||||||
|
|
||||||
def gfpgan_fix_faces(np_image):
|
def gfpgan_fix_faces(np_image):
|
||||||
|
global loaded_gfpgan_model
|
||||||
model = gfpgan()
|
model = gfpgan()
|
||||||
|
|
||||||
np_image_bgr = np_image[:, :, ::-1]
|
np_image_bgr = np_image[:, :, ::-1]
|
||||||
|
@ -56,7 +57,9 @@ def gfpgan_fix_faces(np_image):
|
||||||
np_image = gfpgan_output_bgr[:, :, ::-1]
|
np_image = gfpgan_output_bgr[:, :, ::-1]
|
||||||
|
|
||||||
if shared.opts.face_restoration_unload:
|
if shared.opts.face_restoration_unload:
|
||||||
model.gfpgan.to(devices.cpu)
|
del model
|
||||||
|
loaded_gfpgan_model = None
|
||||||
|
devices.torch_gc()
|
||||||
|
|
||||||
return np_image
|
return np_image
|
||||||
|
|
||||||
|
@ -83,11 +86,7 @@ def setup_gfpgan():
|
||||||
return "GFPGAN"
|
return "GFPGAN"
|
||||||
|
|
||||||
def restore(self, np_image):
|
def restore(self, np_image):
|
||||||
np_image_bgr = np_image[:, :, ::-1]
|
return gfpgan_fix_faces(np_image)
|
||||||
cropped_faces, restored_faces, gfpgan_output_bgr = gfpgan().enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True)
|
|
||||||
np_image = gfpgan_output_bgr[:, :, ::-1]
|
|
||||||
|
|
||||||
return np_image
|
|
||||||
|
|
||||||
shared.face_restorers.append(FaceRestorerGFPGAN())
|
shared.face_restorers.append(FaceRestorerGFPGAN())
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|
|
@ -12,7 +12,7 @@ import cv2
|
||||||
from skimage import exposure
|
from skimage import exposure
|
||||||
|
|
||||||
import modules.sd_hijack
|
import modules.sd_hijack
|
||||||
from modules import devices, prompt_parser, masking
|
from modules import devices, prompt_parser, masking, lowvram
|
||||||
from modules.sd_hijack import model_hijack
|
from modules.sd_hijack import model_hijack
|
||||||
from modules.sd_samplers import samplers, samplers_for_img2img
|
from modules.sd_samplers import samplers, samplers_for_img2img
|
||||||
from modules.shared import opts, cmd_opts, state
|
from modules.shared import opts, cmd_opts, state
|
||||||
|
@ -335,7 +335,8 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
||||||
if state.job_count == -1:
|
if state.job_count == -1:
|
||||||
state.job_count = p.n_iter
|
state.job_count = p.n_iter
|
||||||
|
|
||||||
for n in range(p.n_iter):
|
for n in range(p.n_iter):
|
||||||
|
with torch.no_grad(), precision_scope("cuda"), ema_scope():
|
||||||
if state.interrupted:
|
if state.interrupted:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
@ -368,22 +369,32 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
||||||
x_samples_ddim = p.sd_model.decode_first_stage(samples_ddim)
|
x_samples_ddim = p.sd_model.decode_first_stage(samples_ddim)
|
||||||
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
|
|
||||||
|
del samples_ddim
|
||||||
|
|
||||||
|
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
||||||
|
lowvram.send_everything_to_cpu()
|
||||||
|
|
||||||
|
devices.torch_gc()
|
||||||
|
|
||||||
if opts.filter_nsfw:
|
if opts.filter_nsfw:
|
||||||
import modules.safety as safety
|
import modules.safety as safety
|
||||||
x_samples_ddim = modules.safety.censor_batch(x_samples_ddim)
|
x_samples_ddim = modules.safety.censor_batch(x_samples_ddim)
|
||||||
|
|
||||||
for i, x_sample in enumerate(x_samples_ddim):
|
for i, x_sample in enumerate(x_samples_ddim):
|
||||||
|
with torch.no_grad(), precision_scope("cuda"), ema_scope():
|
||||||
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
|
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
|
||||||
x_sample = x_sample.astype(np.uint8)
|
x_sample = x_sample.astype(np.uint8)
|
||||||
|
|
||||||
if p.restore_faces:
|
if p.restore_faces:
|
||||||
|
with torch.no_grad(), precision_scope("cuda"), ema_scope():
|
||||||
if opts.save and not p.do_not_save_samples and opts.save_images_before_face_restoration:
|
if opts.save and not p.do_not_save_samples and opts.save_images_before_face_restoration:
|
||||||
images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-face-restoration")
|
images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-face-restoration")
|
||||||
|
|
||||||
devices.torch_gc()
|
|
||||||
|
|
||||||
x_sample = modules.face_restoration.restore_faces(x_sample)
|
x_sample = modules.face_restoration.restore_faces(x_sample)
|
||||||
|
|
||||||
|
devices.torch_gc()
|
||||||
|
|
||||||
|
with torch.no_grad(), precision_scope("cuda"), ema_scope():
|
||||||
image = Image.fromarray(x_sample)
|
image = Image.fromarray(x_sample)
|
||||||
|
|
||||||
if p.color_corrections is not None and i < len(p.color_corrections):
|
if p.color_corrections is not None and i < len(p.color_corrections):
|
||||||
|
@ -411,8 +422,13 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
||||||
infotexts.append(infotext(n, i))
|
infotexts.append(infotext(n, i))
|
||||||
output_images.append(image)
|
output_images.append(image)
|
||||||
|
|
||||||
state.nextjob()
|
del x_samples_ddim
|
||||||
|
|
||||||
|
devices.torch_gc()
|
||||||
|
|
||||||
|
state.nextjob()
|
||||||
|
|
||||||
|
with torch.no_grad(), precision_scope("cuda"), ema_scope():
|
||||||
p.color_corrections = None
|
p.color_corrections = None
|
||||||
|
|
||||||
index_of_first_image = 0
|
index_of_first_image = 0
|
||||||
|
@ -648,4 +664,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
||||||
if self.mask is not None:
|
if self.mask is not None:
|
||||||
samples = samples * self.nmask + self.init_latent * self.mask
|
samples = samples * self.nmask + self.init_latent * self.mask
|
||||||
|
|
||||||
|
del x
|
||||||
|
devices.torch_gc()
|
||||||
|
|
||||||
return samples
|
return samples
|
||||||
|
|
3
webui.py
3
webui.py
|
@ -22,7 +22,10 @@ import modules.txt2img
|
||||||
import modules.img2img
|
import modules.img2img
|
||||||
import modules.swinir as swinir
|
import modules.swinir as swinir
|
||||||
import modules.sd_models
|
import modules.sd_models
|
||||||
|
from torch.nn.functional import silu
|
||||||
|
import ldm
|
||||||
|
|
||||||
|
ldm.modules.diffusionmodules.model.nonlinearity = silu
|
||||||
|
|
||||||
modules.codeformer_model.setup_codeformer()
|
modules.codeformer_model.setup_codeformer()
|
||||||
modules.gfpgan_model.setup_gfpgan()
|
modules.gfpgan_model.setup_gfpgan()
|
||||||
|
|
Loading…
Reference in a new issue