commit
48feae37ff
6 changed files with 55 additions and 6 deletions
|
@ -4,6 +4,7 @@ import os
|
|||
import sys
|
||||
import importlib.util
|
||||
import shlex
|
||||
import platform
|
||||
|
||||
dir_repos = "repositories"
|
||||
dir_tmp = "tmp"
|
||||
|
@ -31,6 +32,7 @@ def extract_arg(args, name):
|
|||
|
||||
|
||||
args, skip_torch_cuda_test = extract_arg(args, '--skip-torch-cuda-test')
|
||||
args, xformers = extract_arg(args, '--xformers')
|
||||
|
||||
|
||||
def repo_dir(name):
|
||||
|
@ -124,6 +126,12 @@ if not is_installed("gfpgan"):
|
|||
if not is_installed("clip"):
|
||||
run_pip(f"install {clip_package}", "clip")
|
||||
|
||||
if not is_installed("xformers") and xformers:
|
||||
if platform.system() == "Windows":
|
||||
run_pip("install https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/a/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl", "xformers")
|
||||
elif platform.system() == "Linux":
|
||||
run_pip("install xformers", "xformers")
|
||||
|
||||
os.makedirs(dir_repos, exist_ok=True)
|
||||
|
||||
git_clone("https://github.com/CompVis/stable-diffusion.git", repo_dir('stable-diffusion'), "Stable Diffusion", stable_diffusion_commit_hash)
|
||||
|
|
|
@ -22,11 +22,13 @@ def apply_optimizations():
|
|||
undo_optimizations()
|
||||
|
||||
ldm.modules.diffusionmodules.model.nonlinearity = silu
|
||||
|
||||
if cmd_opts.opt_split_attention_v1:
|
||||
if not cmd_opts.disable_opt_xformers_attention and not (cmd_opts.opt_split_attention or torch.version.hip) and shared.xformers_available:
|
||||
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward
|
||||
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward
|
||||
elif cmd_opts.opt_split_attention_v1:
|
||||
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
|
||||
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
|
||||
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward
|
||||
elif cmd_opts.opt_split_attention or torch.cuda.is_available():
|
||||
ldm.modules.attention_CrossAttention_forward = sd_hijack_optimizations.split_cross_attention_forward
|
||||
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward
|
||||
|
||||
|
||||
|
|
|
@ -1,7 +1,14 @@
|
|||
import math
|
||||
import torch
|
||||
from torch import einsum
|
||||
|
||||
try:
|
||||
import xformers.ops
|
||||
import functorch
|
||||
xformers._is_functorch_available = True
|
||||
shared.xformers_available = True
|
||||
except:
|
||||
print('Cannot find xformers, defaulting to split attention. Try setting --xformers in your webui-user file if you wish to install it.')
|
||||
continue
|
||||
from ldm.util import default
|
||||
from einops import rearrange
|
||||
|
||||
|
@ -115,6 +122,25 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
|
|||
|
||||
return self.to_out(r2)
|
||||
|
||||
def xformers_attention_forward(self, x, context=None, mask=None):
|
||||
h = self.heads
|
||||
q_in = self.to_q(x)
|
||||
context = default(context, x)
|
||||
hypernetwork = shared.selected_hypernetwork()
|
||||
hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
|
||||
if hypernetwork_layers is not None:
|
||||
k_in = self.to_k(hypernetwork_layers[0](context))
|
||||
v_in = self.to_v(hypernetwork_layers[1](context))
|
||||
else:
|
||||
k_in = self.to_k(context)
|
||||
v_in = self.to_v(context)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in))
|
||||
del q_in, k_in, v_in
|
||||
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None)
|
||||
|
||||
out = rearrange(out, 'b n h d -> b n (h d)', h=h)
|
||||
return self.to_out(out)
|
||||
|
||||
def cross_attention_attnblock_forward(self, x):
|
||||
h_ = x
|
||||
h_ = self.norm(h_)
|
||||
|
@ -177,3 +203,13 @@ def cross_attention_attnblock_forward(self, x):
|
|||
h3 += x
|
||||
|
||||
return h3
|
||||
|
||||
def xformers_attnblock_forward(self, x):
|
||||
h_ = x
|
||||
h_ = self.norm(h_)
|
||||
q1 = self.q(h_).contiguous()
|
||||
k1 = self.k(h_).contiguous()
|
||||
v = self.v(h_).contiguous()
|
||||
out = xformers.ops.memory_efficient_attention(q1, k1, v)
|
||||
out = self.proj_out(out)
|
||||
return x+out
|
||||
|
|
|
@ -43,6 +43,7 @@ parser.add_argument("--realesrgan-models-path", type=str, help="Path to director
|
|||
parser.add_argument("--scunet-models-path", type=str, help="Path to directory with ScuNET model file(s).", default=os.path.join(models_path, 'ScuNET'))
|
||||
parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(models_path, 'SwinIR'))
|
||||
parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(models_path, 'LDSR'))
|
||||
parser.add_argument("--disable-opt-xformers-attention", action='store_true', help="force-disables xformers attention optimization")
|
||||
parser.add_argument("--opt-split-attention", action='store_true', help="force-enables cross-attention layer optimization. By default, it's on for torch.cuda and off for other torch devices.")
|
||||
parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
|
||||
parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
|
||||
|
@ -73,7 +74,7 @@ device = devices.device
|
|||
|
||||
batch_cond_uncond = cmd_opts.always_batch_cond_uncond or not (cmd_opts.lowvram or cmd_opts.medvram)
|
||||
parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram
|
||||
|
||||
xformers_available = False
|
||||
config_filename = cmd_opts.ui_settings_file
|
||||
|
||||
hypernetworks = hypernetwork.load_hypernetworks(os.path.join(models_path, 'hypernetworks'))
|
||||
|
|
|
@ -23,3 +23,4 @@ resize-right
|
|||
torchdiffeq
|
||||
kornia
|
||||
lark
|
||||
functorch
|
||||
|
|
|
@ -22,3 +22,4 @@ resize-right==0.0.2
|
|||
torchdiffeq==0.2.3
|
||||
kornia==0.6.7
|
||||
lark==1.1.2
|
||||
functorch==0.2.1
|
||||
|
|
Loading…
Reference in a new issue