aboutsummaryrefslogtreecommitdiff
path: root/modules/textual_inversion
diff options
context:
space:
mode:
Diffstat (limited to 'modules/textual_inversion')
-rw-r--r--modules/textual_inversion/textual_inversion.py25
1 files changed, 18 insertions, 7 deletions
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index fa0e33a2..3d835358 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -172,7 +172,7 @@ def create_embedding(name, num_vectors_per_token, init_text='*'):
return fn
-def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_image_prompt):
+def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
assert embedding_name, 'embedding not selected'
shared.state.textinfo = "Initializing textual inversion training..."
@@ -259,18 +259,29 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0:
last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png')
- preview_text = entry.cond_text if preview_image_prompt == "" else preview_image_prompt
-
p = processing.StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
- prompt=preview_text,
- steps=20,
- height=training_height,
- width=training_width,
do_not_save_grid=True,
do_not_save_samples=True,
)
+ if preview_from_txt2img:
+ p.prompt = preview_prompt
+ p.negative_prompt = preview_negative_prompt
+ p.steps = preview_steps
+ p.sampler_index = preview_sampler_index
+ p.cfg_scale = preview_cfg_scale
+ p.seed = preview_seed
+ p.width = preview_width
+ p.height = preview_height
+ else:
+ p.prompt = entry.cond_text
+ p.steps = 20
+ p.width = training_width
+ p.height = training_height
+
+ preview_text = p.prompt
+
processed = processing.process_images(p)
image = processed.images[0]