aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
authorAUTOMATIC <16777216c@gmail.com>2022-09-16 09:48:46 +0300
committerAUTOMATIC <16777216c@gmail.com>2022-09-16 09:48:46 +0300
commit83bce1a604f40356583d31c8c0b2f8b590dda071 (patch)
treec094f7241b9e001367fce750f6344dfa8a66ebcb /modules
parentb44ddcb44398fbe922fd7515f66d8b0c2344bc54 (diff)
parent87e8b9a2ab3f033e7fdadbb2fe258857915980ac (diff)
Merge branch 'batch-seed-attempt'
Diffstat (limited to 'modules')
-rw-r--r--modules/devices.py10
-rw-r--r--modules/processing.py18
-rw-r--r--modules/sd_samplers.py40
-rw-r--r--modules/shared.py1
4 files changed, 67 insertions, 2 deletions
diff --git a/modules/devices.py b/modules/devices.py
index e4430e1a..07bb2339 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -48,3 +48,13 @@ def randn(seed, shape):
torch.manual_seed(seed)
return torch.randn(shape, device=device)
+
+def randn_without_seed(shape):
+ # Pytorch currently doesn't handle setting randomness correctly when the metal backend is used.
+ if device.type == 'mps':
+ generator = torch.Generator(device=cpu)
+ noise = torch.randn(shape, generator=generator, device=cpu).to(device)
+ return noise
+
+ return torch.randn(shape, device=device)
+
diff --git a/modules/processing.py b/modules/processing.py
index 71a9c6f5..798313ee 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -119,8 +119,14 @@ def slerp(val, low, high):
return res
-def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0):
+def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0, p=None):
xs = []
+
+ if p is not None and p.sampler is not None and len(seeds) > 1 and opts.enable_batch_seeds:
+ sampler_noises = [[] for _ in range(p.sampler.number_of_needed_noises(p))]
+ else:
+ sampler_noises = None
+
for i, seed in enumerate(seeds):
noise_shape = shape if seed_resize_from_h <= 0 or seed_resize_from_w <= 0 else (shape[0], seed_resize_from_h//8, seed_resize_from_w//8)
@@ -155,9 +161,17 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
x[:, ty:ty+h, tx:tx+w] = noise[:, dy:dy+h, dx:dx+w]
noise = x
+ if sampler_noises is not None:
+ cnt = p.sampler.number_of_needed_noises(p)
+ for j in range(cnt):
+ sampler_noises[j].append(devices.randn_without_seed(tuple(noise_shape)))
xs.append(noise)
+
+ if sampler_noises is not None:
+ p.sampler.sampler_noises = [torch.stack(n).to(shared.device) for n in sampler_noises]
+
x = torch.stack(xs).to(shared.device)
return x
@@ -257,7 +271,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
comments[comment] = 1
# we manually generate all input noises because each one should have a specific seed
- x = create_random_tensors([opt_C, p.height // opt_f, p.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, seed_resize_from_h=p.seed_resize_from_h, seed_resize_from_w=p.seed_resize_from_w)
+ x = create_random_tensors([opt_C, p.height // opt_f, p.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, seed_resize_from_h=p.seed_resize_from_h, seed_resize_from_w=p.seed_resize_from_w, p=p)
if p.n_iter > 1:
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py
index 5d95bfe0..02ffce0e 100644
--- a/modules/sd_samplers.py
+++ b/modules/sd_samplers.py
@@ -80,8 +80,12 @@ class VanillaStableDiffusionSampler:
self.mask = None
self.nmask = None
self.init_latent = None
+ self.sampler_noises = None
self.step = 0
+ def number_of_needed_noises(self, p):
+ return 0
+
def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs):
cond = prompt_parser.reconstruct_cond_batch(cond, self.step)
unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step)
@@ -185,16 +189,46 @@ def extended_trange(count, *args, **kwargs):
shared.total_tqdm.update()
+class TorchHijack:
+ def __init__(self, kdiff_sampler):
+ self.kdiff_sampler = kdiff_sampler
+
+ def __getattr__(self, item):
+ if item == 'randn_like':
+ return self.kdiff_sampler.randn_like
+
+ if hasattr(torch, item):
+ return getattr(torch, item)
+
+ raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item))
+
+
class KDiffusionSampler:
def __init__(self, funcname, sd_model):
self.model_wrap = k_diffusion.external.CompVisDenoiser(sd_model, quantize=shared.opts.enable_quantization)
self.funcname = funcname
self.func = getattr(k_diffusion.sampling, self.funcname)
self.model_wrap_cfg = CFGDenoiser(self.model_wrap)
+ self.sampler_noises = None
+ self.sampler_noise_index = 0
def callback_state(self, d):
store_latent(d["denoised"])
+ def number_of_needed_noises(self, p):
+ return p.steps
+
+ def randn_like(self, x):
+ noise = self.sampler_noises[self.sampler_noise_index] if self.sampler_noises is not None and self.sampler_noise_index < len(self.sampler_noises) else None
+
+ if noise is not None and x.shape == noise.shape:
+ res = noise
+ else:
+ res = torch.randn_like(x)
+
+ self.sampler_noise_index += 1
+ return res
+
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning):
t_enc = int(min(p.denoising_strength, 0.999) * p.steps)
sigmas = self.model_wrap.get_sigmas(p.steps)
@@ -213,6 +247,9 @@ class KDiffusionSampler:
if hasattr(k_diffusion.sampling, 'trange'):
k_diffusion.sampling.trange = lambda *args, **kwargs: extended_trange(*args, **kwargs)
+ if self.sampler_noises is not None:
+ k_diffusion.sampling.torch = TorchHijack(self)
+
return self.func(self.model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state)
def sample(self, p, x, conditioning, unconditional_conditioning):
@@ -224,6 +261,9 @@ class KDiffusionSampler:
if hasattr(k_diffusion.sampling, 'trange'):
k_diffusion.sampling.trange = lambda *args, **kwargs: extended_trange(*args, **kwargs)
+ if self.sampler_noises is not None:
+ k_diffusion.sampling.torch = TorchHijack(self)
+
samples_ddim = self.func(self.model_wrap_cfg, x, sigmas, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state)
return samples_ddim
diff --git a/modules/shared.py b/modules/shared.py
index 78450546..fa6a0e99 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -128,6 +128,7 @@ class Options:
"enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."),
"font": OptionInfo("", "Font for image grids that have text"),
"enable_emphasis": OptionInfo(True, "Use (text) to make model pay more attention to text and [text] to make it pay less attention"),
+ "enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
"save_txt": OptionInfo(False, "Create a text file next to every image with generation parameters."),
"ESRGAN_tile": OptionInfo(192, "Tile size for upscaling. 0 = no tiling.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}),
"ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for upscaling. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),