aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAUTOMATIC <16777216c@gmail.com>2022-12-10 14:54:02 +0300
committerAUTOMATIC <16777216c@gmail.com>2022-12-10 14:54:16 +0300
commit991e2dcee9d6baa66b5c0b1969c4c07407be933a (patch)
tree8cde65654885dc069aee99994f12ae14ba1aac98
parentd06592267c745b4732026c4e0c499c9a4b3900a1 (diff)
remove NSFW filter and its dependency; if you still want it, find it in the extensions section
-rw-r--r--modules/processing.py7
-rw-r--r--modules/safety.py42
-rw-r--r--modules/scripts.py20
-rw-r--r--modules/shared.py1
-rw-r--r--requirements.txt1
-rw-r--r--requirements_versions.txt1
6 files changed, 23 insertions, 49 deletions
diff --git a/modules/processing.py b/modules/processing.py
index 81400d14..056c9322 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -13,7 +13,7 @@ from skimage import exposure
from typing import Any, Dict, List, Optional
import modules.sd_hijack
-from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste
+from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks
from modules.sd_hijack import model_hijack
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
@@ -571,9 +571,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
devices.torch_gc()
- if opts.filter_nsfw:
- import modules.safety as safety
- x_samples_ddim = modules.safety.censor_batch(x_samples_ddim)
+ if p.scripts is not None:
+ p.scripts.postprocess_batch(p, x_samples_ddim, batch_number=n)
for i, x_sample in enumerate(x_samples_ddim):
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
diff --git a/modules/safety.py b/modules/safety.py
deleted file mode 100644
index cff4b278..00000000
--- a/modules/safety.py
+++ /dev/null
@@ -1,42 +0,0 @@
-import torch
-from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
-from transformers import AutoFeatureExtractor
-from PIL import Image
-
-import modules.shared as shared
-
-safety_model_id = "CompVis/stable-diffusion-safety-checker"
-safety_feature_extractor = None
-safety_checker = None
-
-def numpy_to_pil(images):
- """
- Convert a numpy image or a batch of images to a PIL image.
- """
- if images.ndim == 3:
- images = images[None, ...]
- images = (images * 255).round().astype("uint8")
- pil_images = [Image.fromarray(image) for image in images]
-
- return pil_images
-
-# check and replace nsfw content
-def check_safety(x_image):
- global safety_feature_extractor, safety_checker
-
- if safety_feature_extractor is None:
- safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id)
- safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)
-
- safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt")
- x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values)
-
- return x_checked_image, has_nsfw_concept
-
-
-def censor_batch(x):
- x_samples_ddim_numpy = x.cpu().permute(0, 2, 3, 1).numpy()
- x_checked_image, has_nsfw_concept = check_safety(x_samples_ddim_numpy)
- x = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2)
-
- return x
diff --git a/modules/scripts.py b/modules/scripts.py
index b934d881..23ca195d 100644
--- a/modules/scripts.py
+++ b/modules/scripts.py
@@ -88,6 +88,17 @@ class Script:
pass
+ def postprocess_batch(self, p, *args, **kwargs):
+ """
+ Same as process_batch(), but called for every batch after it has been generated.
+
+ **kwargs will have same items as process_batch, and also:
+ - batch_number - index of current batch, from 0 to number of batches-1
+ - images - torch tensor with all generated images, with values ranging from 0 to 1;
+ """
+
+ pass
+
def postprocess(self, p, processed, *args):
"""
This function is called after processing ends for AlwaysVisible scripts.
@@ -347,6 +358,15 @@ class ScriptRunner:
print(f"Error running postprocess: {script.filename}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
+ def postprocess_batch(self, p, images, **kwargs):
+ for script in self.alwayson_scripts:
+ try:
+ script_args = p.script_args[script.args_from:script.args_to]
+ script.postprocess_batch(p, *script_args, images=images, **kwargs)
+ except Exception:
+ print(f"Error running postprocess_batch: {script.filename}", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+
def before_component(self, component, **kwargs):
for script in self.scripts:
try:
diff --git a/modules/shared.py b/modules/shared.py
index 44922c91..272267c1 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -367,7 +367,6 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
"use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."),
"enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
"comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }),
- "filter_nsfw": OptionInfo(False, "Filter NSFW content"),
'CLIP_stop_at_last_layers': OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}),
"random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}),
}))
diff --git a/requirements.txt b/requirements.txt
index 05818aa6..678acb4d 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,6 +1,5 @@
accelerate
basicsr
-diffusers
fairscale==0.4.4
fonts
font-roboto
diff --git a/requirements_versions.txt b/requirements_versions.txt
index 035fa82f..185cd066 100644
--- a/requirements_versions.txt
+++ b/requirements_versions.txt
@@ -1,5 +1,4 @@
transformers==4.19.2
-diffusers==0.3.0
accelerate==0.12.0
basicsr==1.4.2
gfpgan==1.3.8