diff --git a/modules/processing.py b/modules/processing.py index 81400d14..056c9322 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -13,7 +13,7 @@ from skimage import exposure from typing import Any, Dict, List, Optional import modules.sd_hijack -from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste +from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks from modules.sd_hijack import model_hijack from modules.shared import opts, cmd_opts, state import modules.shared as shared @@ -571,9 +571,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: devices.torch_gc() - if opts.filter_nsfw: - import modules.safety as safety - x_samples_ddim = modules.safety.censor_batch(x_samples_ddim) + if p.scripts is not None: + p.scripts.postprocess_batch(p, x_samples_ddim, batch_number=n) for i, x_sample in enumerate(x_samples_ddim): x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) diff --git a/modules/safety.py b/modules/safety.py deleted file mode 100644 index cff4b278..00000000 --- a/modules/safety.py +++ /dev/null @@ -1,42 +0,0 @@ -import torch -from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker -from transformers import AutoFeatureExtractor -from PIL import Image - -import modules.shared as shared - -safety_model_id = "CompVis/stable-diffusion-safety-checker" -safety_feature_extractor = None -safety_checker = None - -def numpy_to_pil(images): - """ - Convert a numpy image or a batch of images to a PIL image. - """ - if images.ndim == 3: - images = images[None, ...] - images = (images * 255).round().astype("uint8") - pil_images = [Image.fromarray(image) for image in images] - - return pil_images - -# check and replace nsfw content -def check_safety(x_image): - global safety_feature_extractor, safety_checker - - if safety_feature_extractor is None: - safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id) - safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id) - - safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt") - x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values) - - return x_checked_image, has_nsfw_concept - - -def censor_batch(x): - x_samples_ddim_numpy = x.cpu().permute(0, 2, 3, 1).numpy() - x_checked_image, has_nsfw_concept = check_safety(x_samples_ddim_numpy) - x = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2) - - return x diff --git a/modules/scripts.py b/modules/scripts.py index b934d881..23ca195d 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -88,6 +88,17 @@ class Script: pass + def postprocess_batch(self, p, *args, **kwargs): + """ + Same as process_batch(), but called for every batch after it has been generated. + + **kwargs will have same items as process_batch, and also: + - batch_number - index of current batch, from 0 to number of batches-1 + - images - torch tensor with all generated images, with values ranging from 0 to 1; + """ + + pass + def postprocess(self, p, processed, *args): """ This function is called after processing ends for AlwaysVisible scripts. @@ -347,6 +358,15 @@ class ScriptRunner: print(f"Error running postprocess: {script.filename}", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) + def postprocess_batch(self, p, images, **kwargs): + for script in self.alwayson_scripts: + try: + script_args = p.script_args[script.args_from:script.args_to] + script.postprocess_batch(p, *script_args, images=images, **kwargs) + except Exception: + print(f"Error running postprocess_batch: {script.filename}", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + def before_component(self, component, **kwargs): for script in self.scripts: try: diff --git a/modules/shared.py b/modules/shared.py index 44922c91..272267c1 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -367,7 +367,6 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."), "enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"), "comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }), - "filter_nsfw": OptionInfo(False, "Filter NSFW content"), 'CLIP_stop_at_last_layers': OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}), "random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}), })) diff --git a/requirements.txt b/requirements.txt index 05818aa6..678acb4d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,5 @@ accelerate basicsr -diffusers fairscale==0.4.4 fonts font-roboto diff --git a/requirements_versions.txt b/requirements_versions.txt index 035fa82f..185cd066 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -1,5 +1,4 @@ transformers==4.19.2 -diffusers==0.3.0 accelerate==0.12.0 basicsr==1.4.2 gfpgan==1.3.8