make it possible to use hypernetworks without opt split attention

This commit is contained in:
AUTOMATIC 2022-10-07 16:39:51 +03:00
parent 97bc0b9504
commit f7c787eb7c
2 changed files with 38 additions and 10 deletions

View file

@ -4,7 +4,12 @@ import sys
import traceback import traceback
import torch import torch
from modules import devices
from ldm.util import default
from modules import devices, shared
import torch
from torch import einsum
from einops import rearrange, repeat
class HypernetworkModule(torch.nn.Module): class HypernetworkModule(torch.nn.Module):
@ -48,15 +53,36 @@ def load_hypernetworks(path):
return res return res
def apply(self, x, context=None, mask=None, original=None):
def attention_CrossAttention_forward(self, x, context=None, mask=None):
h = self.heads
if CrossAttention.hypernetwork is not None and context.shape[2] in CrossAttention.hypernetwork: q = self.to_q(x)
if context.shape[1] == 77 and CrossAttention.noise_cond: context = default(context, x)
context = context + (torch.randn_like(context) * 0.1)
h_k, h_v = CrossAttention.hypernetwork[context.shape[2]] hypernetwork = shared.selected_hypernetwork()
k = self.to_k(h_k(context)) hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
v = self.to_v(h_v(context))
if hypernetwork_layers is not None:
k = self.to_k(hypernetwork_layers[0](context))
v = self.to_v(hypernetwork_layers[1](context))
else: else:
k = self.to_k(context) k = self.to_k(context)
v = self.to_v(context) v = self.to_v(context)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
if mask is not None:
mask = rearrange(mask, 'b ... -> b (...)')
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, 'b j -> (b h) () j', h=h)
sim.masked_fill_(~mask, max_neg_value)
# attention, what we cannot get enough of
attn = sim.softmax(dim=-1)
out = einsum('b i j, b j d -> b i d', attn, v)
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
return self.to_out(out)

View file

@ -8,7 +8,7 @@ from torch import einsum
from torch.nn.functional import silu from torch.nn.functional import silu
import modules.textual_inversion.textual_inversion import modules.textual_inversion.textual_inversion
from modules import prompt_parser, devices, sd_hijack_optimizations, shared from modules import prompt_parser, devices, sd_hijack_optimizations, shared, hypernetwork
from modules.shared import opts, device, cmd_opts from modules.shared import opts, device, cmd_opts
import ldm.modules.attention import ldm.modules.attention
@ -20,6 +20,8 @@ diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.At
def apply_optimizations(): def apply_optimizations():
undo_optimizations()
ldm.modules.diffusionmodules.model.nonlinearity = silu ldm.modules.diffusionmodules.model.nonlinearity = silu
if cmd_opts.opt_split_attention_v1: if cmd_opts.opt_split_attention_v1:
@ -30,7 +32,7 @@ def apply_optimizations():
def undo_optimizations(): def undo_optimizations():
ldm.modules.attention.CrossAttention.forward = attention_CrossAttention_forward ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity
ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward