Allow nested structures inside schedules

This commit is contained in:
dan 2022-10-03 19:25:36 +08:00 committed by AUTOMATIC1111
parent 6c6ae28bf5
commit 2f1b61d979
3 changed files with 55 additions and 66 deletions

View file

@ -1,20 +1,11 @@
import re import re
from collections import namedtuple from collections import namedtuple
import torch import torch
from lark import Lark, Transformer, Visitor
import functools
import modules.shared as shared import modules.shared as shared
re_prompt = re.compile(r'''
(.*?)
\[
([^]:]+):
(?:([^]:]*):)?
([0-9]*\.?[0-9]+)
]
|
(.+)
''', re.X)
# a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]" # a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]"
# will be represented with prompt_schedule like this (assuming steps=100): # will be represented with prompt_schedule like this (assuming steps=100):
# [25, 'fantasy landscape with a mountain and an oak in foreground shoddy'] # [25, 'fantasy landscape with a mountain and an oak in foreground shoddy']
@ -25,61 +16,57 @@ re_prompt = re.compile(r'''
def get_learned_conditioning_prompt_schedules(prompts, steps): def get_learned_conditioning_prompt_schedules(prompts, steps):
res = [] grammar = r"""
cache = {} start: prompt
prompt: (emphasized | scheduled | weighted | plain)*
for prompt in prompts: !emphasized: "(" prompt ")"
prompt_schedule: list[list[str | int]] = [[steps, ""]] | "(" prompt ":" prompt ")"
| "[" prompt "]"
cached = cache.get(prompt, None) scheduled: "[" (prompt ":")? prompt ":" NUMBER "]"
if cached is not None: !weighted: "{" weighted_item ("|" weighted_item)* "}"
res.append(cached) !weighted_item: prompt (":" prompt)?
continue plain: /([^\\\[\](){}:|]|\\.)+/
%import common.SIGNED_NUMBER -> NUMBER
for m in re_prompt.finditer(prompt): """
plaintext = m.group(1) if m.group(5) is None else m.group(5) parser = Lark(grammar, parser='lalr')
concept_from = m.group(2) def collect_steps(steps, tree):
concept_to = m.group(3) l = [steps]
if concept_to is None: class CollectSteps(Visitor):
concept_to = concept_from def scheduled(self, tree):
concept_from = "" tree.children[-1] = float(tree.children[-1])
swap_position = float(m.group(4)) if m.group(4) is not None else None if tree.children[-1] < 1:
tree.children[-1] *= steps
if swap_position is not None: tree.children[-1] = min(steps, int(tree.children[-1]))
if swap_position < 1: l.append(tree.children[-1])
swap_position = swap_position * steps CollectSteps().visit(tree)
swap_position = int(min(swap_position, steps)) return sorted(set(l))
def at_step(step, tree):
swap_index = None class AtStep(Transformer):
found_exact_index = False def scheduled(self, args):
for i in range(len(prompt_schedule)): if len(args) == 2:
end_step = prompt_schedule[i][0] before, after, when = (), *args
prompt_schedule[i][1] += plaintext else:
before, after, when = args
if swap_position is not None and swap_index is None: yield before if step <= when else after
if swap_position == end_step: def start(self, args):
swap_index = i def flatten(x):
found_exact_index = True if type(x) == str:
yield x
if swap_position < end_step: else:
swap_index = i for gen in x:
yield from flatten(gen)
if swap_index is not None: return ''.join(flatten(args[0]))
if not found_exact_index: def plain(self, args):
prompt_schedule.insert(swap_index, [swap_position, prompt_schedule[swap_index][1]]) yield args[0].value
def __default__(self, data, children, meta):
for i in range(len(prompt_schedule)): for child in children:
end_step = prompt_schedule[i][0] yield from child
must_replace = swap_position < end_step return AtStep().transform(tree)
@functools.cache
prompt_schedule[i][1] += concept_to if must_replace else concept_from def get_schedule(prompt):
tree = parser.parse(prompt)
res.append(prompt_schedule) return [[t, at_step(t, tree)] for t in collect_steps(steps, tree)]
cache[prompt] = prompt_schedule return [get_schedule(prompt) for prompt in prompts]
#for t in prompt_schedule:
# print(t)
return res
ScheduledPromptConditioning = namedtuple("ScheduledPromptConditioning", ["end_at_step", "cond"]) ScheduledPromptConditioning = namedtuple("ScheduledPromptConditioning", ["end_at_step", "cond"])

View file

@ -22,3 +22,4 @@ clean-fid
resize-right resize-right
torchdiffeq torchdiffeq
kornia kornia
lark

View file

@ -21,3 +21,4 @@ clean-fid==0.1.29
resize-right==0.0.2 resize-right==0.0.2
torchdiffeq==0.2.3 torchdiffeq==0.2.3
kornia==0.6.7 kornia==0.6.7
lark==1.1.2