From 991e2dcee9d6baa66b5c0b1969c4c07407be933a Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 10 Dec 2022 14:54:02 +0300 Subject: remove NSFW filter and its dependency; if you still want it, find it in the extensions section --- modules/safety.py | 42 ------------------------------------------ 1 file changed, 42 deletions(-) delete mode 100644 modules/safety.py (limited to 'modules/safety.py') diff --git a/modules/safety.py b/modules/safety.py deleted file mode 100644 index cff4b278..00000000 --- a/modules/safety.py +++ /dev/null @@ -1,42 +0,0 @@ -import torch -from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker -from transformers import AutoFeatureExtractor -from PIL import Image - -import modules.shared as shared - -safety_model_id = "CompVis/stable-diffusion-safety-checker" -safety_feature_extractor = None -safety_checker = None - -def numpy_to_pil(images): - """ - Convert a numpy image or a batch of images to a PIL image. - """ - if images.ndim == 3: - images = images[None, ...] - images = (images * 255).round().astype("uint8") - pil_images = [Image.fromarray(image) for image in images] - - return pil_images - -# check and replace nsfw content -def check_safety(x_image): - global safety_feature_extractor, safety_checker - - if safety_feature_extractor is None: - safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id) - safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id) - - safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt") - x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values) - - return x_checked_image, has_nsfw_concept - - -def censor_batch(x): - x_samples_ddim_numpy = x.cpu().permute(0, 2, 3, 1).numpy() - x_checked_image, has_nsfw_concept = check_safety(x_samples_ddim_numpy) - x = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2) - - return x -- cgit v1.2.1