Use apply_hypernetwork function

This commit is contained in:
brkirch 2022-10-11 05:13:17 -04:00 committed by AUTOMATIC1111
parent 574c8e554a
commit 861db783c7

View file

@ -202,16 +202,10 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None):
q = self.to_q(x)
context = default(context, x)
hypernetwork = shared.loaded_hypernetwork
hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
if hypernetwork_layers is not None:
k = self.to_k(hypernetwork_layers[0](context)) * self.scale
v = self.to_v(hypernetwork_layers[1](context))
else:
k = self.to_k(context) * self.scale
v = self.to_v(context)
del context, x
context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
k = self.to_k(context_k) * self.scale
v = self.to_v(context_v)
del context, context_k, context_v, x
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
r = einsum_op(q, k, v)