aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
Diffstat (limited to 'modules')
-rw-r--r--modules/sd_samplers.py33
1 files changed, 15 insertions, 18 deletions
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py
index 4fe67854..44112f99 100644
--- a/modules/sd_samplers.py
+++ b/modules/sd_samplers.py
@@ -1,4 +1,4 @@
-from collections import namedtuple
+from collections import namedtuple, deque
import numpy as np
from math import floor
import torch
@@ -335,18 +335,28 @@ class CFGDenoiser(torch.nn.Module):
class TorchHijack:
- def __init__(self, kdiff_sampler):
- self.kdiff_sampler = kdiff_sampler
+ def __init__(self, sampler_noises):
+ # Using a deque to efficiently receive the sampler_noises in the same order as the previous index-based
+ # implementation.
+ self.sampler_noises = deque(sampler_noises)
def __getattr__(self, item):
if item == 'randn_like':
- return self.kdiff_sampler.randn_like
+ return self.randn_like
if hasattr(torch, item):
return getattr(torch, item)
raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item))
+ def randn_like(self, x):
+ if self.sampler_noises:
+ noise = self.sampler_noises.popleft()
+ if noise.shape == x.shape:
+ return noise
+
+ return torch.randn_like(x)
+
class KDiffusionSampler:
def __init__(self, funcname, sd_model):
@@ -356,7 +366,6 @@ class KDiffusionSampler:
self.extra_params = sampler_extra_params.get(funcname, [])
self.model_wrap_cfg = CFGDenoiser(self.model_wrap)
self.sampler_noises = None
- self.sampler_noise_index = 0
self.stop_at = None
self.eta = None
self.default_eta = 1.0
@@ -389,26 +398,14 @@ class KDiffusionSampler:
def number_of_needed_noises(self, p):
return p.steps
- def randn_like(self, x):
- noise = self.sampler_noises[self.sampler_noise_index] if self.sampler_noises is not None and self.sampler_noise_index < len(self.sampler_noises) else None
-
- if noise is not None and x.shape == noise.shape:
- res = noise
- else:
- res = torch.randn_like(x)
-
- self.sampler_noise_index += 1
- return res
-
def initialize(self, p):
self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
self.model_wrap.step = 0
- self.sampler_noise_index = 0
self.eta = p.eta or opts.eta_ancestral
if self.sampler_noises is not None:
- k_diffusion.sampling.torch = TorchHijack(self)
+ k_diffusion.sampling.torch = TorchHijack(self.sampler_noises)
extra_params_kwargs = {}
for param_name in self.extra_params: