From 254be4eeb259fd3bd2452250fca1bd278fa248ff Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Wed, 16 Aug 2023 21:45:19 -0400 Subject: Add extra noise callback --- modules/script_callbacks.py | 26 ++++++++++++++++++++++++++ modules/sd_samplers_kdiffusion.py | 4 ++++ modules/sd_samplers_timesteps.py | 4 ++++ 3 files changed, 34 insertions(+) diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index 77ee55ee..fab23551 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -28,6 +28,15 @@ class ImageSaveParams: """dictionary with parameters for image's PNG info data; infotext will have the key 'parameters'""" +class ExtraNoiseParams: + def __init__(self, noise, x): + self.noise = noise + """Random noise generated by the seed""" + + self.x = x + """Latent image representation of the image""" + + class CFGDenoiserParams: def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps, text_cond, text_uncond): self.x = x @@ -100,6 +109,7 @@ callback_map = dict( callbacks_ui_settings=[], callbacks_before_image_saved=[], callbacks_image_saved=[], + callbacks_extra_noise=[], callbacks_cfg_denoiser=[], callbacks_cfg_denoised=[], callbacks_cfg_after_cfg=[], @@ -189,6 +199,14 @@ def image_saved_callback(params: ImageSaveParams): report_exception(c, 'image_saved_callback') +def extra_noise_callback(params: ExtraNoiseParams): + for c in callback_map['callbacks_extra_noise']: + try: + c.callback(params) + except Exception: + report_exception(c, 'callbacks_extra_noise') + + def cfg_denoiser_callback(params: CFGDenoiserParams): for c in callback_map['callbacks_cfg_denoiser']: try: @@ -367,6 +385,14 @@ def on_image_saved(callback): add_callback(callback_map['callbacks_image_saved'], callback) +def on_extra_noise(callback): + """register a function to be called before adding extra noise in img2img or hires fix; + The callback is called with one argument: + - params: ExtraNoiseParams - contains noise determined by seed and latent representation of image + """ + add_callback(callback_map['callbacks_extra_noise'], callback) + + def on_cfg_denoiser(callback): """register a function to be called in the kdiffussion cfg_denoiser method after building the inner model inputs. The callback is called with one argument: diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index 08866e41..b9e0d577 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -3,6 +3,7 @@ import inspect import k_diffusion.sampling from modules import sd_samplers_common, sd_samplers_extra, sd_samplers_cfg_denoiser from modules.sd_samplers_cfg_denoiser import CFGDenoiser # noqa: F401 +from modules.script_callbacks import ExtraNoiseParams, extra_noise_callback from modules.shared import opts import modules.shared as shared @@ -147,6 +148,9 @@ class KDiffusionSampler(sd_samplers_common.Sampler): if opts.img2img_extra_noise > 0: p.extra_generation_params["Extra noise"] = opts.img2img_extra_noise + extra_noise_params = ExtraNoiseParams(noise, x) + extra_noise_callback(extra_noise_params) + noise = extra_noise_params.noise xi += noise * opts.img2img_extra_noise extra_params_kwargs = self.initialize(p) diff --git a/modules/sd_samplers_timesteps.py b/modules/sd_samplers_timesteps.py index 68aea454..7a6cbd46 100644 --- a/modules/sd_samplers_timesteps.py +++ b/modules/sd_samplers_timesteps.py @@ -3,6 +3,7 @@ import inspect import sys from modules import devices, sd_samplers_common, sd_samplers_timesteps_impl from modules.sd_samplers_cfg_denoiser import CFGDenoiser +from modules.script_callbacks import ExtraNoiseParams, extra_noise_callback from modules.shared import opts import modules.shared as shared @@ -106,6 +107,9 @@ class CompVisSampler(sd_samplers_common.Sampler): if opts.img2img_extra_noise > 0: p.extra_generation_params["Extra noise"] = opts.img2img_extra_noise + extra_noise_params = ExtraNoiseParams(noise, x) + extra_noise_callback(extra_noise_params) + noise = extra_noise_params.noise xi += noise * opts.img2img_extra_noise * sqrt_alpha_cumprod extra_params_kwargs = self.initialize(p) -- cgit v1.2.1