aboutsummaryrefslogtreecommitdiff
path: root/webui.py
diff options
context:
space:
mode:
authorAUTOMATIC <16777216c@gmail.com>2022-08-31 15:09:40 +0300
committerAUTOMATIC <16777216c@gmail.com>2022-08-31 15:09:40 +0300
commit53e7616b5133a0bffc799cae8b1a66395f975f3a (patch)
tree0ab0c3040dfbf17acb649807083685fe08f1a79d /webui.py
parent9427e4e290ef2a6f1d127e2ab2748629a51f31f5 (diff)
DDIM support returned for img2img
Diffstat (limited to 'webui.py')
-rw-r--r--webui.py79
1 files changed, 55 insertions, 24 deletions
diff --git a/webui.py b/webui.py
index b8088795..80952b79 100644
--- a/webui.py
+++ b/webui.py
@@ -94,7 +94,7 @@ samplers = [
SamplerData('DDIM', lambda: VanillaStableDiffusionSampler(DDIMSampler)),
SamplerData('PLMS', lambda: VanillaStableDiffusionSampler(PLMSSampler)),
]
-samplers_for_img2img = [x for x in samplers if x.name != 'DDIM' and x.name != 'PLMS']
+samplers_for_img2img = [x for x in samplers if x.name != 'PLMS']
RealesrganModelInfo = namedtuple("RealesrganModelInfo", ["name", "location", "model", "netscale"])
@@ -835,9 +835,37 @@ class StableDiffusionProcessing:
raise NotImplementedError()
+def p_sample_ddim_hook(sampler_wrapper, x_dec, cond, ts, *args, **kwargs):
+ if sampler_wrapper.mask is not None:
+ img_orig = sampler_wrapper.sampler.model.q_sample(sampler_wrapper.init_latent, ts)
+ x_dec = img_orig * sampler_wrapper.mask + sampler_wrapper.nmask * x_dec
+
+ return sampler_wrapper.orig_p_sample_ddim(x_dec, cond, ts, *args, **kwargs)
+
+
class VanillaStableDiffusionSampler:
def __init__(self, constructor):
self.sampler = constructor(sd_model)
+ self.orig_p_sample_ddim = self.sampler.p_sample_ddim
+ self.sampler.p_sample_ddim = lambda x_dec, cond, ts, *args, **kwargs: p_sample_ddim_hook(self, x_dec, cond, ts, *args, **kwargs)
+ self.mask = None
+ self.nmask = None
+ self.init_latent = None
+
+ def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning):
+ t_enc = int(min(p.denoising_strength, 0.999) * p.steps)
+
+ self.sampler.make_schedule(ddim_num_steps=p.steps, ddim_eta=0.0, verbose=False)
+ x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(device), noise=noise)
+
+ self.mask = p.mask
+ self.nmask = p.nmask
+ self.init_latent = p.init_latent
+
+ samples = self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning)
+
+ return samples
+
def sample(self, p: StableDiffusionProcessing, x, conditioning, unconditional_conditioning):
samples_ddim, _ = self.sampler.sample(S=p.steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x)
@@ -864,6 +892,27 @@ class KDiffusionSampler:
self.func = getattr(k_diffusion.sampling, self.funcname)
self.model_wrap_cfg = CFGDenoiser(self.model_wrap)
+ def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning):
+ t_enc = int(min(p.denoising_strength, 0.999) * p.steps)
+ sigmas = self.model_wrap.get_sigmas(p.steps)
+ noise = noise * sigmas[p.steps - t_enc - 1]
+
+ xi = x + noise
+
+ if p.mask is not None:
+ if p.inpainting_fill == 2:
+ xi = xi * p.mask + noise * p.nmask
+ elif p.inpainting_fill == 3:
+ xi = xi * p.mask
+
+ sigma_sched = sigmas[p.steps - t_enc - 1:]
+
+ def mask_cb(v):
+ v["denoised"][:] = v["denoised"][:] * p.nmask + p.init_latent * p.mask
+
+ return self.func(self.model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=mask_cb if p.mask is not None else None)
+
+
def sample(self, p: StableDiffusionProcessing, x, conditioning, unconditional_conditioning):
sigmas = self.model_wrap.get_sigmas(p.steps)
x = x * sigmas[0]
@@ -1246,39 +1295,20 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
self.original_mask = self.original_mask.filter(ImageFilter.GaussianBlur(self.mask_blur)).convert('L')
latmask = self.original_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2]))
- latmask = np.moveaxis(np.array(latmask, dtype=np.float), 2, 0) / 255
+ latmask = np.moveaxis(np.array(latmask, dtype=np.float64), 2, 0) / 255
latmask = latmask[0]
latmask = np.tile(latmask[None], (4, 1, 1))
self.mask = torch.asarray(1.0 - latmask).to(device).type(sd_model.dtype)
self.nmask = torch.asarray(latmask).to(device).type(sd_model.dtype)
-
-
def sample(self, x, conditioning, unconditional_conditioning):
- t_enc = int(min(self.denoising_strength, 0.999) * self.steps)
-
- sigmas = self.sampler.model_wrap.get_sigmas(self.steps)
- noise = x * sigmas[self.steps - t_enc - 1]
- xi = self.init_latent + noise
+ samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning)
if self.mask is not None:
- if self.inpainting_fill == 2:
- xi = xi * self.mask + noise * self.nmask
- elif self.inpainting_fill == 3:
- xi = xi * self.mask
+ samples = samples * self.nmask + self.init_latent * self.mask
- sigma_sched = sigmas[self.steps - t_enc - 1:]
-
- def mask_cb(v):
- v["denoised"][:] = v["denoised"][:] * self.nmask + self.init_latent * self.mask
-
- samples_ddim = self.sampler.func(self.sampler.model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': self.cfg_scale}, disable=False, callback=mask_cb if self.mask is not None else None)
-
- if self.mask is not None:
- samples_ddim = samples_ddim * self.nmask + self.init_latent * self.mask
-
- return samples_ddim
+ return samples
def img2img(prompt: str, init_img, init_img_with_mask, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, use_GFPGAN: bool, prompt_matrix, mode: int, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int):
@@ -1544,6 +1574,7 @@ def run_extras(image, GFPGAN_strength, RealESRGAN_upscaling, RealESRGAN_model_in
if have_realesrgan and RealESRGAN_upscaling != 1.0:
image = upscale_with_realesrgan(image, RealESRGAN_upscaling, RealESRGAN_model_index)
+ os.makedirs(outpath, exist_ok=True)
base_count = len(os.listdir(outpath))
save_image(image, outpath, f"{base_count:05}", None, '', opts.samples_format, short_filename=True)