diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 9eb6cc20..c058ac6e 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -67,8 +67,9 @@ def split_cross_attention_forward(self, x, context=None, mask=None): mem_free_total = mem_free_cuda + mem_free_torch gb = 1024 ** 3 - tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * 4 - mem_required = tensor_size * 2.5 + tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() + modifier = 3 if q.element_size() == 2 else 2.5 + mem_required = tensor_size * modifier steps = 1 if mem_required > mem_free_total: @@ -86,7 +87,7 @@ def split_cross_attention_forward(self, x, context=None, mask=None): end = i + slice_size s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale - s2 = s1.softmax(dim=-1) + s2 = s1.softmax(dim=-1, dtype=q.dtype) del s1 r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)