alternate prompt
This commit is contained in:
parent
34acad1628
commit
a5550f0213
1 changed files with 7 additions and 2 deletions
|
@ -13,13 +13,14 @@ import lark
|
||||||
|
|
||||||
schedule_parser = lark.Lark(r"""
|
schedule_parser = lark.Lark(r"""
|
||||||
!start: (prompt | /[][():]/+)*
|
!start: (prompt | /[][():]/+)*
|
||||||
prompt: (emphasized | scheduled | plain | WHITESPACE)*
|
prompt: (emphasized | scheduled | alternate | plain | WHITESPACE)*
|
||||||
!emphasized: "(" prompt ")"
|
!emphasized: "(" prompt ")"
|
||||||
| "(" prompt ":" prompt ")"
|
| "(" prompt ":" prompt ")"
|
||||||
| "[" prompt "]"
|
| "[" prompt "]"
|
||||||
scheduled: "[" [prompt ":"] prompt ":" [WHITESPACE] NUMBER "]"
|
scheduled: "[" [prompt ":"] prompt ":" [WHITESPACE] NUMBER "]"
|
||||||
|
alternate: "[" prompt ("|" prompt)+ "]"
|
||||||
WHITESPACE: /\s+/
|
WHITESPACE: /\s+/
|
||||||
plain: /([^\\\[\]():]|\\.)+/
|
plain: /([^\\\[\]():|]|\\.)+/
|
||||||
%import common.SIGNED_NUMBER -> NUMBER
|
%import common.SIGNED_NUMBER -> NUMBER
|
||||||
""")
|
""")
|
||||||
|
|
||||||
|
@ -59,6 +60,8 @@ def get_learned_conditioning_prompt_schedules(prompts, steps):
|
||||||
tree.children[-1] *= steps
|
tree.children[-1] *= steps
|
||||||
tree.children[-1] = min(steps, int(tree.children[-1]))
|
tree.children[-1] = min(steps, int(tree.children[-1]))
|
||||||
l.append(tree.children[-1])
|
l.append(tree.children[-1])
|
||||||
|
def alternate(self, tree):
|
||||||
|
l.extend(range(1, steps+1))
|
||||||
CollectSteps().visit(tree)
|
CollectSteps().visit(tree)
|
||||||
return sorted(set(l))
|
return sorted(set(l))
|
||||||
|
|
||||||
|
@ -67,6 +70,8 @@ def get_learned_conditioning_prompt_schedules(prompts, steps):
|
||||||
def scheduled(self, args):
|
def scheduled(self, args):
|
||||||
before, after, _, when = args
|
before, after, _, when = args
|
||||||
yield before or () if step <= when else after
|
yield before or () if step <= when else after
|
||||||
|
def alternate(self, args):
|
||||||
|
yield next(args[(step - 1)%len(args)])
|
||||||
def start(self, args):
|
def start(self, args):
|
||||||
def flatten(x):
|
def flatten(x):
|
||||||
if type(x) == str:
|
if type(x) == str:
|
||||||
|
|
Loading…
Reference in a new issue