Removing parts no longer needed to fix vram

This commit is contained in:
Jairo Correa 2022-10-04 22:28:50 -03:00
parent 1f50971fb8
commit 82380d9ac1
2 changed files with 9 additions and 15 deletions

View file

@ -1,7 +1,6 @@
import contextlib import contextlib
import torch import torch
import gc
from modules import errors from modules import errors
@ -20,8 +19,8 @@ def get_optimal_device():
return cpu return cpu
def torch_gc(): def torch_gc():
gc.collect()
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()
torch.cuda.ipc_collect() torch.cuda.ipc_collect()

View file

@ -346,7 +346,6 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
state.job_count = p.n_iter state.job_count = p.n_iter
for n in range(p.n_iter): for n in range(p.n_iter):
with torch.no_grad(), precision_scope("cuda"), ema_scope():
if state.interrupted: if state.interrupted:
break break
@ -396,21 +395,18 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
x_samples_ddim = modules.safety.censor_batch(x_samples_ddim) x_samples_ddim = modules.safety.censor_batch(x_samples_ddim)
for i, x_sample in enumerate(x_samples_ddim): for i, x_sample in enumerate(x_samples_ddim):
with torch.no_grad(), precision_scope("cuda"), ema_scope():
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
x_sample = x_sample.astype(np.uint8) x_sample = x_sample.astype(np.uint8)
if p.restore_faces: if p.restore_faces:
with torch.no_grad(), precision_scope("cuda"), ema_scope():
if opts.save and not p.do_not_save_samples and opts.save_images_before_face_restoration: if opts.save and not p.do_not_save_samples and opts.save_images_before_face_restoration:
images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-face-restoration") images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-face-restoration")
devices.torch_gc()
x_sample = modules.face_restoration.restore_faces(x_sample) x_sample = modules.face_restoration.restore_faces(x_sample)
devices.torch_gc() devices.torch_gc()
devices.torch_gc()
with torch.no_grad(), precision_scope("cuda"), ema_scope():
image = Image.fromarray(x_sample) image = Image.fromarray(x_sample)
if p.color_corrections is not None and i < len(p.color_corrections): if p.color_corrections is not None and i < len(p.color_corrections):
@ -444,7 +440,6 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
state.nextjob() state.nextjob()
with torch.no_grad(), precision_scope("cuda"), ema_scope():
p.color_corrections = None p.color_corrections = None
index_of_first_image = 0 index_of_first_image = 0