add postprocess call for scripts
This commit is contained in:
parent
35c45df28b
commit
9bb6b6509a
2 changed files with 30 additions and 6 deletions
|
@ -478,7 +478,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||||
model_hijack.embedding_db.load_textual_inversion_embeddings()
|
model_hijack.embedding_db.load_textual_inversion_embeddings()
|
||||||
|
|
||||||
if p.scripts is not None:
|
if p.scripts is not None:
|
||||||
p.scripts.run_alwayson_scripts(p)
|
p.scripts.process(p)
|
||||||
|
|
||||||
infotexts = []
|
infotexts = []
|
||||||
output_images = []
|
output_images = []
|
||||||
|
@ -501,7 +501,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||||
seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
|
seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
|
||||||
subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
|
subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
|
||||||
|
|
||||||
if (len(prompts) == 0):
|
if len(prompts) == 0:
|
||||||
break
|
break
|
||||||
|
|
||||||
with devices.autocast():
|
with devices.autocast():
|
||||||
|
@ -590,7 +590,13 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||||
images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
|
images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
|
||||||
|
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
return Processed(p, output_images, p.all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], all_prompts=p.all_prompts, all_seeds=p.all_seeds, all_subseeds=p.all_subseeds, index_of_first_image=index_of_first_image, infotexts=infotexts)
|
|
||||||
|
res = Processed(p, output_images, p.all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], all_prompts=p.all_prompts, all_seeds=p.all_seeds, all_subseeds=p.all_subseeds, index_of_first_image=index_of_first_image, infotexts=infotexts)
|
||||||
|
|
||||||
|
if p.scripts is not None:
|
||||||
|
p.scripts.postprocess(p, res)
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||||
|
|
|
@ -64,7 +64,16 @@ class Script:
|
||||||
def process(self, p, *args):
|
def process(self, p, *args):
|
||||||
"""
|
"""
|
||||||
This function is called before processing begins for AlwaysVisible scripts.
|
This function is called before processing begins for AlwaysVisible scripts.
|
||||||
scripts. You can modify the processing object (p) here, inject hooks, etc.
|
You can modify the processing object (p) here, inject hooks, etc.
|
||||||
|
args contains all values returned by components from ui()
|
||||||
|
"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
def postprocess(self, p, processed, *args):
|
||||||
|
"""
|
||||||
|
This function is called after processing ends for AlwaysVisible scripts.
|
||||||
|
args contains all values returned by components from ui()
|
||||||
"""
|
"""
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
@ -289,13 +298,22 @@ class ScriptRunner:
|
||||||
|
|
||||||
return processed
|
return processed
|
||||||
|
|
||||||
def run_alwayson_scripts(self, p):
|
def process(self, p):
|
||||||
for script in self.alwayson_scripts:
|
for script in self.alwayson_scripts:
|
||||||
try:
|
try:
|
||||||
script_args = p.script_args[script.args_from:script.args_to]
|
script_args = p.script_args[script.args_from:script.args_to]
|
||||||
script.process(p, *script_args)
|
script.process(p, *script_args)
|
||||||
except Exception:
|
except Exception:
|
||||||
print(f"Error running alwayson script: {script.filename}", file=sys.stderr)
|
print(f"Error running process: {script.filename}", file=sys.stderr)
|
||||||
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
|
||||||
|
def postprocess(self, p, processed):
|
||||||
|
for script in self.alwayson_scripts:
|
||||||
|
try:
|
||||||
|
script_args = p.script_args[script.args_from:script.args_to]
|
||||||
|
script.postprocess(p, processed, *script_args)
|
||||||
|
except Exception:
|
||||||
|
print(f"Error running postprocess: {script.filename}", file=sys.stderr)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
|
||||||
def reload_sources(self, cache):
|
def reload_sources(self, cache):
|
||||||
|
|
Loading…
Reference in a new issue