Add UI setting for upcasting attention to float32

Adds "Upcast cross attention layer to float32" option in Stable Diffusion settings. This allows for generating images using SD 2.1 models without --no-half or xFormers.

In order to make upcasting cross attention layer optimizations possible it is necessary to indent several sections of code in sd_hijack_optimizations.py so that a context manager can be used to disable autocast. Also, even though Stable Diffusion (and Diffusers) only upcast q and k, unfortunately my findings were that most of the cross attention layer optimizations could not function unless v is upcast also.
This commit is contained in:
brkirch 2023-01-25 00:23:10 -05:00
parent 84d9ce30cb
commit e3b53fd295
5 changed files with 108 additions and 64 deletions

View file

@ -108,6 +108,10 @@ def autocast(disable=False):
return torch.autocast("cuda")
def without_autocast(disable=False):
return torch.autocast("cuda", enabled=False) if torch.is_autocast_enabled() and not disable else contextlib.nullcontext()
class NansException(Exception):
pass
@ -125,7 +129,7 @@ def test_for_nans(x, where):
message = "A tensor with all NaNs was produced in Unet."
if not shared.cmd_opts.no_half:
message += " This could be either because there's not enough precision to represent the picture, or because your video card does not support half type. Try using --no-half commandline argument to fix this."
message += " This could be either because there's not enough precision to represent the picture, or because your video card does not support half type. Try setting the \"Upcast cross attention layer to float32\" option in Settings > Stable Diffusion or using the --no-half commandline argument to fix this."
elif where == "vae":
message = "A tensor with all NaNs was produced in VAE."

View file

@ -611,7 +611,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if p.n_iter > 1:
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
with devices.autocast(disable=devices.unet_needs_upcast):
with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, prompts=prompts)
x_samples_ddim = [decode_first_stage(p.sd_model, samples_ddim[i:i+1].to(dtype=devices.dtype_vae))[0].cpu() for i in range(samples_ddim.size(0))]

View file

@ -9,7 +9,7 @@ from torch import einsum
from ldm.util import default
from einops import rearrange
from modules import shared, errors
from modules import shared, errors, devices
from modules.hypernetworks import hypernetwork
from .sub_quadratic_attention import efficient_dot_product_attention
@ -52,7 +52,12 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None):
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
del q_in, k_in, v_in
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
dtype = q.dtype
if shared.opts.upcast_attn:
q, k, v = q.float(), k.float(), v.float()
with devices.without_autocast(disable=not shared.opts.upcast_attn):
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
for i in range(0, q.shape[0], 2):
end = i + 2
s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
@ -65,6 +70,8 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None):
del s2
del q, k, v
r1 = r1.to(dtype)
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
del r1
@ -82,7 +89,12 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
k_in = self.to_k(context_k)
v_in = self.to_v(context_v)
k_in *= self.scale
dtype = q_in.dtype
if shared.opts.upcast_attn:
q_in, k_in, v_in = q_in.float(), k_in.float(), v_in if v_in.device.type == 'mps' else v_in.float()
with devices.without_autocast(disable=not shared.opts.upcast_attn):
k_in = k_in * self.scale
del context, x
@ -122,6 +134,8 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
del q, k, v
r1 = r1.to(dtype)
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
del r1
@ -204,12 +218,20 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None):
context = default(context, x)
context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
k = self.to_k(context_k) * self.scale
k = self.to_k(context_k)
v = self.to_v(context_v)
del context, context_k, context_v, x
dtype = q.dtype
if shared.opts.upcast_attn:
q, k, v = q.float(), k.float(), v if v.device.type == 'mps' else v.float()
with devices.without_autocast(disable=not shared.opts.upcast_attn):
k = k * self.scale
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
r = einsum_op(q, k, v)
r = r.to(dtype)
return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h))
# -- End of code from https://github.com/invoke-ai/InvokeAI --
@ -234,8 +256,14 @@ def sub_quad_attention_forward(self, x, context=None, mask=None):
k = k.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
v = v.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
dtype = q.dtype
if shared.opts.upcast_attn:
q, k = q.float(), k.float()
x = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training)
x = x.to(dtype)
x = x.unflatten(0, (-1, h)).transpose(1,2).flatten(start_dim=2)
out_proj, dropout = self.to_out
@ -268,6 +296,7 @@ def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_
query_chunk_size = q_tokens
kv_chunk_size = k_tokens
with devices.without_autocast(disable=q.dtype == v.dtype):
return efficient_dot_product_attention(
q,
k,
@ -306,8 +335,14 @@ def xformers_attention_forward(self, x, context=None, mask=None):
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in))
del q_in, k_in, v_in
dtype = q.dtype
if shared.opts.upcast_attn:
q, k = q.float(), k.float()
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=get_xformers_flash_attention_op(q, k, v))
out = out.to(dtype)
out = rearrange(out, 'b n h d -> b n (h d)', h=h)
return self.to_out(out)
@ -378,10 +413,14 @@ def xformers_attnblock_forward(self, x):
v = self.v(h_)
b, c, h, w = q.shape
q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v))
dtype = q.dtype
if shared.opts.upcast_attn:
q, k = q.float(), k.float()
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
out = xformers.ops.memory_efficient_attention(q, k, v, op=get_xformers_flash_attention_op(q, k, v))
out = out.to(dtype)
out = rearrange(out, 'b (h w) c -> b c h w', h=h)
out = self.proj_out(out)
return x + out

View file

@ -410,6 +410,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
"comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }),
"CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}),
"extra_networks_default_multiplier": OptionInfo(1.0, "Multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
"upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"),
}))
options_templates.update(options_section(('compatibility', "Compatibility"), {

View file

@ -67,7 +67,7 @@ def _summarize_chunk(
max_score, _ = torch.max(attn_weights, -1, keepdim=True)
max_score = max_score.detach()
exp_weights = torch.exp(attn_weights - max_score)
exp_values = torch.bmm(exp_weights, value)
exp_values = torch.bmm(exp_weights, value) if query.device.type == 'mps' else torch.bmm(exp_weights, value.to(exp_weights.dtype)).to(value.dtype)
max_score = max_score.squeeze(-1)
return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score)
@ -129,7 +129,7 @@ def _get_attention_scores_no_kv_chunking(
)
attn_probs = attn_scores.softmax(dim=-1)
del attn_scores
hidden_states_slice = torch.bmm(attn_probs, value)
hidden_states_slice = torch.bmm(attn_probs, value) if query.device.type == 'mps' else torch.bmm(attn_probs, value.to(attn_probs.dtype)).to(value.dtype)
return hidden_states_slice