aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--modules/processing.py18
-rw-r--r--modules/sd_samplers_common.py13
-rw-r--r--modules/shared_options.py1
3 files changed, 18 insertions, 14 deletions
diff --git a/modules/processing.py b/modules/processing.py
index 0138e5ac..f696e925 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -1148,18 +1148,12 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
else:
decoded_samples = None
- current = shared.sd_model.sd_checkpoint_info
- try:
- if self.hr_checkpoint_info is not None:
- self.sampler = None
- sd_models.reload_model_weights(info=self.hr_checkpoint_info)
- devices.torch_gc()
-
- return self.sample_hr_pass(samples, decoded_samples, seeds, subseeds, subseed_strength, prompts)
- finally:
- self.sampler = None
- sd_models.reload_model_weights(info=current)
- devices.torch_gc()
+ with sd_models.SkipWritingToConfig():
+ sd_models.reload_model_weights(info=self.hr_checkpoint_info)
+
+ devices.torch_gc()
+
+ return self.sample_hr_pass(samples, decoded_samples, seeds, subseeds, subseed_strength, prompts)
def sample_hr_pass(self, samples, decoded_samples, seeds, subseeds, subseed_strength, prompts):
if shared.state.interrupted:
diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py
index 60fa161c..6c935a38 100644
--- a/modules/sd_samplers_common.py
+++ b/modules/sd_samplers_common.py
@@ -164,8 +164,17 @@ def apply_refiner(cfg_denoiser):
if refiner_checkpoint_info is None or shared.sd_model.sd_checkpoint_info == refiner_checkpoint_info:
return False
- if getattr(cfg_denoiser.p, "enable_hr", False) and not cfg_denoiser.p.is_hr_pass:
- return False
+ if getattr(cfg_denoiser.p, "enable_hr", False):
+ is_second_pass = cfg_denoiser.p.is_hr_pass
+
+ if opts.hires_fix_refiner_pass == "first pass" and is_second_pass:
+ return False
+
+ if opts.hires_fix_refiner_pass == "second pass" and not is_second_pass:
+ return False
+
+ if opts.hires_fix_refiner_pass != "second pass":
+ cfg_denoiser.p.extra_generation_params['Hires refiner'] = opts.hires_fix_refiner_pass
cfg_denoiser.p.extra_generation_params['Refiner'] = refiner_checkpoint_info.short_title
cfg_denoiser.p.extra_generation_params['Refiner switch at'] = refiner_switch_at
diff --git a/modules/shared_options.py b/modules/shared_options.py
index 78652ea2..00b273fa 100644
--- a/modules/shared_options.py
+++ b/modules/shared_options.py
@@ -146,6 +146,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
"upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"),
"randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU", "NV"]}, infotext="RNG").info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors; use NV to produce same picture as on NVidia videocards"),
"tiling": OptionInfo(False, "Tiling", infotext='Tiling').info("produce a tileable picture"),
+ "hires_fix_refiner_pass": OptionInfo("second pass", "Hires fix: which pass to enable refiner for", gr.Radio, {"choices": ["first pass", "second pass", "both passes"]}, infotext="Hires refiner"),
}))
options_templates.update(options_section(('sdxl', "Stable Diffusion XL"), {