add fallback for xformers_attnblock_forward
This commit is contained in:
parent
a5550f0213
commit
f9c5da1592
1 changed files with 4 additions and 1 deletions
|
@ -211,6 +211,7 @@ def cross_attention_attnblock_forward(self, x):
|
|||
return h3
|
||||
|
||||
def xformers_attnblock_forward(self, x):
|
||||
try:
|
||||
h_ = x
|
||||
h_ = self.norm(h_)
|
||||
q1 = self.q(h_).contiguous()
|
||||
|
@ -218,4 +219,6 @@ def xformers_attnblock_forward(self, x):
|
|||
v = self.v(h_).contiguous()
|
||||
out = xformers.ops.memory_efficient_attention(q1, k1, v)
|
||||
out = self.proj_out(out)
|
||||
return x+out
|
||||
return x + out
|
||||
except NotImplementedError:
|
||||
return cross_attention_attnblock_forward(self, x)
|
||||
|
|
Loading…
Reference in a new issue