From 762265eab58cdb8f2d6398769bab43d8b8db0075 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 10 May 2023 07:52:45 +0300 Subject: autofixes from ruff --- modules/textual_inversion/textual_inversion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/textual_inversion') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 4368eb63..f753b75f 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -603,7 +603,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st try: vectorSize = list(data['string_to_param'].values())[0].shape[0] - except Exception as e: + except Exception: vectorSize = '?' checkpoint = sd_models.select_checkpoint() -- cgit v1.2.1 From 96d6ca4199e7c5eee8d451618de5161cea317c40 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 10 May 2023 08:25:25 +0300 Subject: manual fixes for ruff --- modules/textual_inversion/autocrop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/textual_inversion') diff --git a/modules/textual_inversion/autocrop.py b/modules/textual_inversion/autocrop.py index ba1bdcd4..d7d8d2e3 100644 --- a/modules/textual_inversion/autocrop.py +++ b/modules/textual_inversion/autocrop.py @@ -185,7 +185,7 @@ def image_face_points(im, settings): try: faces = classifier.detectMultiScale(gray, scaleFactor=1.1, minNeighbors=7, minSize=(minsize, minsize), flags=cv2.CASCADE_SCALE_IMAGE) - except: + except Exception: continue if len(faces) > 0: -- cgit v1.2.1 From f741a98baccae100fcfb40c017b5c35c5cba1b0c Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 10 May 2023 08:43:42 +0300 Subject: imports cleanup for ruff --- modules/textual_inversion/autocrop.py | 4 +--- modules/textual_inversion/image_embedding.py | 2 +- modules/textual_inversion/preprocess.py | 4 ---- modules/textual_inversion/textual_inversion.py | 1 - 4 files changed, 2 insertions(+), 9 deletions(-) (limited to 'modules/textual_inversion') diff --git a/modules/textual_inversion/autocrop.py b/modules/textual_inversion/autocrop.py index d7d8d2e3..7770d22f 100644 --- a/modules/textual_inversion/autocrop.py +++ b/modules/textual_inversion/autocrop.py @@ -1,10 +1,8 @@ import cv2 import requests import os -from collections import defaultdict -from math import log, sqrt import numpy as np -from PIL import Image, ImageDraw +from PIL import ImageDraw GREEN = "#0F0" BLUE = "#00F" diff --git a/modules/textual_inversion/image_embedding.py b/modules/textual_inversion/image_embedding.py index 5593f88c..ee0e850a 100644 --- a/modules/textual_inversion/image_embedding.py +++ b/modules/textual_inversion/image_embedding.py @@ -2,7 +2,7 @@ import base64 import json import numpy as np import zlib -from PIL import Image, PngImagePlugin, ImageDraw, ImageFont +from PIL import Image, ImageDraw, ImageFont from fonts.ttf import Roboto import torch from modules.shared import opts diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py index da0bcb26..d0cad09e 100644 --- a/modules/textual_inversion/preprocess.py +++ b/modules/textual_inversion/preprocess.py @@ -1,13 +1,9 @@ import os from PIL import Image, ImageOps import math -import platform -import sys import tqdm -import time from modules import paths, shared, images, deepbooru -from modules.shared import opts, cmd_opts from modules.textual_inversion import autocrop diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index f753b75f..9ed9ba45 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -1,7 +1,6 @@ import os import sys import traceback -import inspect from collections import namedtuple import torch -- cgit v1.2.1 From 028d3f6425d85f122027c127fba8bcbf4f66ee75 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 10 May 2023 11:05:02 +0300 Subject: ruff auto fixes --- modules/textual_inversion/textual_inversion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/textual_inversion') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 9ed9ba45..c37bb2ad 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -167,7 +167,7 @@ class EmbeddingDatabase: if 'string_to_param' in data: param_dict = data['string_to_param'] if hasattr(param_dict, '_parameters'): - param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11 + param_dict = param_dict._parameters # fix for torch 1.12.1 loading saved file from torch 1.11 assert len(param_dict) == 1, 'embedding file has multiple terms in it' emb = next(iter(param_dict.items()))[1] # diffuser concepts -- cgit v1.2.1 From 550256db1ce18778a9d56ff343d844c61b9f9b83 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 10 May 2023 11:19:16 +0300 Subject: ruff manual fixes --- modules/textual_inversion/image_embedding.py | 2 +- modules/textual_inversion/learn_schedule.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) (limited to 'modules/textual_inversion') diff --git a/modules/textual_inversion/image_embedding.py b/modules/textual_inversion/image_embedding.py index ee0e850a..d85a4888 100644 --- a/modules/textual_inversion/image_embedding.py +++ b/modules/textual_inversion/image_embedding.py @@ -17,7 +17,7 @@ class EmbeddingEncoder(json.JSONEncoder): class EmbeddingDecoder(json.JSONDecoder): def __init__(self, *args, **kwargs): - json.JSONDecoder.__init__(self, object_hook=self.object_hook, *args, **kwargs) + json.JSONDecoder.__init__(self, *args, object_hook=self.object_hook, **kwargs) def object_hook(self, d): if 'TORCHTENSOR' in d: diff --git a/modules/textual_inversion/learn_schedule.py b/modules/textual_inversion/learn_schedule.py index f63fc72f..fda58898 100644 --- a/modules/textual_inversion/learn_schedule.py +++ b/modules/textual_inversion/learn_schedule.py @@ -32,8 +32,8 @@ class LearnScheduleIterator: self.maxit += 1 return assert self.rates - except (ValueError, AssertionError): - raise Exception('Invalid learning rate schedule. It should be a number or, for example, like "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, and 1e-5 until 10000.') + except (ValueError, AssertionError) as e: + raise Exception('Invalid learning rate schedule. It should be a number or, for example, like "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, and 1e-5 until 10000.') from e def __iter__(self): -- cgit v1.2.1 From a5121e7a0623db328a9462d340d389ed6737374a Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 10 May 2023 11:37:18 +0300 Subject: fixes for B007 --- modules/textual_inversion/learn_schedule.py | 2 +- modules/textual_inversion/textual_inversion.py | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) (limited to 'modules/textual_inversion') diff --git a/modules/textual_inversion/learn_schedule.py b/modules/textual_inversion/learn_schedule.py index fda58898..c56bea45 100644 --- a/modules/textual_inversion/learn_schedule.py +++ b/modules/textual_inversion/learn_schedule.py @@ -12,7 +12,7 @@ class LearnScheduleIterator: self.it = 0 self.maxit = 0 try: - for i, pair in enumerate(pairs): + for pair in pairs: if not pair.strip(): continue tmp = pair.split(':') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index c37bb2ad..47035332 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -29,7 +29,7 @@ textual_inversion_templates = {} def list_textual_inversion_templates(): textual_inversion_templates.clear() - for root, dirs, fns in os.walk(shared.cmd_opts.textual_inversion_templates_dir): + for root, _, fns in os.walk(shared.cmd_opts.textual_inversion_templates_dir): for fn in fns: path = os.path.join(root, fn) @@ -198,7 +198,7 @@ class EmbeddingDatabase: if not os.path.isdir(embdir.path): return - for root, dirs, fns in os.walk(embdir.path, followlinks=True): + for root, _, fns in os.walk(embdir.path, followlinks=True): for fn in fns: try: fullfn = os.path.join(root, fn) @@ -215,7 +215,7 @@ class EmbeddingDatabase: def load_textual_inversion_embeddings(self, force_reload=False): if not force_reload: need_reload = False - for path, embdir in self.embedding_dirs.items(): + for embdir in self.embedding_dirs.values(): if embdir.has_changed(): need_reload = True break @@ -228,7 +228,7 @@ class EmbeddingDatabase: self.skipped_embeddings.clear() self.expected_shape = self.get_expected_shape() - for path, embdir in self.embedding_dirs.items(): + for embdir in self.embedding_dirs.values(): self.load_from_dir(embdir) embdir.update() @@ -469,7 +469,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st try: sd_hijack_checkpoint.add() - for i in range((steps-initial_step) * gradient_step): + for _ in range((steps-initial_step) * gradient_step): if scheduler.finished: break if shared.state.interrupted: -- cgit v1.2.1 From 3ec7b705c78b7aca9569c92a419837352c7a4ec6 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 10 May 2023 21:21:32 +0300 Subject: suggestions and fixes from the PR --- modules/textual_inversion/textual_inversion.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) (limited to 'modules/textual_inversion') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 47035332..9e1b2b9a 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -166,8 +166,7 @@ class EmbeddingDatabase: # textual inversion embeddings if 'string_to_param' in data: param_dict = data['string_to_param'] - if hasattr(param_dict, '_parameters'): - param_dict = param_dict._parameters # fix for torch 1.12.1 loading saved file from torch 1.11 + param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11 assert len(param_dict) == 1, 'embedding file has multiple terms in it' emb = next(iter(param_dict.items()))[1] # diffuser concepts -- cgit v1.2.1 From df7070eca22278b25c921ef72d3f97a221d66242 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Thu, 11 May 2023 10:06:19 +0300 Subject: Deduplicate get_font code --- modules/textual_inversion/image_embedding.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) (limited to 'modules/textual_inversion') diff --git a/modules/textual_inversion/image_embedding.py b/modules/textual_inversion/image_embedding.py index d85a4888..5858a55f 100644 --- a/modules/textual_inversion/image_embedding.py +++ b/modules/textual_inversion/image_embedding.py @@ -3,9 +3,7 @@ import json import numpy as np import zlib from PIL import Image, ImageDraw, ImageFont -from fonts.ttf import Roboto import torch -from modules.shared import opts class EmbeddingEncoder(json.JSONEncoder): @@ -136,11 +134,8 @@ def caption_image_overlay(srcimage, title, footerLeft, footerMid, footerRight, t image = srcimage.copy() fontsize = 32 if textfont is None: - try: - textfont = ImageFont.truetype(opts.font or Roboto, fontsize) - textfont = opts.font or Roboto - except Exception: - textfont = Roboto + from modules.images import get_font + textfont = get_font(fontsize) factor = 1.5 gradient = Image.new('RGBA', (1, image.size[1]), color=(0, 0, 0, 0)) -- cgit v1.2.1 From 098d2fda5250f9f418fc641bd0f185cb461ee6d9 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Thu, 11 May 2023 18:13:35 +0300 Subject: Reindent autocrop with 4 spaces --- modules/textual_inversion/autocrop.py | 202 +++++++++++++++++----------------- 1 file changed, 102 insertions(+), 100 deletions(-) (limited to 'modules/textual_inversion') diff --git a/modules/textual_inversion/autocrop.py b/modules/textual_inversion/autocrop.py index 7770d22f..8e667a4d 100644 --- a/modules/textual_inversion/autocrop.py +++ b/modules/textual_inversion/autocrop.py @@ -10,63 +10,64 @@ RED = "#F00" def crop_image(im, settings): - """ Intelligently crop an image to the subject matter """ - - scale_by = 1 - if is_landscape(im.width, im.height): - scale_by = settings.crop_height / im.height - elif is_portrait(im.width, im.height): - scale_by = settings.crop_width / im.width - elif is_square(im.width, im.height): - if is_square(settings.crop_width, settings.crop_height): - scale_by = settings.crop_width / im.width - elif is_landscape(settings.crop_width, settings.crop_height): - scale_by = settings.crop_width / im.width - elif is_portrait(settings.crop_width, settings.crop_height): - scale_by = settings.crop_height / im.height - - im = im.resize((int(im.width * scale_by), int(im.height * scale_by))) - im_debug = im.copy() - - focus = focal_point(im_debug, settings) - - # take the focal point and turn it into crop coordinates that try to center over the focal - # point but then get adjusted back into the frame - y_half = int(settings.crop_height / 2) - x_half = int(settings.crop_width / 2) - - x1 = focus.x - x_half - if x1 < 0: - x1 = 0 - elif x1 + settings.crop_width > im.width: - x1 = im.width - settings.crop_width - - y1 = focus.y - y_half - if y1 < 0: - y1 = 0 - elif y1 + settings.crop_height > im.height: - y1 = im.height - settings.crop_height - - x2 = x1 + settings.crop_width - y2 = y1 + settings.crop_height - - crop = [x1, y1, x2, y2] - - results = [] - - results.append(im.crop(tuple(crop))) - - if settings.annotate_image: - d = ImageDraw.Draw(im_debug) - rect = list(crop) - rect[2] -= 1 - rect[3] -= 1 - d.rectangle(rect, outline=GREEN) - results.append(im_debug) - if settings.destop_view_image: - im_debug.show() - - return results + """ Intelligently crop an image to the subject matter """ + + scale_by = 1 + if is_landscape(im.width, im.height): + scale_by = settings.crop_height / im.height + elif is_portrait(im.width, im.height): + scale_by = settings.crop_width / im.width + elif is_square(im.width, im.height): + if is_square(settings.crop_width, settings.crop_height): + scale_by = settings.crop_width / im.width + elif is_landscape(settings.crop_width, settings.crop_height): + scale_by = settings.crop_width / im.width + elif is_portrait(settings.crop_width, settings.crop_height): + scale_by = settings.crop_height / im.height + + + im = im.resize((int(im.width * scale_by), int(im.height * scale_by))) + im_debug = im.copy() + + focus = focal_point(im_debug, settings) + + # take the focal point and turn it into crop coordinates that try to center over the focal + # point but then get adjusted back into the frame + y_half = int(settings.crop_height / 2) + x_half = int(settings.crop_width / 2) + + x1 = focus.x - x_half + if x1 < 0: + x1 = 0 + elif x1 + settings.crop_width > im.width: + x1 = im.width - settings.crop_width + + y1 = focus.y - y_half + if y1 < 0: + y1 = 0 + elif y1 + settings.crop_height > im.height: + y1 = im.height - settings.crop_height + + x2 = x1 + settings.crop_width + y2 = y1 + settings.crop_height + + crop = [x1, y1, x2, y2] + + results = [] + + results.append(im.crop(tuple(crop))) + + if settings.annotate_image: + d = ImageDraw.Draw(im_debug) + rect = list(crop) + rect[2] -= 1 + rect[3] -= 1 + d.rectangle(rect, outline=GREEN) + results.append(im_debug) + if settings.destop_view_image: + im_debug.show() + + return results def focal_point(im, settings): corner_points = image_corner_points(im, settings) if settings.corner_points_weight > 0 else [] @@ -86,7 +87,7 @@ def focal_point(im, settings): corner_centroid = None if len(corner_points) > 0: corner_centroid = centroid(corner_points) - corner_centroid.weight = settings.corner_points_weight / weight_pref_total + corner_centroid.weight = settings.corner_points_weight / weight_pref_total pois.append(corner_centroid) entropy_centroid = None @@ -98,7 +99,7 @@ def focal_point(im, settings): face_centroid = None if len(face_points) > 0: face_centroid = centroid(face_points) - face_centroid.weight = settings.face_points_weight / weight_pref_total + face_centroid.weight = settings.face_points_weight / weight_pref_total pois.append(face_centroid) average_point = poi_average(pois, settings) @@ -132,7 +133,7 @@ def focal_point(im, settings): d.rectangle(f.bounding(4), outline=color) d.ellipse(average_point.bounding(max_size), outline=GREEN) - + return average_point @@ -260,10 +261,11 @@ def image_entropy(im): hist = hist[hist > 0] return -np.log2(hist / hist.sum()).sum() + def centroid(pois): - x = [poi.x for poi in pois] - y = [poi.y for poi in pois] - return PointOfInterest(sum(x)/len(pois), sum(y)/len(pois)) + x = [poi.x for poi in pois] + y = [poi.y for poi in pois] + return PointOfInterest(sum(x) / len(pois), sum(y) / len(pois)) def poi_average(pois, settings): @@ -281,59 +283,59 @@ def poi_average(pois, settings): def is_landscape(w, h): - return w > h + return w > h def is_portrait(w, h): - return h > w + return h > w def is_square(w, h): - return w == h + return w == h def download_and_cache_models(dirname): - download_url = 'https://github.com/opencv/opencv_zoo/blob/91fb0290f50896f38a0ab1e558b74b16bc009428/models/face_detection_yunet/face_detection_yunet_2022mar.onnx?raw=true' - model_file_name = 'face_detection_yunet.onnx' + download_url = 'https://github.com/opencv/opencv_zoo/blob/91fb0290f50896f38a0ab1e558b74b16bc009428/models/face_detection_yunet/face_detection_yunet_2022mar.onnx?raw=true' + model_file_name = 'face_detection_yunet.onnx' - if not os.path.exists(dirname): - os.makedirs(dirname) + if not os.path.exists(dirname): + os.makedirs(dirname) - cache_file = os.path.join(dirname, model_file_name) - if not os.path.exists(cache_file): - print(f"downloading face detection model from '{download_url}' to '{cache_file}'") - response = requests.get(download_url) - with open(cache_file, "wb") as f: - f.write(response.content) + cache_file = os.path.join(dirname, model_file_name) + if not os.path.exists(cache_file): + print(f"downloading face detection model from '{download_url}' to '{cache_file}'") + response = requests.get(download_url) + with open(cache_file, "wb") as f: + f.write(response.content) - if os.path.exists(cache_file): - return cache_file - return None + if os.path.exists(cache_file): + return cache_file + return None class PointOfInterest: - def __init__(self, x, y, weight=1.0, size=10): - self.x = x - self.y = y - self.weight = weight - self.size = size + def __init__(self, x, y, weight=1.0, size=10): + self.x = x + self.y = y + self.weight = weight + self.size = size - def bounding(self, size): - return [ - self.x - size//2, - self.y - size//2, - self.x + size//2, - self.y + size//2 - ] + def bounding(self, size): + return [ + self.x - size // 2, + self.y - size // 2, + self.x + size // 2, + self.y + size // 2 + ] class Settings: - def __init__(self, crop_width=512, crop_height=512, corner_points_weight=0.5, entropy_points_weight=0.5, face_points_weight=0.5, annotate_image=False, dnn_model_path=None): - self.crop_width = crop_width - self.crop_height = crop_height - self.corner_points_weight = corner_points_weight - self.entropy_points_weight = entropy_points_weight - self.face_points_weight = face_points_weight - self.annotate_image = annotate_image - self.destop_view_image = False - self.dnn_model_path = dnn_model_path + def __init__(self, crop_width=512, crop_height=512, corner_points_weight=0.5, entropy_points_weight=0.5, face_points_weight=0.5, annotate_image=False, dnn_model_path=None): + self.crop_width = crop_width + self.crop_height = crop_height + self.corner_points_weight = corner_points_weight + self.entropy_points_weight = entropy_points_weight + self.face_points_weight = face_points_weight + self.annotate_image = annotate_image + self.destop_view_image = False + self.dnn_model_path = dnn_model_path -- cgit v1.2.1 From 49a55b410b66b7dd9be9335d8a2e3a71e4f8b15c Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Thu, 11 May 2023 18:28:15 +0300 Subject: Autofix Ruff W (not W605) (mostly whitespace) --- modules/textual_inversion/dataset.py | 4 ++-- modules/textual_inversion/preprocess.py | 2 +- modules/textual_inversion/textual_inversion.py | 16 ++++++++-------- 3 files changed, 11 insertions(+), 11 deletions(-) (limited to 'modules/textual_inversion') diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index 41610e03..b9621fc9 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -118,7 +118,7 @@ class PersonalizedBase(Dataset): weight = torch.ones(latent_sample.shape) else: weight = None - + if latent_sampling_method == "random": entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist, weight=weight) else: @@ -243,4 +243,4 @@ class BatchLoaderRandom(BatchLoader): return self def collate_wrapper_random(batch): - return BatchLoaderRandom(batch) \ No newline at end of file + return BatchLoaderRandom(batch) diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py index d0cad09e..a009d8e8 100644 --- a/modules/textual_inversion/preprocess.py +++ b/modules/textual_inversion/preprocess.py @@ -125,7 +125,7 @@ def multicrop_pic(image: Image, mindim, maxdim, minarea, maxarea, objective, thr default=None ) return wh and center_crop(image, *wh) - + def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_keep_original_size, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False, process_multicrop=None, process_multicrop_mindim=None, process_multicrop_maxdim=None, process_multicrop_minarea=None, process_multicrop_maxarea=None, process_multicrop_objective=None, process_multicrop_threshold=None): width = process_width diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 9e1b2b9a..d489ed1e 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -323,16 +323,16 @@ def tensorboard_add(tensorboard_writer, loss, global_step, step, learn_rate, epo tensorboard_add_scaler(tensorboard_writer, f"Learn rate/train/epoch-{epoch_num}", learn_rate, step) def tensorboard_add_scaler(tensorboard_writer, tag, value, step): - tensorboard_writer.add_scalar(tag=tag, + tensorboard_writer.add_scalar(tag=tag, scalar_value=value, global_step=step) def tensorboard_add_image(tensorboard_writer, tag, pil_image, step): # Convert a pil image to a torch tensor img_tensor = torch.as_tensor(np.array(pil_image, copy=True)) - img_tensor = img_tensor.view(pil_image.size[1], pil_image.size[0], + img_tensor = img_tensor.view(pil_image.size[1], pil_image.size[0], len(pil_image.getbands())) img_tensor = img_tensor.permute((2, 0, 1)) - + tensorboard_writer.add_image(tag, img_tensor, global_step=step) def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_model_every, create_image_every, log_directory, name="embedding"): @@ -402,7 +402,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st if initial_step >= steps: shared.state.textinfo = "Model has already been trained beyond specified max steps" return embedding, filename - + scheduler = LearnRateScheduler(learn_rate, steps, initial_step) clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else \ torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else \ @@ -412,7 +412,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st # dataset loading may take a while, so input validations and early returns should be done before this shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." old_parallel_processing_allowed = shared.parallel_processing_allowed - + if shared.opts.training_enable_tensorboard: tensorboard_writer = tensorboard_setup(log_directory) @@ -439,7 +439,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st optimizer_saved_dict = torch.load(f"{filename}.optim", map_location='cpu') if embedding.checksum() == optimizer_saved_dict.get('hash', None): optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None) - + if optimizer_state_dict is not None: optimizer.load_state_dict(optimizer_state_dict) print("Loaded existing optimizer from checkpoint") @@ -485,7 +485,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st if clip_grad: clip_grad_sched.step(embedding.step) - + with devices.autocast(): x = batch.latent_sample.to(devices.device, non_blocking=pin_memory) if use_weight: @@ -513,7 +513,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st # go back until we reach gradient accumulation steps if (j + 1) % gradient_step != 0: continue - + if clip_grad: clip_grad(embedding.vec, clip_grad_sched.learn_rate) -- cgit v1.2.1