sdp_attnblock_forward hijack
This commit is contained in:
parent
0981dea948
commit
8d7fa2f67c
2 changed files with 26 additions and 0 deletions
|
@ -47,10 +47,12 @@ def apply_optimizations():
|
||||||
elif cmd_opts.opt_sdp_no_mem_attention and can_use_sdp:
|
elif cmd_opts.opt_sdp_no_mem_attention and can_use_sdp:
|
||||||
print("Applying scaled dot product cross attention optimization (without memory efficient attention).")
|
print("Applying scaled dot product cross attention optimization (without memory efficient attention).")
|
||||||
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.scaled_dot_product_no_mem_attention_forward
|
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.scaled_dot_product_no_mem_attention_forward
|
||||||
|
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.sdp_no_mem_attnblock_forward
|
||||||
optimization_method = 'sdp-no-mem'
|
optimization_method = 'sdp-no-mem'
|
||||||
elif cmd_opts.opt_sdp_attention and can_use_sdp:
|
elif cmd_opts.opt_sdp_attention and can_use_sdp:
|
||||||
print("Applying scaled dot product cross attention optimization.")
|
print("Applying scaled dot product cross attention optimization.")
|
||||||
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.scaled_dot_product_attention_forward
|
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.scaled_dot_product_attention_forward
|
||||||
|
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.sdp_attnblock_forward
|
||||||
optimization_method = 'sdp'
|
optimization_method = 'sdp'
|
||||||
elif cmd_opts.opt_sub_quad_attention:
|
elif cmd_opts.opt_sub_quad_attention:
|
||||||
print("Applying sub-quadratic cross attention optimization.")
|
print("Applying sub-quadratic cross attention optimization.")
|
||||||
|
|
|
@ -473,6 +473,30 @@ def xformers_attnblock_forward(self, x):
|
||||||
except NotImplementedError:
|
except NotImplementedError:
|
||||||
return cross_attention_attnblock_forward(self, x)
|
return cross_attention_attnblock_forward(self, x)
|
||||||
|
|
||||||
|
def sdp_attnblock_forward(self, x):
|
||||||
|
h_ = x
|
||||||
|
h_ = self.norm(h_)
|
||||||
|
q = self.q(h_)
|
||||||
|
k = self.k(h_)
|
||||||
|
v = self.v(h_)
|
||||||
|
b, c, h, w = q.shape
|
||||||
|
q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v))
|
||||||
|
dtype = q.dtype
|
||||||
|
if shared.opts.upcast_attn:
|
||||||
|
q, k = q.float(), k.float()
|
||||||
|
q = q.contiguous()
|
||||||
|
k = k.contiguous()
|
||||||
|
v = v.contiguous()
|
||||||
|
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=False)
|
||||||
|
out = out.to(dtype)
|
||||||
|
out = rearrange(out, 'b (h w) c -> b c h w', h=h)
|
||||||
|
out = self.proj_out(out)
|
||||||
|
return x + out
|
||||||
|
|
||||||
|
def sdp_no_mem_attnblock_forward(self, x):
|
||||||
|
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False):
|
||||||
|
return sdp_attnblock_forward(self, x)
|
||||||
|
|
||||||
def sub_quad_attnblock_forward(self, x):
|
def sub_quad_attnblock_forward(self, x):
|
||||||
h_ = x
|
h_ = x
|
||||||
h_ = self.norm(h_)
|
h_ = self.norm(h_)
|
||||||
|
|
Loading…
Reference in a new issue