aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2023-01-13 14:58:03 +0300
committerGitHub <noreply@github.com>2023-01-13 14:58:03 +0300
commit1849f6eb806f637f783b3beee3b48772da1cfab1 (patch)
tree345be78dd1991b77fcf4519bc44097e975e0b0c4
parent544e7a233e994f379dd67df08f5f519290b10293 (diff)
parent9cd7716753c5be47f76b8e5555cc3e7c0f17d34d (diff)
Merge pull request #3264 from Melanpan/tensorboard
Add support for Tensorboard (training)
-rw-r--r--modules/hypernetworks/hypernetwork.py15
-rw-r--r--modules/shared.py3
-rw-r--r--modules/textual_inversion/textual_inversion.py31
3 files changed, 48 insertions, 1 deletions
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 194679e8..83cbb4f0 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -24,7 +24,6 @@ from statistics import stdev, mean
optimizer_dict = {optim_name : cls_obj for optim_name, cls_obj in inspect.getmembers(torch.optim, inspect.isclass) if optim_name != "Optimizer"}
-
class HypernetworkModule(torch.nn.Module):
multiplier = 1.0
activation_dict = {
@@ -498,6 +497,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
if clip_grad:
clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, initial_step, verbose=False)
+ if shared.opts.training_enable_tensorboard:
+ tensorboard_writer = textual_inversion.tensorboard_setup(log_directory)
+
# dataset loading may take a while, so input validations and early returns should be done before this
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
@@ -632,6 +634,14 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file)
hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
+
+
+ if shared.opts.training_enable_tensorboard:
+ epoch_num = hypernetwork.step // len(ds)
+ epoch_step = hypernetwork.step - (epoch_num * len(ds)) + 1
+
+ textual_inversion.tensorboard_add(tensorboard_writer, loss=mean_loss, global_step=hypernetwork.step, step=epoch_step, learn_rate=scheduler.learn_rate, epoch_num=epoch_num)
+
textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, steps_per_epoch, {
"loss": f"{loss_step:.7f}",
"learn_rate": scheduler.learn_rate
@@ -673,6 +683,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
processed = processing.process_images(p)
image = processed.images[0] if len(processed.images) > 0 else None
+
+ if shared.opts.training_enable_tensorboard and shared.opts.training_tensorboard_save_images:
+ textual_inversion.tensorboard_add_image(tensorboard_writer, f"Validation at epoch {epoch_num}", image, hypernetwork.step)
if unload:
shared.sd_model.cond_stage_model.to(devices.cpu)
diff --git a/modules/shared.py b/modules/shared.py
index 1c964237..b90ded52 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -373,6 +373,9 @@ options_templates.update(options_section(('training', "Training"), {
"training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
"training_write_csv_every": OptionInfo(500, "Save an csv containing the loss to log directory every N steps, 0 to disable"),
"training_xattention_optimizations": OptionInfo(False, "Use cross attention optimizations while training"),
+ "training_enable_tensorboard": OptionInfo(False, "Enable tensorboard logging."),
+ "training_tensorboard_save_images": OptionInfo(False, "Save generated images within tensorboard."),
+ "training_tensorboard_flush_every": OptionInfo(120, "How often, in seconds, to flush the pending tensorboard events and summaries to disk."),
}))
options_templates.update(options_section(('sd', "Stable Diffusion"), {
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index e23906ca..85210b0e 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -12,6 +12,7 @@ import csv
import safetensors.torch
from PIL import Image, PngImagePlugin
+from torch.utils.tensorboard import SummaryWriter
from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers
import modules.textual_inversion.dataset
@@ -294,6 +295,30 @@ def write_loss(log_directory, filename, step, epoch_len, values):
**values,
})
+def tensorboard_setup(log_directory):
+ os.makedirs(os.path.join(log_directory, "tensorboard"), exist_ok=True)
+ return SummaryWriter(
+ log_dir=os.path.join(log_directory, "tensorboard"),
+ flush_secs=shared.opts.training_tensorboard_flush_every)
+
+def tensorboard_add(tensorboard_writer, loss, global_step, step, learn_rate, epoch_num):
+ tensorboard_add_scaler(tensorboard_writer, "Loss/train", loss, global_step)
+ tensorboard_add_scaler(tensorboard_writer, f"Loss/train/epoch-{epoch_num}", loss, step)
+ tensorboard_add_scaler(tensorboard_writer, "Learn rate/train", learn_rate, global_step)
+ tensorboard_add_scaler(tensorboard_writer, f"Learn rate/train/epoch-{epoch_num}", learn_rate, step)
+
+def tensorboard_add_scaler(tensorboard_writer, tag, value, step):
+ tensorboard_writer.add_scalar(tag=tag,
+ scalar_value=value, global_step=step)
+
+def tensorboard_add_image(tensorboard_writer, tag, pil_image, step):
+ # Convert a pil image to a torch tensor
+ img_tensor = torch.as_tensor(np.array(pil_image, copy=True))
+ img_tensor = img_tensor.view(pil_image.size[1], pil_image.size[0],
+ len(pil_image.getbands()))
+ img_tensor = img_tensor.permute((2, 0, 1))
+
+ tensorboard_writer.add_image(tag, img_tensor, global_step=step)
def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_model_every, create_image_every, log_directory, name="embedding"):
assert model_name, f"{name} not selected"
@@ -372,6 +397,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
# dataset loading may take a while, so input validations and early returns should be done before this
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
old_parallel_processing_allowed = shared.parallel_processing_allowed
+
+ if shared.opts.training_enable_tensorboard:
+ tensorboard_writer = tensorboard_setup(log_directory)
pin_memory = shared.opts.pin_memory
@@ -535,6 +563,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
last_saved_image += f", prompt: {preview_text}"
+ if shared.opts.training_enable_tensorboard and shared.opts.training_tensorboard_save_images:
+ tensorboard_add_image(tensorboard_writer, f"Validation at epoch {epoch_num}", image, embedding.step)
+
if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{steps_done}.png')