put safety checker into a separate file because it's already crowded in processing
This commit is contained in:
parent
b03bc4e79a
commit
b5a8b99d3f
2 changed files with 44 additions and 34 deletions
|
@ -19,20 +19,11 @@ import modules.face_restoration
|
||||||
import modules.images as images
|
import modules.images as images
|
||||||
import modules.styles
|
import modules.styles
|
||||||
|
|
||||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
|
||||||
from transformers import AutoFeatureExtractor
|
|
||||||
|
|
||||||
# load safety model
|
|
||||||
safety_model_id = "CompVis/stable-diffusion-safety-checker"
|
|
||||||
safety_feature_extractor = None
|
|
||||||
safety_checker = None
|
|
||||||
|
|
||||||
# some of those options should not be changed at all because they would break the model, so I removed them from options.
|
# some of those options should not be changed at all because they would break the model, so I removed them from options.
|
||||||
opt_C = 4
|
opt_C = 4
|
||||||
opt_f = 8
|
opt_f = 8
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class StableDiffusionProcessing:
|
class StableDiffusionProcessing:
|
||||||
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", prompt_style="None", seed=-1, subseed=-1, subseed_strength=0, seed_resize_from_h=-1, seed_resize_from_w=-1, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None):
|
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", prompt_style="None", seed=-1, subseed=-1, subseed_strength=0, seed_resize_from_h=-1, seed_resize_from_w=-1, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None):
|
||||||
self.sd_model = sd_model
|
self.sd_model = sd_model
|
||||||
|
@ -154,28 +145,6 @@ def fix_seed(p):
|
||||||
p.subseed = int(random.randrange(4294967294)) if p.subseed is None or p.subseed == -1 else p.subseed
|
p.subseed = int(random.randrange(4294967294)) if p.subseed is None or p.subseed == -1 else p.subseed
|
||||||
|
|
||||||
|
|
||||||
def numpy_to_pil(images):
|
|
||||||
"""
|
|
||||||
Convert a numpy image or a batch of images to a PIL image.
|
|
||||||
"""
|
|
||||||
if images.ndim == 3:
|
|
||||||
images = images[None, ...]
|
|
||||||
images = (images * 255).round().astype("uint8")
|
|
||||||
pil_images = [Image.fromarray(image) for image in images]
|
|
||||||
|
|
||||||
return pil_images
|
|
||||||
|
|
||||||
# check and replace nsfw content
|
|
||||||
def check_safety(x_image):
|
|
||||||
global safety_feature_extractor, safety_checker
|
|
||||||
if safety_feature_extractor is None:
|
|
||||||
safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id)
|
|
||||||
safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)
|
|
||||||
safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt")
|
|
||||||
x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values)
|
|
||||||
return x_checked_image, has_nsfw_concept
|
|
||||||
|
|
||||||
|
|
||||||
def process_images(p: StableDiffusionProcessing) -> Processed:
|
def process_images(p: StableDiffusionProcessing) -> Processed:
|
||||||
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
|
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
|
||||||
|
|
||||||
|
@ -279,9 +248,8 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
||||||
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
|
|
||||||
if opts.filter_nsfw:
|
if opts.filter_nsfw:
|
||||||
x_samples_ddim_numpy = x_samples_ddim.cpu().permute(0, 2, 3, 1).numpy()
|
import modules.safety as safety
|
||||||
x_checked_image, has_nsfw_concept = check_safety(x_samples_ddim_numpy)
|
x_samples_ddim = modules.safety.censor_batch(x_samples_ddim)
|
||||||
x_samples_ddim = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2)
|
|
||||||
|
|
||||||
for i, x_sample in enumerate(x_samples_ddim):
|
for i, x_sample in enumerate(x_samples_ddim):
|
||||||
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
|
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
|
||||||
|
|
42
modules/safety.py
Normal file
42
modules/safety.py
Normal file
|
@ -0,0 +1,42 @@
|
||||||
|
import torch
|
||||||
|
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||||
|
from transformers import AutoFeatureExtractor
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
import modules.shared as shared
|
||||||
|
|
||||||
|
safety_model_id = "CompVis/stable-diffusion-safety-checker"
|
||||||
|
safety_feature_extractor = None
|
||||||
|
safety_checker = None
|
||||||
|
|
||||||
|
def numpy_to_pil(images):
|
||||||
|
"""
|
||||||
|
Convert a numpy image or a batch of images to a PIL image.
|
||||||
|
"""
|
||||||
|
if images.ndim == 3:
|
||||||
|
images = images[None, ...]
|
||||||
|
images = (images * 255).round().astype("uint8")
|
||||||
|
pil_images = [Image.fromarray(image) for image in images]
|
||||||
|
|
||||||
|
return pil_images
|
||||||
|
|
||||||
|
# check and replace nsfw content
|
||||||
|
def check_safety(x_image):
|
||||||
|
global safety_feature_extractor, safety_checker
|
||||||
|
|
||||||
|
if safety_feature_extractor is None:
|
||||||
|
safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id)
|
||||||
|
safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)
|
||||||
|
|
||||||
|
safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt")
|
||||||
|
x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values)
|
||||||
|
|
||||||
|
return x_checked_image, has_nsfw_concept
|
||||||
|
|
||||||
|
|
||||||
|
def censor_batch(x):
|
||||||
|
x_samples_ddim_numpy = x.cpu().permute(0, 2, 3, 1).numpy()
|
||||||
|
x_checked_image, has_nsfw_concept = check_safety(x_samples_ddim_numpy)
|
||||||
|
x = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2)
|
||||||
|
|
||||||
|
return x
|
Loading…
Reference in a new issue