xy_grid: Find hypernetwork by closest name
This commit is contained in:
parent
4aeacaefbf
commit
2d006ce16c
2 changed files with 16 additions and 1 deletions
|
@ -120,6 +120,17 @@ def load_hypernetwork(filename):
|
||||||
shared.loaded_hypernetwork = None
|
shared.loaded_hypernetwork = None
|
||||||
|
|
||||||
|
|
||||||
|
def find_closest_hypernetwork_name(search: str):
|
||||||
|
if not search:
|
||||||
|
return None
|
||||||
|
search = search.lower()
|
||||||
|
applicable = [name for name in shared.hypernetworks if search in name.lower()]
|
||||||
|
if not applicable:
|
||||||
|
return None
|
||||||
|
applicable = sorted(applicable, key=lambda name: len(name))
|
||||||
|
return applicable[0]
|
||||||
|
|
||||||
|
|
||||||
def apply_hypernetwork(hypernetwork, context, layer=None):
|
def apply_hypernetwork(hypernetwork, context, layer=None):
|
||||||
hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
|
hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
|
||||||
|
|
||||||
|
|
|
@ -84,7 +84,11 @@ def apply_checkpoint(p, x, xs):
|
||||||
|
|
||||||
|
|
||||||
def apply_hypernetwork(p, x, xs):
|
def apply_hypernetwork(p, x, xs):
|
||||||
hypernetwork.load_hypernetwork(x)
|
if x.lower() in ["", "none"]:
|
||||||
|
name = None
|
||||||
|
else:
|
||||||
|
name = hypernetwork.find_closest_hypernetwork_name(x)
|
||||||
|
hypernetwork.load_hypernetwork(name)
|
||||||
|
|
||||||
|
|
||||||
def apply_clip_skip(p, x, xs):
|
def apply_clip_skip(p, x, xs):
|
||||||
|
|
Loading…
Reference in a new issue