Merge pull request #1283 from jn-jairo/fix-vram
Fix memory leak and reduce memory usage
This commit is contained in:
commit
2cfcb23c16
4 changed files with 20 additions and 10 deletions
|
@ -100,6 +100,8 @@ def run_extras(extras_mode, image, image_folder, gfpgan_visibility, codeformer_v
|
||||||
|
|
||||||
outputs.append(image)
|
outputs.append(image)
|
||||||
|
|
||||||
|
devices.torch_gc()
|
||||||
|
|
||||||
return outputs, plaintext_to_html(info), ''
|
return outputs, plaintext_to_html(info), ''
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -11,7 +11,7 @@ import cv2
|
||||||
from skimage import exposure
|
from skimage import exposure
|
||||||
|
|
||||||
import modules.sd_hijack
|
import modules.sd_hijack
|
||||||
from modules import devices, prompt_parser, masking, sd_samplers
|
from modules import devices, prompt_parser, masking, sd_samplers, lowvram
|
||||||
from modules.sd_hijack import model_hijack
|
from modules.sd_hijack import model_hijack
|
||||||
from modules.shared import opts, cmd_opts, state
|
from modules.shared import opts, cmd_opts, state
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
|
@ -382,6 +382,13 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
||||||
x_samples_ddim = p.sd_model.decode_first_stage(samples_ddim)
|
x_samples_ddim = p.sd_model.decode_first_stage(samples_ddim)
|
||||||
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
|
|
||||||
|
del samples_ddim
|
||||||
|
|
||||||
|
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
||||||
|
lowvram.send_everything_to_cpu()
|
||||||
|
|
||||||
|
devices.torch_gc()
|
||||||
|
|
||||||
if opts.filter_nsfw:
|
if opts.filter_nsfw:
|
||||||
import modules.safety as safety
|
import modules.safety as safety
|
||||||
x_samples_ddim = modules.safety.censor_batch(x_samples_ddim)
|
x_samples_ddim = modules.safety.censor_batch(x_samples_ddim)
|
||||||
|
@ -426,6 +433,10 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
||||||
infotexts.append(infotext(n, i))
|
infotexts.append(infotext(n, i))
|
||||||
output_images.append(image)
|
output_images.append(image)
|
||||||
|
|
||||||
|
del x_samples_ddim
|
||||||
|
|
||||||
|
devices.torch_gc()
|
||||||
|
|
||||||
state.nextjob()
|
state.nextjob()
|
||||||
|
|
||||||
p.color_corrections = None
|
p.color_corrections = None
|
||||||
|
@ -663,4 +674,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
||||||
if self.mask is not None:
|
if self.mask is not None:
|
||||||
samples = samples * self.nmask + self.init_latent * self.mask
|
samples = samples * self.nmask + self.init_latent * self.mask
|
||||||
|
|
||||||
|
del x
|
||||||
|
devices.torch_gc()
|
||||||
|
|
||||||
return samples
|
return samples
|
||||||
|
|
|
@ -5,6 +5,7 @@ import traceback
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from torch import einsum
|
from torch import einsum
|
||||||
|
from torch.nn.functional import silu
|
||||||
|
|
||||||
import modules.textual_inversion.textual_inversion
|
import modules.textual_inversion.textual_inversion
|
||||||
from modules import prompt_parser, devices, sd_hijack_optimizations, shared
|
from modules import prompt_parser, devices, sd_hijack_optimizations, shared
|
||||||
|
@ -19,11 +20,12 @@ diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.At
|
||||||
|
|
||||||
|
|
||||||
def apply_optimizations():
|
def apply_optimizations():
|
||||||
|
ldm.modules.diffusionmodules.model.nonlinearity = silu
|
||||||
|
|
||||||
if cmd_opts.opt_split_attention_v1:
|
if cmd_opts.opt_split_attention_v1:
|
||||||
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
|
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
|
||||||
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
|
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
|
||||||
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward
|
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward
|
||||||
ldm.modules.diffusionmodules.model.nonlinearity = sd_hijack_optimizations.nonlinearity_hijack
|
|
||||||
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward
|
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -92,14 +92,6 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
|
||||||
|
|
||||||
return self.to_out(r2)
|
return self.to_out(r2)
|
||||||
|
|
||||||
def nonlinearity_hijack(x):
|
|
||||||
# swish
|
|
||||||
t = torch.sigmoid(x)
|
|
||||||
x *= t
|
|
||||||
del t
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
def cross_attention_attnblock_forward(self, x):
|
def cross_attention_attnblock_forward(self, x):
|
||||||
h_ = x
|
h_ = x
|
||||||
h_ = self.norm(h_)
|
h_ = self.norm(h_)
|
||||||
|
|
Loading…
Reference in a new issue