Merge pull request #335 from C43H66N12O12S2/attention-update
Update cross attention to the newest version
This commit is contained in:
commit
a26f157a5e
1 changed files with 4 additions and 3 deletions
|
@ -67,8 +67,9 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
|
||||||
mem_free_total = mem_free_cuda + mem_free_torch
|
mem_free_total = mem_free_cuda + mem_free_torch
|
||||||
|
|
||||||
gb = 1024 ** 3
|
gb = 1024 ** 3
|
||||||
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * 4
|
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
|
||||||
mem_required = tensor_size * 2.5
|
modifier = 3 if q.element_size() == 2 else 2.5
|
||||||
|
mem_required = tensor_size * modifier
|
||||||
steps = 1
|
steps = 1
|
||||||
|
|
||||||
if mem_required > mem_free_total:
|
if mem_required > mem_free_total:
|
||||||
|
@ -86,7 +87,7 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
|
||||||
end = i + slice_size
|
end = i + slice_size
|
||||||
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale
|
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale
|
||||||
|
|
||||||
s2 = s1.softmax(dim=-1)
|
s2 = s1.softmax(dim=-1, dtype=q.dtype)
|
||||||
del s1
|
del s1
|
||||||
|
|
||||||
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
|
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
|
||||||
|
|
Loading…
Reference in a new issue