aboutsummaryrefslogtreecommitdiff
path: root/modules/hypernetworks/hypernetwork.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/hypernetworks/hypernetwork.py')
-rw-r--r--modules/hypernetworks/hypernetwork.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 70f1cbd2..be3e4648 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -468,7 +468,7 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None,
shared.reload_hypernetworks()
-def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, use_weight, create_image_every, save_hypernetwork_every, template_filename, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
+def train_hypernetwork(id_task, hypernetwork_name: str, learn_rate: float, batch_size: int, gradient_step: int, data_root: str, log_directory: str, training_width: int, training_height: int, varsize: bool, steps: int, clip_grad_mode: str, clip_grad_value: float, shuffle_tags: bool, tag_drop_out: bool, latent_sampling_method: str, use_weight: bool, create_image_every: int, save_hypernetwork_every: int, template_filename: str, preview_from_txt2img: bool, preview_prompt: str, preview_negative_prompt: str, preview_steps: int, preview_sampler_name: str, preview_cfg_scale: float, preview_seed: int, preview_width: int, preview_height: int):
from modules import images, processing
save_hypernetwork_every = save_hypernetwork_every or 0
@@ -698,7 +698,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
p.prompt = preview_prompt
p.negative_prompt = preview_negative_prompt
p.steps = preview_steps
- p.sampler_name = sd_samplers.samplers[preview_sampler_index].name
+ p.sampler_name = sd_samplers.samplers_map[preview_sampler_name.lower()]
p.cfg_scale = preview_cfg_scale
p.seed = preview_seed
p.width = preview_width