add fallback for xformers_attnblock_forward
This commit is contained in:
parent
a5550f0213
commit
f9c5da1592
1 changed files with 4 additions and 1 deletions
|
@ -211,6 +211,7 @@ def cross_attention_attnblock_forward(self, x):
|
||||||
return h3
|
return h3
|
||||||
|
|
||||||
def xformers_attnblock_forward(self, x):
|
def xformers_attnblock_forward(self, x):
|
||||||
|
try:
|
||||||
h_ = x
|
h_ = x
|
||||||
h_ = self.norm(h_)
|
h_ = self.norm(h_)
|
||||||
q1 = self.q(h_).contiguous()
|
q1 = self.q(h_).contiguous()
|
||||||
|
@ -218,4 +219,6 @@ def xformers_attnblock_forward(self, x):
|
||||||
v = self.v(h_).contiguous()
|
v = self.v(h_).contiguous()
|
||||||
out = xformers.ops.memory_efficient_attention(q1, k1, v)
|
out = xformers.ops.memory_efficient_attention(q1, k1, v)
|
||||||
out = self.proj_out(out)
|
out = self.proj_out(out)
|
||||||
return x+out
|
return x + out
|
||||||
|
except NotImplementedError:
|
||||||
|
return cross_attention_attnblock_forward(self, x)
|
||||||
|
|
Loading…
Reference in a new issue