diff --git a/modules/images.py b/modules/images.py index bfc2ba06..7870b5b7 100644 --- a/modules/images.py +++ b/modules/images.py @@ -451,17 +451,6 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i """ namegen = FilenameGenerator(p, seed, prompt) - if extension == 'png' and opts.enable_pnginfo and info is not None: - pnginfo = PngImagePlugin.PngInfo() - - if existing_info is not None: - for k, v in existing_info.items(): - pnginfo.add_text(k, str(v)) - - pnginfo.add_text(pnginfo_section_name, info) - else: - pnginfo = None - if save_to_dirs is None: save_to_dirs = (grid and opts.grid_save_to_dirs) or (not grid and opts.save_to_dirs and not no_prompt) @@ -489,19 +478,27 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i if add_number: basecount = get_next_sequence_number(path, basename) fullfn = None - fullfn_without_extension = None for i in range(500): fn = f"{basecount + i:05}" if basename == '' else f"{basename}-{basecount + i:04}" fullfn = os.path.join(path, f"{fn}{file_decoration}.{extension}") - fullfn_without_extension = os.path.join(path, f"{fn}{file_decoration}") if not os.path.exists(fullfn): break else: fullfn = os.path.join(path, f"{file_decoration}.{extension}") - fullfn_without_extension = os.path.join(path, file_decoration) else: fullfn = os.path.join(path, f"{forced_filename}.{extension}") - fullfn_without_extension = os.path.join(path, forced_filename) + + pnginfo = existing_info or {} + if info is not None: + pnginfo[pnginfo_section_name] = info + + params = script_callbacks.ImageSaveParams(image, p, fullfn, pnginfo) + script_callbacks.before_image_saved_callback(params) + + image = params.image + fullfn = params.filename + info = params.pnginfo.get(pnginfo_section_name, None) + fullfn_without_extension, extension = os.path.splitext(params.filename) def exif_bytes(): return piexif.dump({ @@ -510,12 +507,20 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i }, }) - if extension.lower() in ("jpg", "jpeg", "webp"): + if extension.lower() == '.png': + pnginfo_data = PngImagePlugin.PngInfo() + for k, v in params.pnginfo.items(): + pnginfo_data.add_text(k, str(v)) + + image.save(fullfn, quality=opts.jpeg_quality, pnginfo=pnginfo_data) + + elif extension.lower() in (".jpg", ".jpeg", ".webp"): image.save(fullfn, quality=opts.jpeg_quality) + if opts.enable_pnginfo and info is not None: piexif.insert(exif_bytes(), fullfn) else: - image.save(fullfn, quality=opts.jpeg_quality, pnginfo=pnginfo) + image.save(fullfn, quality=opts.jpeg_quality) target_side_length = 4000 oversize = image.width > target_side_length or image.height > target_side_length @@ -538,7 +543,8 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i else: txt_fullfn = None - script_callbacks.image_saved_callback(image, p, fullfn, txt_fullfn) + script_callbacks.image_saved_callback(params) + return fullfn, txt_fullfn diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index 6803d57b..6ea58d61 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -9,15 +9,34 @@ def report_exception(c, job): print(traceback.format_exc(), file=sys.stderr) +class ImageSaveParams: + def __init__(self, image, p, filename, pnginfo): + self.image = image + """the PIL image itself""" + + self.p = p + """p object with processing parameters; either StableDiffusionProcessing or an object with same fields""" + + self.filename = filename + """name of file that the image would be saved to""" + + self.pnginfo = pnginfo + """dictionary with parameters for image's PNG info data; infotext will have the key 'parameters'""" + + ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"]) callbacks_model_loaded = [] callbacks_ui_tabs = [] callbacks_ui_settings = [] +callbacks_before_image_saved = [] callbacks_image_saved = [] + def clear_callbacks(): callbacks_model_loaded.clear() callbacks_ui_tabs.clear() + callbacks_ui_settings.clear() + callbacks_before_image_saved.clear() callbacks_image_saved.clear() @@ -49,10 +68,18 @@ def ui_settings_callback(): report_exception(c, 'ui_settings_callback') -def image_saved_callback(image, p, fullfn, txt_fullfn): +def before_image_saved_callback(params: ImageSaveParams): for c in callbacks_image_saved: try: - c.callback(image, p, fullfn, txt_fullfn) + c.callback(params) + except Exception: + report_exception(c, 'before_image_saved_callback') + + +def image_saved_callback(params: ImageSaveParams): + for c in callbacks_image_saved: + try: + c.callback(params) except Exception: report_exception(c, 'image_saved_callback') @@ -64,7 +91,6 @@ def add_callback(callbacks, fun): callbacks.append(ScriptCallback(filename, fun)) - def on_model_loaded(callback): """register a function to be called when the stable diffusion model is created; the model is passed as an argument""" @@ -90,11 +116,17 @@ def on_ui_settings(callback): add_callback(callbacks_ui_settings, callback) +def on_before_image_saved(callback): + """register a function to be called before an image is saved to a file. + The callback is called with one argument: + - params: ImageSaveParams - parameters the image is to be saved with. You can change fields in this object. + """ + add_callback(callbacks_before_image_saved, callback) + + def on_image_saved(callback): - """register a function to be called after modules.images.save_image is called. - The callback is called with three arguments: - - p - procesing object (or a dummy object with same fields if the image is saved using save button) - - fullfn - image filename - - txt_fullfn - text file with parameters; may be None + """register a function to be called after an image is saved to a file. + The callback is called with one argument: + - params: ImageSaveParams - parameters the image was saved with. Changing fields in this object does nothing. """ add_callback(callbacks_image_saved, callback)