stable-diffusion-webui/modules/hypernetwork.py
2022-10-11 11:10:17 +03:00

103 lines
3.1 KiB
Python

import glob
import os
import sys
import traceback
import torch
from ldm.util import default
from modules import devices, shared
import torch
from torch import einsum
from einops import rearrange, repeat
class HypernetworkModule(torch.nn.Module):
def __init__(self, dim, state_dict):
super().__init__()
self.linear1 = torch.nn.Linear(dim, dim * 2)
self.linear2 = torch.nn.Linear(dim * 2, dim)
self.load_state_dict(state_dict, strict=True)
self.to(devices.device)
def forward(self, x):
return x + (self.linear2(self.linear1(x)))
class Hypernetwork:
filename = None
name = None
def __init__(self, filename):
self.filename = filename
self.name = os.path.splitext(os.path.basename(filename))[0]
self.layers = {}
state_dict = torch.load(filename, map_location='cpu')
for size, sd in state_dict.items():
self.layers[size] = (HypernetworkModule(size, sd[0]), HypernetworkModule(size, sd[1]))
def list_hypernetworks(path):
res = {}
for filename in glob.iglob(os.path.join(path, '**/*.pt'), recursive=True):
name = os.path.splitext(os.path.basename(filename))[0]
res[name] = filename
return res
def load_hypernetwork(filename):
path = shared.hypernetworks.get(filename, None)
if path is not None:
print(f"Loading hypernetwork {filename}")
try:
shared.loaded_hypernetwork = Hypernetwork(path)
except Exception:
print(f"Error loading hypernetwork {path}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
else:
if shared.loaded_hypernetwork is not None:
print(f"Unloading hypernetwork")
shared.loaded_hypernetwork = None
def apply_hypernetwork(hypernetwork, context):
hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
if hypernetwork_layers is None:
return context, context
context_k = hypernetwork_layers[0](context)
context_v = hypernetwork_layers[1](context)
return context_k, context_v
def attention_CrossAttention_forward(self, x, context=None, mask=None):
h = self.heads
q = self.to_q(x)
context = default(context, x)
context_k, context_v = apply_hypernetwork(shared.loaded_hypernetwork, context)
k = self.to_k(context_k)
v = self.to_v(context_v)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
if mask is not None:
mask = rearrange(mask, 'b ... -> b (...)')
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, 'b j -> (b h) () j', h=h)
sim.masked_fill_(~mask, max_neg_value)
# attention, what we cannot get enough of
attn = sim.softmax(dim=-1)
out = einsum('b i j, b j d -> b i d', attn, v)
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
return self.to_out(out)