experimental optimization

This commit is contained in:
AUTOMATIC 2023-01-05 21:00:52 +03:00
parent f8d0cf6a6e
commit 847f869c67

View file

@ -544,6 +544,29 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
infotexts = [] infotexts = []
output_images = [] output_images = []
cached_uc = [None, None]
cached_c = [None, None]
def get_conds_with_caching(function, required_prompts, steps, cache):
"""
Returns the result of calling function(shared.sd_model, required_prompts, steps)
using a cache to store the result if the same arguments have been used before.
cache is an array containing two elements. The first element is a tuple
representing the previously used arguments, or None if no arguments
have been used before. The second element is where the previously
computed result is stored.
"""
if cache[0] is not None and (required_prompts, steps) == cache[0]:
return cache[1]
with devices.autocast():
cache[1] = function(shared.sd_model, required_prompts, steps)
cache[0] = (required_prompts, steps)
return cache[1]
with torch.no_grad(), p.sd_model.ema_scope(): with torch.no_grad(), p.sd_model.ema_scope():
with devices.autocast(): with devices.autocast():
p.init(p.all_prompts, p.all_seeds, p.all_subseeds) p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
@ -571,9 +594,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if p.scripts is not None: if p.scripts is not None:
p.scripts.process_batch(p, batch_number=n, prompts=prompts, seeds=seeds, subseeds=subseeds) p.scripts.process_batch(p, batch_number=n, prompts=prompts, seeds=seeds, subseeds=subseeds)
with devices.autocast(): uc = get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, p.steps, cached_uc)
uc = prompt_parser.get_learned_conditioning(shared.sd_model, negative_prompts, p.steps) c = get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, p.steps, cached_c)
c = prompt_parser.get_multicond_learned_conditioning(shared.sd_model, prompts, p.steps)
if len(model_hijack.comments) > 0: if len(model_hijack.comments) > 0:
for comment in model_hijack.comments: for comment in model_hijack.comments: