Removing parts no longer needed to fix vram
This commit is contained in:
parent
1f50971fb8
commit
82380d9ac1
2 changed files with 9 additions and 15 deletions
|
@ -1,7 +1,6 @@
|
||||||
import contextlib
|
import contextlib
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import gc
|
|
||||||
|
|
||||||
from modules import errors
|
from modules import errors
|
||||||
|
|
||||||
|
@ -20,8 +19,8 @@ def get_optimal_device():
|
||||||
|
|
||||||
return cpu
|
return cpu
|
||||||
|
|
||||||
|
|
||||||
def torch_gc():
|
def torch_gc():
|
||||||
gc.collect()
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
torch.cuda.ipc_collect()
|
torch.cuda.ipc_collect()
|
||||||
|
|
|
@ -345,8 +345,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
||||||
if state.job_count == -1:
|
if state.job_count == -1:
|
||||||
state.job_count = p.n_iter
|
state.job_count = p.n_iter
|
||||||
|
|
||||||
for n in range(p.n_iter):
|
for n in range(p.n_iter):
|
||||||
with torch.no_grad(), precision_scope("cuda"), ema_scope():
|
|
||||||
if state.interrupted:
|
if state.interrupted:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
@ -395,22 +394,19 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
||||||
import modules.safety as safety
|
import modules.safety as safety
|
||||||
x_samples_ddim = modules.safety.censor_batch(x_samples_ddim)
|
x_samples_ddim = modules.safety.censor_batch(x_samples_ddim)
|
||||||
|
|
||||||
for i, x_sample in enumerate(x_samples_ddim):
|
for i, x_sample in enumerate(x_samples_ddim):
|
||||||
with torch.no_grad(), precision_scope("cuda"), ema_scope():
|
|
||||||
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
|
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
|
||||||
x_sample = x_sample.astype(np.uint8)
|
x_sample = x_sample.astype(np.uint8)
|
||||||
|
|
||||||
if p.restore_faces:
|
if p.restore_faces:
|
||||||
with torch.no_grad(), precision_scope("cuda"), ema_scope():
|
|
||||||
if opts.save and not p.do_not_save_samples and opts.save_images_before_face_restoration:
|
if opts.save and not p.do_not_save_samples and opts.save_images_before_face_restoration:
|
||||||
images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-face-restoration")
|
images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-face-restoration")
|
||||||
|
|
||||||
|
devices.torch_gc()
|
||||||
|
|
||||||
x_sample = modules.face_restoration.restore_faces(x_sample)
|
x_sample = modules.face_restoration.restore_faces(x_sample)
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
|
|
||||||
devices.torch_gc()
|
|
||||||
|
|
||||||
with torch.no_grad(), precision_scope("cuda"), ema_scope():
|
|
||||||
image = Image.fromarray(x_sample)
|
image = Image.fromarray(x_sample)
|
||||||
|
|
||||||
if p.color_corrections is not None and i < len(p.color_corrections):
|
if p.color_corrections is not None and i < len(p.color_corrections):
|
||||||
|
@ -438,13 +434,12 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
||||||
infotexts.append(infotext(n, i))
|
infotexts.append(infotext(n, i))
|
||||||
output_images.append(image)
|
output_images.append(image)
|
||||||
|
|
||||||
del x_samples_ddim
|
del x_samples_ddim
|
||||||
|
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
|
|
||||||
state.nextjob()
|
state.nextjob()
|
||||||
|
|
||||||
with torch.no_grad(), precision_scope("cuda"), ema_scope():
|
|
||||||
p.color_corrections = None
|
p.color_corrections = None
|
||||||
|
|
||||||
index_of_first_image = 0
|
index_of_first_image = 0
|
||||||
|
|
Loading…
Reference in a new issue