aboutsummaryrefslogtreecommitdiff
path: root/modules/processing.py
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2023-08-10 17:05:32 +0300
committerGitHub <noreply@github.com>2023-08-10 17:05:32 +0300
commit36762f0eaf04c270dde23849cb198446ecdc4100 (patch)
tree879b63e94d986f8d4fb30d65ee5aa4ae45f3e640 /modules/processing.py
parent959404e0e29531d24f2e02088bf0399f4b9db15b (diff)
parentac8a5d18d3ede6bcb8fa5a3da1c7c28e064cd65d (diff)
Merge pull request #12371 from AUTOMATIC1111/refiner
initial refiner support
Diffstat (limited to 'modules/processing.py')
-rw-r--r--modules/processing.py16
1 files changed, 16 insertions, 0 deletions
diff --git a/modules/processing.py b/modules/processing.py
index 44d47e8c..efa6eafa 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -377,6 +377,9 @@ class StableDiffusionProcessing:
self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, self.steps * self.step_multiplier, [self.cached_uc], self.extra_network_data)
self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, self.steps * self.step_multiplier, [self.cached_c], self.extra_network_data)
+ def get_conds(self):
+ return self.c, self.uc
+
def parse_extra_network_prompts(self):
self.prompts, self.extra_network_data = extra_networks.parse_prompts(self.prompts)
@@ -611,6 +614,10 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
stored_opts = {k: opts.data[k] for k in p.override_settings.keys()}
try:
+ # after running refiner, the refiner model is not unloaded - webui swaps back to main model here
+ if shared.sd_model.sd_checkpoint_info.title != opts.sd_model_checkpoint:
+ sd_models.reload_model_weights()
+
# if no checkpoint override or the override checkpoint can't be found, remove override entry and load opts checkpoint
if sd_models.checkpoint_aliases.get(p.override_settings.get('sd_model_checkpoint')) is None:
p.override_settings.pop('sd_model_checkpoint', None)
@@ -710,6 +717,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if state.interrupted:
break
+ sd_models.reload_model_weights() # model can be changed for example by refiner
+
p.prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
p.negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size]
p.seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
@@ -1201,6 +1210,13 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
with devices.autocast():
extra_networks.activate(self, self.extra_network_data)
+ def get_conds(self):
+ if self.is_hr_pass:
+ return self.hr_c, self.hr_uc
+
+ return super().get_conds()
+
+
def parse_extra_network_prompts(self):
res = super().parse_extra_network_prompts()