aboutsummaryrefslogtreecommitdiff
path: root/modules/processing.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/processing.py')
-rw-r--r--modules/processing.py41
1 files changed, 25 insertions, 16 deletions
diff --git a/modules/processing.py b/modules/processing.py
index 9c3673de..bc541e2f 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -13,7 +13,7 @@ from skimage import exposure
from typing import Any, Dict, List, Optional
import modules.sd_hijack
-from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks
+from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks, extra_networks
from modules.sd_hijack import model_hijack
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
@@ -140,6 +140,7 @@ class StableDiffusionProcessing:
self.override_settings = {k: v for k, v in (override_settings or {}).items() if k not in shared.restricted_opts}
self.override_settings_restore_afterwards = override_settings_restore_afterwards
self.is_using_inpainting_conditioning = False
+ self.disable_extra_networks = False
if not seed_enable_extras:
self.subseed = -1
@@ -438,9 +439,6 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
"Size": f"{p.width}x{p.height}",
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
"Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
- "Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name),
- "Hypernet hash": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.shorthash()),
- "Hypernet strength": (None if shared.loaded_hypernetwork is None or shared.opts.sd_hypernetwork_strength >= 1 else shared.opts.sd_hypernetwork_strength),
"Batch size": (None if p.batch_size < 2 else p.batch_size),
"Batch pos": (None if p.batch_size < 2 else position_in_batch),
"Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
@@ -468,14 +466,12 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
try:
for k, v in p.override_settings.items():
setattr(opts, k, v)
- if k == 'sd_hypernetwork':
- shared.reload_hypernetworks() # make onchange call for changing hypernet
if k == 'sd_model_checkpoint':
- sd_models.reload_model_weights() # make onchange call for changing SD model
+ sd_models.reload_model_weights()
if k == 'sd_vae':
- sd_vae.reload_vae_weights() # make onchange call for changing VAE
+ sd_vae.reload_vae_weights()
res = process_images_inner(p)
@@ -484,9 +480,11 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
if p.override_settings_restore_afterwards:
for k, v in stored_opts.items():
setattr(opts, k, v)
- if k == 'sd_hypernetwork': shared.reload_hypernetworks()
- if k == 'sd_model_checkpoint': sd_models.reload_model_weights()
- if k == 'sd_vae': sd_vae.reload_vae_weights()
+ if k == 'sd_model_checkpoint':
+ sd_models.reload_model_weights()
+
+ if k == 'sd_vae':
+ sd_vae.reload_vae_weights()
return res
@@ -535,13 +533,11 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings:
model_hijack.embedding_db.load_textual_inversion_embeddings()
+ _, extra_network_data = extra_networks.parse_prompts(p.all_prompts[0:1])
+
if p.scripts is not None:
p.scripts.process(p)
- with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file:
- processed = Processed(p, [], p.seed, "")
- file.write(processed.infotext(p, 0))
-
infotexts = []
output_images = []
@@ -572,6 +568,13 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
with devices.autocast():
p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
+ if not p.disable_extra_networks:
+ extra_networks.activate(p, extra_network_data)
+
+ with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file:
+ processed = Processed(p, [], p.seed, "")
+ file.write(processed.infotext(p, 0))
+
if state.job_count == -1:
state.job_count = p.n_iter
@@ -592,6 +595,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if len(prompts) == 0:
break
+ prompts, _ = extra_networks.parse_prompts(prompts)
+
if p.scripts is not None:
p.scripts.process_batch(p, batch_number=n, prompts=prompts, seeds=seeds, subseeds=subseeds)
@@ -681,6 +686,9 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if opts.grid_save:
images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
+ if not p.disable_extra_networks:
+ extra_networks.deactivate(p, extra_network_data)
+
devices.torch_gc()
res = Processed(p, output_images, p.all_seeds[0], infotext(), comments="".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], index_of_first_image=index_of_first_image, infotexts=infotexts)
@@ -857,7 +865,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
shared.state.nextjob()
- self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
+ img2img_sampler_name = self.sampler_name if self.sampler_name != 'PLMS' else 'DDIM' # PLMS does not support img2img so we just silently switch ot DDIM
+ self.sampler = sd_samplers.create_sampler(img2img_sampler_name, self.sd_model)
samples = samples[:, :, self.truncate_y//2:samples.shape[2]-(self.truncate_y+1)//2, self.truncate_x//2:samples.shape[3]-(self.truncate_x+1)//2]