aboutsummaryrefslogtreecommitdiff
path: root/modules/hypernetworks/hypernetwork.py
diff options
context:
space:
mode:
authordiscus0434 <66945496+discus0434@users.noreply.github.com>2022-10-19 15:18:45 +0900
committerGitHub <noreply@github.com>2022-10-19 15:18:45 +0900
commit7f8670c4ef71440f690824d4b9bd432cc2926a3e (patch)
tree8cd53d62296d5bfe92e4581eb5c197588ccd1257 /modules/hypernetworks/hypernetwork.py
parent5d16f5979434bc8ee7f0301b3d6de74ac99a6b3f (diff)
parentda72becb13e4b750fbcb3d158c3f843311ef9938 (diff)
Merge branch 'master' into master
Diffstat (limited to 'modules/hypernetworks/hypernetwork.py')
-rw-r--r--modules/hypernetworks/hypernetwork.py5
1 files changed, 2 insertions, 3 deletions
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 082165f4..583ada31 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -244,7 +244,7 @@ def stack_conds(conds):
return torch.stack(conds)
-def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
+def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
assert hypernetwork_name, 'hypernetwork not selected'
path = shared.hypernetworks.get(hypernetwork_name, None)
@@ -273,8 +273,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"):
- ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=512, height=512, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size)
- assert ds.length > 1, "Dataset should contain more than 1 images"
+ ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size)
if unload:
shared.sd_model.cond_stage_model.to(devices.cpu)
shared.sd_model.first_stage_model.to(devices.cpu)