aboutsummaryrefslogtreecommitdiff
path: root/modules/textual_inversion/textual_inversion.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/textual_inversion/textual_inversion.py')
-rw-r--r--modules/textual_inversion/textual_inversion.py11
1 files changed, 8 insertions, 3 deletions
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 04dda585..6d815c0b 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -11,7 +11,6 @@ import safetensors.torch
import numpy as np
from PIL import Image, PngImagePlugin
-from torch.utils.tensorboard import SummaryWriter
from modules import shared, devices, sd_hijack, sd_models, images, sd_samplers, sd_hijack_checkpoint, errors, hashes
import modules.textual_inversion.dataset
@@ -151,6 +150,7 @@ class EmbeddingDatabase:
return embedding
def get_expected_shape(self):
+ devices.torch_npu_set_device()
vec = shared.sd_model.cond_stage_model.encode_embedding_init_text(",", 1)
return vec.shape[1]
@@ -344,6 +344,7 @@ def write_loss(log_directory, filename, step, epoch_len, values):
})
def tensorboard_setup(log_directory):
+ from torch.utils.tensorboard import SummaryWriter
os.makedirs(os.path.join(log_directory, "tensorboard"), exist_ok=True)
return SummaryWriter(
log_dir=os.path.join(log_directory, "tensorboard"),
@@ -448,8 +449,12 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
old_parallel_processing_allowed = shared.parallel_processing_allowed
+ tensorboard_writer = None
if shared.opts.training_enable_tensorboard:
- tensorboard_writer = tensorboard_setup(log_directory)
+ try:
+ tensorboard_writer = tensorboard_setup(log_directory)
+ except ImportError:
+ errors.report("Error initializing tensorboard", exc_info=True)
pin_memory = shared.opts.pin_memory
@@ -622,7 +627,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
last_saved_image += f", prompt: {preview_text}"
- if shared.opts.training_enable_tensorboard and shared.opts.training_tensorboard_save_images:
+ if tensorboard_writer and shared.opts.training_tensorboard_save_images:
tensorboard_add_image(tensorboard_writer, f"Validation at epoch {epoch_num}", image, embedding.step)
if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded: