better support for xformers flash attention on older versions of torch

This commit is contained in:
AUTOMATIC 2023-01-23 16:40:20 +03:00
parent 3fa482076a
commit 59146621e2
2 changed files with 30 additions and 24 deletions

View file

@ -24,6 +24,18 @@ See https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#stable
""") """)
already_displayed = {}
def display_once(e: Exception, task):
if task in already_displayed:
return
display(e, task)
already_displayed[task] = 1
def run(code, task): def run(code, task):
try: try:
code() code()

View file

@ -9,7 +9,7 @@ from torch import einsum
from ldm.util import default from ldm.util import default
from einops import rearrange from einops import rearrange
from modules import shared from modules import shared, errors
from modules.hypernetworks import hypernetwork from modules.hypernetworks import hypernetwork
from .sub_quadratic_attention import efficient_dot_product_attention from .sub_quadratic_attention import efficient_dot_product_attention
@ -279,6 +279,21 @@ def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_
) )
def get_xformers_flash_attention_op(q, k, v):
if not shared.cmd_opts.xformers_flash_attention:
return None
try:
flash_attention_op = xformers.ops.MemoryEfficientAttentionFlashAttentionOp
fw, bw = flash_attention_op
if fw.supports(xformers.ops.fmha.Inputs(query=q, key=k, value=v, attn_bias=None)):
return flash_attention_op
except Exception as e:
errors.display_once(e, "enabling flash attention")
return None
def xformers_attention_forward(self, x, context=None, mask=None): def xformers_attention_forward(self, x, context=None, mask=None):
h = self.heads h = self.heads
q_in = self.to_q(x) q_in = self.to_q(x)
@ -291,18 +306,7 @@ def xformers_attention_forward(self, x, context=None, mask=None):
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in)) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in))
del q_in, k_in, v_in del q_in, k_in, v_in
if shared.cmd_opts.xformers_flash_attention: out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=get_xformers_flash_attention_op(q, k, v))
op = xformers.ops.MemoryEfficientAttentionFlashAttentionOp
fw, bw = op
if not fw.supports(xformers.ops.fmha.Inputs(query=q, key=k, value=v, attn_bias=None)):
# print('xformers_attention_forward', q.shape, k.shape, v.shape)
# Flash Attention is not availabe for the input arguments.
# Fallback to default xFormers' backend.
op = None
else:
op = None
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=op)
out = rearrange(out, 'b n h d -> b n (h d)', h=h) out = rearrange(out, 'b n h d -> b n (h d)', h=h)
return self.to_out(out) return self.to_out(out)
@ -377,17 +381,7 @@ def xformers_attnblock_forward(self, x):
q = q.contiguous() q = q.contiguous()
k = k.contiguous() k = k.contiguous()
v = v.contiguous() v = v.contiguous()
if shared.cmd_opts.xformers_flash_attention: out = xformers.ops.memory_efficient_attention(q, k, v, op=get_xformers_flash_attention_op(q, k, v))
op = xformers.ops.MemoryEfficientAttentionFlashAttentionOp
fw, bw = op
if not fw.supports(xformers.ops.fmha.Inputs(query=q, key=k, value=v)):
# print('xformers_attnblock_forward', q.shape, k.shape, v.shape)
# Flash Attention is not availabe for the input arguments.
# Fallback to default xFormers' backend.
op = None
else:
op = None
out = xformers.ops.memory_efficient_attention(q, k, v, op=op)
out = rearrange(out, 'b (h w) c -> b c h w', h=h) out = rearrange(out, 'b (h w) c -> b c h w', h=h)
out = self.proj_out(out) out = self.proj_out(out)
return x + out return x + out