This commit is contained in:
AUTOMATIC 2022-09-05 01:41:20 +03:00
parent 407fc1fe0c
commit 5bb126bd89
2 changed files with 44 additions and 1 deletions

View file

@ -3,8 +3,43 @@ import sys
import traceback import traceback
import torch import torch
import numpy as np import numpy as np
from torch import einsum
from modules.shared import opts, device from modules.shared import opts, device, cmd_opts
from ldm.util import default
from einops import rearrange
import ldm.modules.attention
# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
def split_cross_attention_forward(self, x, context=None, mask=None):
h = self.heads
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
del context, x
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
for i in range(0, q.shape[0], 2):
end = i + 2
s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
s1 *= self.scale
s2 = s1.softmax(dim=-1)
del s1
r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
del s2
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
del r1
return self.to_out(r2)
class StableDiffusionModelHijack: class StableDiffusionModelHijack:
@ -67,6 +102,9 @@ class StableDiffusionModelHijack:
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self) model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
m.cond_stage_model = FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self) m.cond_stage_model = FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
if cmd_opts.opt_split_attention:
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward
class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
def __init__(self, wrapped, hijack): def __init__(self, wrapped, hijack):
@ -205,4 +243,8 @@ class EmbeddingsWithFixes(torch.nn.Module):
return inputs_embeds return inputs_embeds
model_hijack = StableDiffusionModelHijack() model_hijack = StableDiffusionModelHijack()

View file

@ -29,6 +29,7 @@ parser.add_argument("--unload-gfpgan", action='store_true', help="unload GFPGAN
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast") parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site (doesn't work for me but you might have better luck)") parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site (doesn't work for me but you might have better luck)")
parser.add_argument("--esrgan-models-path", type=str, help="path to directory with ESRGAN models", default=os.path.join(script_path, 'ESRGAN')) parser.add_argument("--esrgan-models-path", type=str, help="path to directory with ESRGAN models", default=os.path.join(script_path, 'ESRGAN'))
parser.add_argument("--opt-split-attention", type=str, help="enable optimization that reduced vram usage by a lot for about 10% decrease in performance", default=os.path.join(script_path, 'ESRGAN'))
cmd_opts = parser.parse_args() cmd_opts = parser.parse_args()
cpu = torch.device("cpu") cpu = torch.device("cpu")