remove functools.cache as some people are having issues with it

This commit is contained in:
AUTOMATIC 2022-10-04 18:02:01 +03:00
parent e1b128d8e4
commit 1eb588cbf1

View file

@ -29,6 +29,7 @@ def get_learned_conditioning_prompt_schedules(prompts, steps):
%import common.SIGNED_NUMBER -> NUMBER %import common.SIGNED_NUMBER -> NUMBER
""" """
parser = Lark(grammar, parser='lalr') parser = Lark(grammar, parser='lalr')
def collect_steps(steps, tree): def collect_steps(steps, tree):
l = [steps] l = [steps]
class CollectSteps(Visitor): class CollectSteps(Visitor):
@ -40,6 +41,7 @@ def get_learned_conditioning_prompt_schedules(prompts, steps):
l.append(tree.children[-1]) l.append(tree.children[-1])
CollectSteps().visit(tree) CollectSteps().visit(tree)
return sorted(set(l)) return sorted(set(l))
def at_step(step, tree): def at_step(step, tree):
class AtStep(Transformer): class AtStep(Transformer):
def scheduled(self, args): def scheduled(self, args):
@ -62,11 +64,13 @@ def get_learned_conditioning_prompt_schedules(prompts, steps):
for child in children: for child in children:
yield from child yield from child
return AtStep().transform(tree) return AtStep().transform(tree)
@functools.cache
def get_schedule(prompt): def get_schedule(prompt):
tree = parser.parse(prompt) tree = parser.parse(prompt)
return [[t, at_step(t, tree)] for t in collect_steps(steps, tree)] return [[t, at_step(t, tree)] for t in collect_steps(steps, tree)]
return [get_schedule(prompt) for prompt in prompts]
promptdict = {prompt: get_schedule(prompt) for prompt in set(prompts)}
return [promptdict[prompt] for prompt in prompts]
ScheduledPromptConditioning = namedtuple("ScheduledPromptConditioning", ["end_at_step", "cond"]) ScheduledPromptConditioning = namedtuple("ScheduledPromptConditioning", ["end_at_step", "cond"])