aboutsummaryrefslogtreecommitdiff
path: root/modules/sd_samplers_common.py
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2023-08-10 17:04:38 +0300
committerAUTOMATIC1111 <16777216c@gmail.com>2023-08-10 17:04:38 +0300
commit70a01cd4440d708bf25cc50393c0430935a8ebc2 (patch)
treedaad07800a3dadfd3caeac1383c1c65ecfcb6284 /modules/sd_samplers_common.py
parent1aefb5025929818b2a96cbb6148fcc2db7b947ec (diff)
parent070b034cd5b49eb5056a18b43f88aa223fec9e0b (diff)
Merge branch 'dev' into refiner
Diffstat (limited to 'modules/sd_samplers_common.py')
-rw-r--r--modules/sd_samplers_common.py24
1 files changed, 12 insertions, 12 deletions
diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py
index b6ad6830..35c4d657 100644
--- a/modules/sd_samplers_common.py
+++ b/modules/sd_samplers_common.py
@@ -1,5 +1,5 @@
import inspect
-from collections import namedtuple, deque
+from collections import namedtuple
import numpy as np
import torch
from PIL import Image
@@ -161,10 +161,15 @@ def apply_refiner(sampler):
class TorchHijack:
- def __init__(self, sampler_noises):
- # Using a deque to efficiently receive the sampler_noises in the same order as the previous index-based
- # implementation.
- self.sampler_noises = deque(sampler_noises)
+ """This is here to replace torch.randn_like of k-diffusion.
+
+ k-diffusion has random_sampler argument for most samplers, but not for all, so
+ this is needed to properly replace every use of torch.randn_like.
+
+ We need to replace to make images generated in batches to be same as images generated individually."""
+
+ def __init__(self, p):
+ self.rng = p.rng
def __getattr__(self, item):
if item == 'randn_like':
@@ -176,12 +181,7 @@ class TorchHijack:
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'")
def randn_like(self, x):
- if self.sampler_noises:
- noise = self.sampler_noises.popleft()
- if noise.shape == x.shape:
- return noise
-
- return devices.randn_like(x)
+ return self.rng.next()
class Sampler:
@@ -248,7 +248,7 @@ class Sampler:
self.eta = p.eta if p.eta is not None else getattr(opts, self.eta_option_field, 0.0)
self.s_min_uncond = getattr(p, 's_min_uncond', 0.0)
- k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else [])
+ k_diffusion.sampling.torch = TorchHijack(p)
extra_params_kwargs = {}
for param_name in self.extra_params: