Merge pull request #635 from C43H66N12O12S2/attention
Move scale multiplication to the front
This commit is contained in:
commit
17b60490fa
1 changed files with 2 additions and 2 deletions
|
@ -50,7 +50,7 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
|
||||||
|
|
||||||
q_in = self.to_q(x)
|
q_in = self.to_q(x)
|
||||||
context = default(context, x)
|
context = default(context, x)
|
||||||
k_in = self.to_k(context)
|
k_in = self.to_k(context) * self.scale
|
||||||
v_in = self.to_v(context)
|
v_in = self.to_v(context)
|
||||||
del context, x
|
del context, x
|
||||||
|
|
||||||
|
@ -85,7 +85,7 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
|
||||||
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
||||||
for i in range(0, q.shape[1], slice_size):
|
for i in range(0, q.shape[1], slice_size):
|
||||||
end = i + slice_size
|
end = i + slice_size
|
||||||
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale
|
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
|
||||||
|
|
||||||
s2 = s1.softmax(dim=-1, dtype=q.dtype)
|
s2 = s1.softmax(dim=-1, dtype=q.dtype)
|
||||||
del s1
|
del s1
|
||||||
|
|
Loading…
Reference in a new issue