aboutsummaryrefslogtreecommitdiff
path: root/webui.py
diff options
context:
space:
mode:
Diffstat (limited to 'webui.py')
-rw-r--r--webui.py137
1 files changed, 69 insertions, 68 deletions
diff --git a/webui.py b/webui.py
index 1a2fa56c..6f8efa84 100644
--- a/webui.py
+++ b/webui.py
@@ -106,6 +106,30 @@ class CFGDenoiser(nn.Module):
return uncond + (cond - uncond) * cond_scale
+class KDiffusionSampler:
+ def __init__(self, m):
+ self.model = m
+ self.model_wrap = K.external.CompVisDenoiser(m)
+
+ def sample(self, S, conditioning, batch_size, shape, verbose, unconditional_guidance_scale, unconditional_conditioning, eta, x_T):
+ sigmas = self.model_wrap.get_sigmas(S)
+ x = x_T * sigmas[0]
+ model_wrap_cfg = CFGDenoiser(self.model_wrap)
+ samples_ddim = K.sampling.sample_lms(model_wrap_cfg, x, sigmas, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': unconditional_guidance_scale}, disable=False)
+
+ return samples_ddim, None
+
+
+def create_random_tensors(seed, shape, count, same_seed=False):
+ xs = []
+ for i in range(count):
+ current_seed = seed if same_seed else seed + i
+ torch.manual_seed(current_seed)
+ xs.append(torch.randn(shape, device=device))
+ x = torch.stack(xs)
+ return x
+
+
def load_GFPGAN():
model_name = 'GFPGANv1.3'
model_path = os.path.join(GFPGAN_dir, 'experiments/pretrained_models', model_name + '.pth')
@@ -166,22 +190,15 @@ def dream(prompt: str, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, pro
seed = int(seed)
keep_same_seed = False
- is_PLMS = sampler_name == 'PLMS'
- is_DDIM = sampler_name == 'DDIM'
- is_Kdif = sampler_name == 'k-diffusion'
-
- sampler = None
- if is_PLMS:
+ if sampler_name == 'PLMS':
sampler = PLMSSampler(model)
- elif is_DDIM:
+ elif sampler_name == 'DDIM':
sampler = DDIMSampler(model)
- elif is_Kdif:
- pass
+ elif sampler_name == 'k-diffusion':
+ sampler = KDiffusionSampler(model)
else:
raise Exception("Unknown sampler: " + sampler_name)
- model_wrap = K.external.CompVisDenoiser(model)
-
os.makedirs(outpath, exist_ok=True)
batch_size = n_samples
@@ -238,21 +255,9 @@ def dream(prompt: str, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, pro
batch_seed = seed if keep_same_seed else seed + n * len(prompts)
# we manually generate all input noises because each one should have a specific seed
- xs = []
- for i in range(len(prompts)):
- current_seed = seed if keep_same_seed else batch_seed + i
- torch.manual_seed(current_seed)
- xs.append(torch.randn(shape, device=device))
- x = torch.stack(xs)
-
- if is_Kdif:
- sigmas = model_wrap.get_sigmas(ddim_steps)
- x = x * sigmas[0]
- model_wrap_cfg = CFGDenoiser(model_wrap)
- samples_ddim = K.sampling.sample_lms(model_wrap_cfg, x, sigmas, extra_args={'cond': c, 'uncond': uc, 'cond_scale': cfg_scale}, disable=False)
-
- elif sampler is not None:
- samples_ddim, _ = sampler.sample(S=ddim_steps, conditioning=c, batch_size=len(prompts), shape=shape, verbose=False, unconditional_guidance_scale=cfg_scale, unconditional_conditioning=uc, eta=ddim_eta, x_T=x)
+ x = create_random_tensors(batch_seed, shape, count=len(prompts), same_seed=keep_same_seed)
+
+ samples_ddim, _ = sampler.sample(S=ddim_steps, conditioning=c, batch_size=len(prompts), shape=shape, verbose=False, unconditional_guidance_scale=cfg_scale, unconditional_conditioning=uc, eta=ddim_eta, x_T=x)
x_samples_ddim = model.decode_first_stage(samples_ddim)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
@@ -274,9 +279,6 @@ def dream(prompt: str, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, pro
output_images.append(image)
base_count += 1
-
-
-
if not opt.skip_grid:
# additionally, save as grid
grid = image_grid(output_images, batch_size, round_down=prompt_matrix)
@@ -380,13 +382,11 @@ def translation(prompt: str, init_img, ddim_steps: int, use_GFPGAN: bool, ddim_e
batch_size = n_samples
assert prompt is not None
- data = [batch_size * [prompt]]
sample_path = os.path.join(outpath, "samples")
os.makedirs(sample_path, exist_ok=True)
base_count = len(os.listdir(sample_path))
grid_count = len(os.listdir(outpath)) - 1
- seedit = 0
image = init_img.convert("RGB")
image = image.resize((width, height), resample=PIL.Image.Resampling.LANCZOS)
@@ -407,43 +407,44 @@ def translation(prompt: str, init_img, ddim_steps: int, use_GFPGAN: bool, ddim_e
t_enc = int(denoising_strength * ddim_steps)
for n in range(n_iter):
- for batch_index, prompts in enumerate(data):
- uc = None
- if cfg_scale != 1.0:
- uc = model.get_learned_conditioning(batch_size * [""])
- if isinstance(prompts, tuple):
- prompts = list(prompts)
- c = model.get_learned_conditioning(prompts)
-
- sigmas = model_wrap.get_sigmas(ddim_steps)
-
- current_seed = seed + n * len(data) + batch_index
- torch.manual_seed(current_seed)
-
- noise = torch.randn_like(x0) * sigmas[ddim_steps - t_enc - 1] # for GPU draw
- xi = x0 + noise
- sigma_sched = sigmas[ddim_steps - t_enc - 1:]
- model_wrap_cfg = CFGDenoiser(model_wrap)
- extra_args = {'cond': c, 'uncond': uc, 'cond_scale': cfg_scale}
-
- samples_ddim = K.sampling.sample_lms(model_wrap_cfg, xi, sigma_sched, extra_args=extra_args, disable=False)
- x_samples_ddim = model.decode_first_stage(samples_ddim)
- x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
-
- if not opt.skip_save or not opt.skip_grid:
- for x_sample in x_samples_ddim:
- x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
- x_sample = x_sample.astype(np.uint8)
-
- if use_GFPGAN and GFPGAN is not None:
- cropped_faces, restored_faces, restored_img = GFPGAN.enhance(x_sample, has_aligned=False, only_center_face=False, paste_back=True)
- x_sample = restored_img
-
- image = Image.fromarray(x_sample)
- image.save(os.path.join(sample_path, f"{base_count:05}-{current_seed}_{prompt.replace(' ', '_').translate({ord(x): '' for x in invalid_filename_chars})[:128]}.png"))
-
- output_images.append(image)
- base_count += 1
+ prompts = batch_size * [prompt]
+
+ uc = None
+ if cfg_scale != 1.0:
+ uc = model.get_learned_conditioning(batch_size * [""])
+ if isinstance(prompts, tuple):
+ prompts = list(prompts)
+ c = model.get_learned_conditioning(prompts)
+
+ batch_seed = seed + n * len(prompts)
+
+ sigmas = model_wrap.get_sigmas(ddim_steps)
+ noise = create_random_tensors(batch_seed, x0.shape[1:], count=len(prompts))
+ noise = noise * sigmas[ddim_steps - t_enc - 1]
+
+ xi = x0 + noise
+ sigma_sched = sigmas[ddim_steps - t_enc - 1:]
+ model_wrap_cfg = CFGDenoiser(model_wrap)
+ extra_args = {'cond': c, 'uncond': uc, 'cond_scale': cfg_scale}
+
+ samples_ddim = K.sampling.sample_lms(model_wrap_cfg, xi, sigma_sched, extra_args=extra_args, disable=False)
+ x_samples_ddim = model.decode_first_stage(samples_ddim)
+ x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
+
+ if not opt.skip_save or not opt.skip_grid:
+ for i, x_sample in enumerate(x_samples_ddim):
+ x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
+ x_sample = x_sample.astype(np.uint8)
+
+ if use_GFPGAN and GFPGAN is not None:
+ cropped_faces, restored_faces, restored_img = GFPGAN.enhance(x_sample, has_aligned=False, only_center_face=False, paste_back=True)
+ x_sample = restored_img
+
+ image = Image.fromarray(x_sample)
+ image.save(os.path.join(sample_path, f"{base_count:05}-{batch_seed+i}_{prompt.replace(' ', '_').translate({ord(x): '' for x in invalid_filename_chars})[:128]}.png"))
+
+ output_images.append(image)
+ base_count += 1
if not opt.skip_grid:
# additionally, save as grid