aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--modules/hypernetworks/hypernetwork.py38
-rw-r--r--modules/textual_inversion/textual_inversion.py48
2 files changed, 58 insertions, 28 deletions
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 2e84583b..38f35c58 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -332,7 +332,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
from modules import images
- assert hypernetwork_name, 'hypernetwork not selected'
+ save_hypernetwork_every = save_hypernetwork_every or 0
+ create_image_every = create_image_every or 0
+ textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, data_root, template_file, steps, save_hypernetwork_every, create_image_every, log_directory, name="hypernetwork")
path = shared.hypernetworks.get(hypernetwork_name, None)
shared.loaded_hypernetwork = Hypernetwork()
@@ -358,39 +360,43 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
else:
images_dir = None
+ hypernetwork = shared.loaded_hypernetwork
+
+ ititial_step = hypernetwork.step or 0
+ if ititial_step > steps:
+ shared.state.textinfo = f"Model has already been trained beyond specified max steps"
+ return hypernetwork, filename
+
+ scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
+
+ # dataset loading may take a while, so input validations and early returns should be done before this
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"):
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size)
+
if unload:
shared.sd_model.cond_stage_model.to(devices.cpu)
shared.sd_model.first_stage_model.to(devices.cpu)
- hypernetwork = shared.loaded_hypernetwork
- weights = hypernetwork.weights()
- for weight in weights:
- weight.requires_grad = True
-
size = len(ds.indexes)
loss_dict = defaultdict(lambda : deque(maxlen = 1024))
losses = torch.zeros((size,))
previous_mean_losses = [0]
previous_mean_loss = 0
print("Mean loss of {} elements".format(size))
-
- last_saved_file = "<none>"
- last_saved_image = "<none>"
- forced_filename = "<none>"
-
- ititial_step = hypernetwork.step or 0
- if ititial_step > steps:
- return hypernetwork, filename
-
- scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
+
+ weights = hypernetwork.weights()
+ for weight in weights:
+ weight.requires_grad = True
# if optimizer == "AdamW": or else Adam / AdamW / SGD, etc...
optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate)
steps_without_grad = 0
+ last_saved_file = "<none>"
+ last_saved_image = "<none>"
+ forced_filename = "<none>"
+
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
for i, entries in pbar:
hypernetwork.step = i + ititial_step
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 17dfb223..44f06443 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -204,9 +204,30 @@ def write_loss(log_directory, filename, step, epoch_len, values):
**values,
})
+def validate_train_inputs(model_name, learn_rate, batch_size, data_root, template_file, steps, save_model_every, create_image_every, log_directory, name="embedding"):
+ assert model_name, f"{name} not selected"
+ assert learn_rate, "Learning rate is empty or 0"
+ assert isinstance(batch_size, int), "Batch size must be integer"
+ assert batch_size > 0, "Batch size must be positive"
+ assert data_root, "Dataset directory is empty"
+ assert os.path.isdir(data_root), "Dataset directory doesn't exist"
+ assert os.listdir(data_root), "Dataset directory is empty"
+ assert template_file, "Prompt template file is empty"
+ assert os.path.isfile(template_file), "Prompt template file doesn't exist"
+ assert steps, "Max steps is empty or 0"
+ assert isinstance(steps, int), "Max steps must be integer"
+ assert steps > 0 , "Max steps must be positive"
+ assert isinstance(save_model_every, int), "Save {name} must be integer"
+ assert save_model_every >= 0 , "Save {name} must be positive or 0"
+ assert isinstance(create_image_every, int), "Create image must be integer"
+ assert create_image_every >= 0 , "Create image must be positive or 0"
+ if save_model_every or create_image_every:
+ assert log_directory, "Log directory is empty"
def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
- assert embedding_name, 'embedding not selected'
+ save_embedding_every = save_embedding_every or 0
+ create_image_every = create_image_every or 0
+ validate_train_inputs(embedding_name, learn_rate, batch_size, data_root, template_file, steps, save_embedding_every, create_image_every, log_directory, name="embedding")
shared.state.textinfo = "Initializing textual inversion training..."
shared.state.job_count = steps
@@ -232,17 +253,27 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
os.makedirs(images_embeds_dir, exist_ok=True)
else:
images_embeds_dir = None
-
+
cond_model = shared.sd_model.cond_stage_model
+ hijack = sd_hijack.model_hijack
+
+ embedding = hijack.embedding_db.word_embeddings[embedding_name]
+
+ ititial_step = embedding.step or 0
+ if ititial_step > steps:
+ shared.state.textinfo = f"Model has already been trained beyond specified max steps"
+ return embedding, filename
+
+ scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
+
+ # dataset loading may take a while, so input validations and early returns should be done before this
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"):
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file, batch_size=batch_size)
- hijack = sd_hijack.model_hijack
-
- embedding = hijack.embedding_db.word_embeddings[embedding_name]
embedding.vec.requires_grad = True
+ optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate)
losses = torch.zeros((32,))
@@ -251,13 +282,6 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
forced_filename = "<none>"
embedding_yet_to_be_embedded = False
- ititial_step = embedding.step or 0
- if ititial_step > steps:
- return embedding, filename
-
- scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
- optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate)
-
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
for i, entries in pbar:
embedding.step = i + ititial_step