aboutsummaryrefslogtreecommitdiff
path: root/modules/processing.py
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2023-08-24 09:41:16 +0300
committerGitHub <noreply@github.com>2023-08-24 09:41:16 +0300
commit0027ce1f6e58c5f4279173e468a60aef420babfe (patch)
treece1c72dcfb95af92bc666526c123ae99078dccb6 /modules/processing.py
parent06f18186dcb4d866cf8e80ff69a032de7aff7ab5 (diff)
parent99ab3d43a71e3f66e57d3cd2013b97c97e7ab69b (diff)
Merge pull request #12457 from rubberbaron/shared-hires-prompt-test
prompt editing timeline has separate range for first pass and hires-fix pass
Diffstat (limited to 'modules/processing.py')
-rw-r--r--[-rwxr-xr-x]modules/processing.py16
1 files changed, 10 insertions, 6 deletions
diff --git a/modules/processing.py b/modules/processing.py
index e60cc92b..066351c1 100755..100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -407,12 +407,14 @@ class StableDiffusionProcessing:
self.main_prompt = self.all_prompts[0]
self.main_negative_prompt = self.all_negative_prompts[0]
- def cached_params(self, required_prompts, steps, extra_network_data):
+ def cached_params(self, required_prompts, steps, extra_network_data, hires_steps=None, use_old_scheduling=False):
"""Returns parameters that invalidate the cond cache if changed"""
return (
required_prompts,
steps,
+ hires_steps,
+ use_old_scheduling,
opts.CLIP_stop_at_last_layers,
shared.sd_model.sd_checkpoint_info,
extra_network_data,
@@ -422,7 +424,7 @@ class StableDiffusionProcessing:
self.height,
)
- def get_conds_with_caching(self, function, required_prompts, steps, caches, extra_network_data):
+ def get_conds_with_caching(self, function, required_prompts, steps, caches, extra_network_data, hires_steps=None):
"""
Returns the result of calling function(shared.sd_model, required_prompts, steps)
using a cache to store the result if the same arguments have been used before.
@@ -435,7 +437,7 @@ class StableDiffusionProcessing:
caches is a list with items described above.
"""
- cached_params = self.cached_params(required_prompts, steps, extra_network_data)
+ cached_params = self.cached_params(required_prompts, steps, extra_network_data, hires_steps, shared.opts.use_old_scheduling)
for cache in caches:
if cache[0] is not None and cached_params == cache[0]:
@@ -444,7 +446,7 @@ class StableDiffusionProcessing:
cache = caches[0]
with devices.autocast():
- cache[1] = function(shared.sd_model, required_prompts, steps)
+ cache[1] = function(shared.sd_model, required_prompts, steps, hires_steps, shared.opts.use_old_scheduling)
cache[0] = cached_params
return cache[1]
@@ -456,6 +458,8 @@ class StableDiffusionProcessing:
sampler_config = sd_samplers.find_sampler_config(self.sampler_name)
total_steps = sampler_config.total_steps(self.steps) if sampler_config else self.steps
self.step_multiplier = total_steps // self.steps
+ self.firstpass_steps = total_steps
+
self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, total_steps, [self.cached_uc], self.extra_network_data)
self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, total_steps, [self.cached_c], self.extra_network_data)
@@ -1292,8 +1296,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
steps = self.hr_second_pass_steps or self.steps
total_steps = sampler_config.total_steps(steps) if sampler_config else steps
- self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, hr_negative_prompts, total_steps, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data)
- self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, hr_prompts, total_steps, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data)
+ self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, hr_negative_prompts, self.firstpass_steps, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data, total_steps)
+ self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, hr_prompts, self.firstpass_steps, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data, total_steps)
def setup_conds(self):
if self.is_hr_pass: