add xformers attention

This commit is contained in:
C43H66N12O12S2 2022-10-07 05:21:49 +03:00 committed by GitHub
parent 2995107fa2
commit f174fb2922
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -1,7 +1,9 @@
import math import math
import torch import torch
from torch import einsum from torch import einsum
import xformers.ops
import functorch
xformers._is_functorch_available=True
from ldm.util import default from ldm.util import default
from einops import rearrange from einops import rearrange
@ -92,6 +94,41 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
return self.to_out(r2) return self.to_out(r2)
def _maybe_init(self, x):
"""
Initialize the attention operator, if required We expect the head dimension to be exposed here, meaning that x
: B, Head, Length
"""
if self.attention_op is not None:
return
_, M, K = x.shape
try:
self.attention_op = xformers.ops.AttentionOpDispatch(
dtype=x.dtype,
device=x.device,
k=K,
attn_bias_type=type(None),
has_dropout=False,
kv_len=M,
q_len=M,
).op
except NotImplementedError as err:
raise NotImplementedError(f"Please install xformers with the flash attention / cutlass components.\n{err}")
def xformers_attention_forward(self, x, context=None, mask=None):
h = self.heads
q_in = self.to_q(x)
context = default(context, x)
k_in = self.to_k(context)
v_in = self.to_v(context)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
del q_in, k_in, v_in
self._maybe_init(q)
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
return self.to_out(out)
def cross_attention_attnblock_forward(self, x): def cross_attention_attnblock_forward(self, x):
h_ = x h_ = x
h_ = self.norm(h_) h_ = self.norm(h_)