Use apply_hypernetwork function
This commit is contained in:
parent
574c8e554a
commit
861db783c7
1 changed files with 4 additions and 10 deletions
|
@ -202,16 +202,10 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None):
|
||||||
q = self.to_q(x)
|
q = self.to_q(x)
|
||||||
context = default(context, x)
|
context = default(context, x)
|
||||||
|
|
||||||
hypernetwork = shared.loaded_hypernetwork
|
context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
|
||||||
hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
|
k = self.to_k(context_k) * self.scale
|
||||||
|
v = self.to_v(context_v)
|
||||||
if hypernetwork_layers is not None:
|
del context, context_k, context_v, x
|
||||||
k = self.to_k(hypernetwork_layers[0](context)) * self.scale
|
|
||||||
v = self.to_v(hypernetwork_layers[1](context))
|
|
||||||
else:
|
|
||||||
k = self.to_k(context) * self.scale
|
|
||||||
v = self.to_v(context)
|
|
||||||
del context, x
|
|
||||||
|
|
||||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
||||||
r = einsum_op(q, k, v)
|
r = einsum_op(q, k, v)
|
||||||
|
|
Loading…
Reference in a new issue