From 96d6ca4199e7c5eee8d451618de5161cea317c40 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 10 May 2023 08:25:25 +0300 Subject: manual fixes for ruff --- modules/models/diffusion/ddpm_edit.py | 26 ++++++++++---------------- modules/models/diffusion/uni_pc/sampler.py | 3 ++- 2 files changed, 12 insertions(+), 17 deletions(-) (limited to 'modules/models/diffusion') diff --git a/modules/models/diffusion/ddpm_edit.py b/modules/models/diffusion/ddpm_edit.py index f880bc3c..611c2b69 100644 --- a/modules/models/diffusion/ddpm_edit.py +++ b/modules/models/diffusion/ddpm_edit.py @@ -479,7 +479,7 @@ class LatentDiffusion(DDPM): self.cond_stage_key = cond_stage_key try: self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1 - except: + except Exception: self.num_downs = 0 if not scale_by_std: self.scale_factor = scale_factor @@ -891,16 +891,6 @@ class LatentDiffusion(DDPM): c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float())) return self.p_losses(x, c, t, *args, **kwargs) - def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: move to dataset - def rescale_bbox(bbox): - x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2]) - y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3]) - w = min(bbox[2] / crop_coordinates[2], 1 - x0) - h = min(bbox[3] / crop_coordinates[3], 1 - y0) - return x0, y0, w, h - - return [rescale_bbox(b) for b in bboxes] - def apply_model(self, x_noisy, t, cond, return_ids=False): if isinstance(cond, dict): @@ -1171,8 +1161,10 @@ class LatentDiffusion(DDPM): if i % log_every_t == 0 or i == timesteps - 1: intermediates.append(x0_partial) - if callback: callback(i) - if img_callback: img_callback(img, i) + if callback: + callback(i) + if img_callback: + img_callback(img, i) return img, intermediates @torch.no_grad() @@ -1219,8 +1211,10 @@ class LatentDiffusion(DDPM): if i % log_every_t == 0 or i == timesteps - 1: intermediates.append(img) - if callback: callback(i) - if img_callback: img_callback(img, i) + if callback: + callback(i) + if img_callback: + img_callback(img, i) if return_intermediates: return img, intermediates @@ -1337,7 +1331,7 @@ class LatentDiffusion(DDPM): if inpaint: # make a simple center square - b, h, w = z.shape[0], z.shape[2], z.shape[3] + h, w = z.shape[2], z.shape[3] mask = torch.ones(N, h, w).to(self.device) # zeros will be filled in mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0. diff --git a/modules/models/diffusion/uni_pc/sampler.py b/modules/models/diffusion/uni_pc/sampler.py index a241c8a7..0a9defa1 100644 --- a/modules/models/diffusion/uni_pc/sampler.py +++ b/modules/models/diffusion/uni_pc/sampler.py @@ -54,7 +54,8 @@ class UniPCSampler(object): if conditioning is not None: if isinstance(conditioning, dict): ctmp = conditioning[list(conditioning.keys())[0]] - while isinstance(ctmp, list): ctmp = ctmp[0] + while isinstance(ctmp, list): + ctmp = ctmp[0] cbs = ctmp.shape[0] if cbs != batch_size: print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") -- cgit v1.2.1 From f741a98baccae100fcfb40c017b5c35c5cba1b0c Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 10 May 2023 08:43:42 +0300 Subject: imports cleanup for ruff --- modules/models/diffusion/uni_pc/uni_pc.py | 1 - 1 file changed, 1 deletion(-) (limited to 'modules/models/diffusion') diff --git a/modules/models/diffusion/uni_pc/uni_pc.py b/modules/models/diffusion/uni_pc/uni_pc.py index 11b330bc..a4c4ef4e 100644 --- a/modules/models/diffusion/uni_pc/uni_pc.py +++ b/modules/models/diffusion/uni_pc/uni_pc.py @@ -1,5 +1,4 @@ import torch -import torch.nn.functional as F import math from tqdm.auto import trange -- cgit v1.2.1 From 4b854806d98cf5ccd48e5cd99c172613da7937f0 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 10 May 2023 09:02:23 +0300 Subject: F401 fixes for ruff --- modules/models/diffusion/uni_pc/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/models/diffusion') diff --git a/modules/models/diffusion/uni_pc/__init__.py b/modules/models/diffusion/uni_pc/__init__.py index e1265e3f..dbb35964 100644 --- a/modules/models/diffusion/uni_pc/__init__.py +++ b/modules/models/diffusion/uni_pc/__init__.py @@ -1 +1 @@ -from .sampler import UniPCSampler +from .sampler import UniPCSampler # noqa: F401 -- cgit v1.2.1 From 028d3f6425d85f122027c127fba8bcbf4f66ee75 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 10 May 2023 11:05:02 +0300 Subject: ruff auto fixes --- modules/models/diffusion/ddpm_edit.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules/models/diffusion') diff --git a/modules/models/diffusion/ddpm_edit.py b/modules/models/diffusion/ddpm_edit.py index 611c2b69..09432117 100644 --- a/modules/models/diffusion/ddpm_edit.py +++ b/modules/models/diffusion/ddpm_edit.py @@ -1130,7 +1130,7 @@ class LatentDiffusion(DDPM): if cond is not None: if isinstance(cond, dict): cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else - list(map(lambda x: x[:batch_size], cond[key])) for key in cond} + [x[:batch_size] for x in cond[key]] for key in cond} else: cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] @@ -1229,7 +1229,7 @@ class LatentDiffusion(DDPM): if cond is not None: if isinstance(cond, dict): cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else - list(map(lambda x: x[:batch_size], cond[key])) for key in cond} + [x[:batch_size] for x in cond[key]] for key in cond} else: cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] return self.p_sample_loop(cond, -- cgit v1.2.1 From 550256db1ce18778a9d56ff343d844c61b9f9b83 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 10 May 2023 11:19:16 +0300 Subject: ruff manual fixes --- modules/models/diffusion/ddpm_edit.py | 14 ++++++++------ modules/models/diffusion/uni_pc/uni_pc.py | 7 +++++-- 2 files changed, 13 insertions(+), 8 deletions(-) (limited to 'modules/models/diffusion') diff --git a/modules/models/diffusion/ddpm_edit.py b/modules/models/diffusion/ddpm_edit.py index 09432117..af4dea15 100644 --- a/modules/models/diffusion/ddpm_edit.py +++ b/modules/models/diffusion/ddpm_edit.py @@ -52,7 +52,7 @@ class DDPM(pl.LightningModule): beta_schedule="linear", loss_type="l2", ckpt_path=None, - ignore_keys=[], + ignore_keys=None, load_only_unet=False, monitor="val/loss", use_ema=True, @@ -107,7 +107,7 @@ class DDPM(pl.LightningModule): print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") if ckpt_path is not None: - self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet) + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys or [], only_model=load_only_unet) # If initialing from EMA-only checkpoint, create EMA model after loading. if self.use_ema and not load_ema: @@ -194,7 +194,9 @@ class DDPM(pl.LightningModule): if context is not None: print(f"{context}: Restored training weights") - def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): + def init_from_ckpt(self, path, ignore_keys=None, only_model=False): + ignore_keys = ignore_keys or [] + sd = torch.load(path, map_location="cpu") if "state_dict" in list(sd.keys()): sd = sd["state_dict"] @@ -473,7 +475,7 @@ class LatentDiffusion(DDPM): conditioning_key = None ckpt_path = kwargs.pop("ckpt_path", None) ignore_keys = kwargs.pop("ignore_keys", []) - super().__init__(conditioning_key=conditioning_key, *args, load_ema=load_ema, **kwargs) + super().__init__(*args, conditioning_key=conditioning_key, load_ema=load_ema, **kwargs) self.concat_mode = concat_mode self.cond_stage_trainable = cond_stage_trainable self.cond_stage_key = cond_stage_key @@ -1433,10 +1435,10 @@ class Layout2ImgDiffusion(LatentDiffusion): # TODO: move all layout-specific hacks to this class def __init__(self, cond_stage_key, *args, **kwargs): assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"' - super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs) + super().__init__(*args, cond_stage_key=cond_stage_key, **kwargs) def log_images(self, batch, N=8, *args, **kwargs): - logs = super().log_images(batch=batch, N=N, *args, **kwargs) + logs = super().log_images(*args, batch=batch, N=N, **kwargs) key = 'train' if self.training else 'validation' dset = self.trainer.datamodule.datasets[key] diff --git a/modules/models/diffusion/uni_pc/uni_pc.py b/modules/models/diffusion/uni_pc/uni_pc.py index a4c4ef4e..6f8ad631 100644 --- a/modules/models/diffusion/uni_pc/uni_pc.py +++ b/modules/models/diffusion/uni_pc/uni_pc.py @@ -178,13 +178,13 @@ def model_wrapper( model, noise_schedule, model_type="noise", - model_kwargs={}, + model_kwargs=None, guidance_type="uncond", #condition=None, #unconditional_condition=None, guidance_scale=1., classifier_fn=None, - classifier_kwargs={}, + classifier_kwargs=None, ): """Create a wrapper function for the noise prediction model. @@ -275,6 +275,9 @@ def model_wrapper( A noise prediction model that accepts the noised data and the continuous time as the inputs. """ + model_kwargs = model_kwargs or [] + classifier_kwargs = classifier_kwargs or [] + def get_model_input_time(t_continuous): """ Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time. -- cgit v1.2.1 From d25219b7e889cf34bccae9cb88497708796efda2 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 10 May 2023 11:55:09 +0300 Subject: manual fixes for some C408 --- modules/models/diffusion/ddpm_edit.py | 8 ++++---- modules/models/diffusion/uni_pc/uni_pc.py | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) (limited to 'modules/models/diffusion') diff --git a/modules/models/diffusion/ddpm_edit.py b/modules/models/diffusion/ddpm_edit.py index af4dea15..3fb76b65 100644 --- a/modules/models/diffusion/ddpm_edit.py +++ b/modules/models/diffusion/ddpm_edit.py @@ -405,7 +405,7 @@ class DDPM(pl.LightningModule): @torch.no_grad() def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs): - log = dict() + log = {} x = self.get_input(batch, self.first_stage_key) N = min(x.shape[0], N) n_row = min(x.shape[0], n_row) @@ -413,7 +413,7 @@ class DDPM(pl.LightningModule): log["inputs"] = x # get diffusion row - diffusion_row = list() + diffusion_row = [] x_start = x[:n_row] for t in range(self.num_timesteps): @@ -1263,7 +1263,7 @@ class LatentDiffusion(DDPM): use_ddim = False - log = dict() + log = {} z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, return_first_stage_outputs=True, force_c_encode=True, @@ -1291,7 +1291,7 @@ class LatentDiffusion(DDPM): if plot_diffusion_rows: # get diffusion row - diffusion_row = list() + diffusion_row = [] z_start = z[:n_row] for t in range(self.num_timesteps): if t % self.log_every_t == 0 or t == self.num_timesteps - 1: diff --git a/modules/models/diffusion/uni_pc/uni_pc.py b/modules/models/diffusion/uni_pc/uni_pc.py index 6f8ad631..f6c49f87 100644 --- a/modules/models/diffusion/uni_pc/uni_pc.py +++ b/modules/models/diffusion/uni_pc/uni_pc.py @@ -344,7 +344,7 @@ def model_wrapper( t_in = torch.cat([t_continuous] * 2) if isinstance(condition, dict): assert isinstance(unconditional_condition, dict) - c_in = dict() + c_in = {} for k in condition: if isinstance(condition[k], list): c_in[k] = [torch.cat([ @@ -355,7 +355,7 @@ def model_wrapper( unconditional_condition[k], condition[k]]) elif isinstance(condition, list): - c_in = list() + c_in = [] assert isinstance(unconditional_condition, list) for i in range(len(condition)): c_in.append(torch.cat([unconditional_condition[i], condition[i]])) -- cgit v1.2.1 From 3ec7b705c78b7aca9569c92a419837352c7a4ec6 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 10 May 2023 21:21:32 +0300 Subject: suggestions and fixes from the PR --- modules/models/diffusion/uni_pc/uni_pc.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules/models/diffusion') diff --git a/modules/models/diffusion/uni_pc/uni_pc.py b/modules/models/diffusion/uni_pc/uni_pc.py index f6c49f87..a227b947 100644 --- a/modules/models/diffusion/uni_pc/uni_pc.py +++ b/modules/models/diffusion/uni_pc/uni_pc.py @@ -275,8 +275,8 @@ def model_wrapper( A noise prediction model that accepts the noised data and the continuous time as the inputs. """ - model_kwargs = model_kwargs or [] - classifier_kwargs = classifier_kwargs or [] + model_kwargs = model_kwargs or {} + classifier_kwargs = classifier_kwargs or {} def get_model_input_time(t_continuous): """ -- cgit v1.2.1 From ae17e97898af8dd776b20e104ba9a81fe699e4df Mon Sep 17 00:00:00 2001 From: Sakura-Luna <53183413+Sakura-Luna@users.noreply.github.com> Date: Thu, 11 May 2023 12:26:04 +0800 Subject: UniPC progress bar adjustment --- modules/models/diffusion/uni_pc/uni_pc.py | 70 ++++++++++++++++--------------- 1 file changed, 37 insertions(+), 33 deletions(-) (limited to 'modules/models/diffusion') diff --git a/modules/models/diffusion/uni_pc/uni_pc.py b/modules/models/diffusion/uni_pc/uni_pc.py index eb5f4e76..1d1b07bd 100644 --- a/modules/models/diffusion/uni_pc/uni_pc.py +++ b/modules/models/diffusion/uni_pc/uni_pc.py @@ -1,7 +1,7 @@ import torch import torch.nn.functional as F import math -from tqdm.auto import trange +import tqdm class NoiseScheduleVP: @@ -757,40 +757,44 @@ class UniPC: vec_t = timesteps[0].expand((x.shape[0])) model_prev_list = [self.model_fn(x, vec_t)] t_prev_list = [vec_t] - # Init the first `order` values by lower order multistep DPM-Solver. - for init_order in range(1, order): - vec_t = timesteps[init_order].expand(x.shape[0]) - x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, init_order, use_corrector=True) - if model_x is None: - model_x = self.model_fn(x, vec_t) - if self.after_update is not None: - self.after_update(x, model_x) - model_prev_list.append(model_x) - t_prev_list.append(vec_t) - for step in trange(order, steps + 1): - vec_t = timesteps[step].expand(x.shape[0]) - if lower_order_final: - step_order = min(order, steps + 1 - step) - else: - step_order = order - #print('this step order:', step_order) - if step == steps: - #print('do not run corrector at the last step') - use_corrector = False - else: - use_corrector = True - x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, step_order, use_corrector=use_corrector) - if self.after_update is not None: - self.after_update(x, model_x) - for i in range(order - 1): - t_prev_list[i] = t_prev_list[i + 1] - model_prev_list[i] = model_prev_list[i + 1] - t_prev_list[-1] = vec_t - # We do not need to evaluate the final model value. - if step < steps: + with tqdm.tqdm(total=steps) as pbar: + # Init the first `order` values by lower order multistep DPM-Solver. + for init_order in range(1, order): + vec_t = timesteps[init_order].expand(x.shape[0]) + x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, init_order, use_corrector=True) if model_x is None: model_x = self.model_fn(x, vec_t) - model_prev_list[-1] = model_x + if self.after_update is not None: + self.after_update(x, model_x) + model_prev_list.append(model_x) + t_prev_list.append(vec_t) + pbar.update() + + for step in range(order, steps + 1): + vec_t = timesteps[step].expand(x.shape[0]) + if lower_order_final: + step_order = min(order, steps + 1 - step) + else: + step_order = order + #print('this step order:', step_order) + if step == steps: + #print('do not run corrector at the last step') + use_corrector = False + else: + use_corrector = True + x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, step_order, use_corrector=use_corrector) + if self.after_update is not None: + self.after_update(x, model_x) + for i in range(order - 1): + t_prev_list[i] = t_prev_list[i + 1] + model_prev_list[i] = model_prev_list[i + 1] + t_prev_list[-1] = vec_t + # We do not need to evaluate the final model value. + if step < steps: + if model_x is None: + model_x = self.model_fn(x, vec_t) + model_prev_list[-1] = model_x + pbar.update() else: raise NotImplementedError() if denoise_to_zero: -- cgit v1.2.1