From abeec4b63029c2c4151a78fc395d312113881845 Mon Sep 17 00:00:00 2001 From: captin411 Date: Wed, 19 Oct 2022 03:18:26 -0700 Subject: Add auto focal point cropping to Preprocess images This algorithm plots a bunch of points of interest on the source image and averages their locations to find a center. Most points come from OpenCV. One point comes from an entropy model. OpenCV points account for 50% of the weight and the entropy based point is the other 50%. The center of all weighted points is calculated and a bounding box is drawn as close to centered over that point as possible. --- modules/textual_inversion/preprocess.py | 151 ++++++++++++++++++++++++++++++-- 1 file changed, 146 insertions(+), 5 deletions(-) (limited to 'modules') diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py index 886cf0c3..168bfb09 100644 --- a/modules/textual_inversion/preprocess.py +++ b/modules/textual_inversion/preprocess.py @@ -1,5 +1,7 @@ import os -from PIL import Image, ImageOps +import cv2 +import numpy as np +from PIL import Image, ImageOps, ImageDraw import platform import sys import tqdm @@ -11,7 +13,7 @@ if cmd_opts.deepdanbooru: import modules.deepbooru as deepbooru -def preprocess(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False): +def preprocess(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False, process_entropy_focus=False): try: if process_caption: shared.interrogator.load() @@ -21,7 +23,7 @@ def preprocess(process_src, process_dst, process_width, process_height, process_ db_opts[deepbooru.OPT_INCLUDE_RANKS] = False deepbooru.create_deepbooru_process(opts.interrogate_deepbooru_score_threshold, db_opts) - preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru) + preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru, process_entropy_focus) finally: @@ -33,7 +35,7 @@ def preprocess(process_src, process_dst, process_width, process_height, process_ -def preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False): +def preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False, process_entropy_focus=False): width = process_width height = process_height src = os.path.abspath(process_src) @@ -93,6 +95,8 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro is_tall = ratio > 1.35 is_wide = ratio < 1 / 1.35 + processing_option_ran = False + if process_split and is_tall: img = img.resize((width, height * img.height // img.width)) @@ -101,6 +105,8 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro bot = img.crop((0, img.height - height, width, img.height)) save_pic(bot, index) + + processing_option_ran = True elif process_split and is_wide: img = img.resize((width * img.width // img.height, height)) @@ -109,8 +115,143 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro right = img.crop((img.width - width, 0, img.width, height)) save_pic(right, index) - else: + + processing_option_ran = True + + if process_entropy_focus and (is_tall or is_wide): + if is_tall: + img = img.resize((width, height * img.height // img.width)) + else: + img = img.resize((width * img.width // img.height, height)) + + x_focal_center, y_focal_center = image_central_focal_point(img, width, height) + + # take the focal point and turn it into crop coordinates that try to center over the focal + # point but then get adjusted back into the frame + y_half = int(height / 2) + x_half = int(width / 2) + + x1 = x_focal_center - x_half + if x1 < 0: + x1 = 0 + elif x1 + width > img.width: + x1 = img.width - width + + y1 = y_focal_center - y_half + if y1 < 0: + y1 = 0 + elif y1 + height > img.height: + y1 = img.height - height + + x2 = x1 + width + y2 = y1 + height + + crop = [x1, y1, x2, y2] + + focal = img.crop(tuple(crop)) + save_pic(focal, index) + + processing_option_ran = True + + if not processing_option_ran: img = images.resize_image(1, img, width, height) save_pic(img, index) shared.state.nextjob() + + +def image_central_focal_point(im, target_width, target_height): + focal_points = [] + + focal_points.extend( + image_focal_points(im) + ) + + fp_entropy = image_entropy_point(im, target_width, target_height) + fp_entropy['weight'] = len(focal_points) + 1 # about half of the weight to entropy + + focal_points.append(fp_entropy) + + weight = 0.0 + x = 0.0 + y = 0.0 + for focal_point in focal_points: + weight += focal_point['weight'] + x += focal_point['x'] * focal_point['weight'] + y += focal_point['y'] * focal_point['weight'] + avg_x = round(x // weight) + avg_y = round(y // weight) + + return avg_x, avg_y + + +def image_focal_points(im): + grayscale = im.convert("L") + + # naive attempt at preventing focal points from collecting at watermarks near the bottom + gd = ImageDraw.Draw(grayscale) + gd.rectangle([0, im.height*.9, im.width, im.height], fill="#999") + + np_im = np.array(grayscale) + + points = cv2.goodFeaturesToTrack( + np_im, + maxCorners=50, + qualityLevel=0.04, + minDistance=min(grayscale.width, grayscale.height)*0.05, + useHarrisDetector=False, + ) + + if points is None: + return [] + + focal_points = [] + for point in points: + x, y = point.ravel() + focal_points.append({ + 'x': x, + 'y': y, + 'weight': 1.0 + }) + + return focal_points + + +def image_entropy_point(im, crop_width, crop_height): + img = im.copy() + # just make it easier to slide the test crop with images oriented the same way + if (img.size[0] < img.size[1]): + portrait = True + img = img.rotate(90, expand=1) + + e_max = 0 + crop_current = [0, 0, crop_width, crop_height] + crop_best = crop_current + while crop_current[2] < img.size[0]: + crop = img.crop(tuple(crop_current)) + e = image_entropy(crop) + + if (e_max < e): + e_max = e + crop_best = list(crop_current) + + crop_current[0] += 4 + crop_current[2] += 4 + + x_mid = int((crop_best[2] - crop_best[0])/2) + y_mid = int((crop_best[3] - crop_best[1])/2) + + return { + 'x': x_mid, + 'y': y_mid, + 'weight': 1.0 + } + + +def image_entropy(im): + # greyscale image entropy + band = np.asarray(im.convert("L")) + hist, _ = np.histogram(band, bins=range(0, 256)) + hist = hist[hist > 0] + return -np.log2(hist / hist.sum()).sum() + -- cgit v1.2.1 From 087609ee181a91a523647435ffffa6288a317e2f Mon Sep 17 00:00:00 2001 From: captin411 Date: Wed, 19 Oct 2022 03:19:35 -0700 Subject: UI changes for focal point image cropping --- modules/ui.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 1ff7eb4f..b6be713b 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1234,6 +1234,7 @@ def create_ui(wrap_gradio_gpu_call): with gr.Row(): process_flip = gr.Checkbox(label='Create flipped copies') process_split = gr.Checkbox(label='Split oversized images into two') + process_entropy_focus = gr.Checkbox(label='Create auto focal point crop') process_caption = gr.Checkbox(label='Use BLIP for caption') process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True if cmd_opts.deepdanbooru else False) @@ -1318,7 +1319,8 @@ def create_ui(wrap_gradio_gpu_call): process_flip, process_split, process_caption, - process_caption_deepbooru + process_caption_deepbooru, + process_entropy_focus ], outputs=[ ti_output, -- cgit v1.2.1 From 41e3877be2c667316515c86037413763eb0ba4da Mon Sep 17 00:00:00 2001 From: captin411 Date: Wed, 19 Oct 2022 13:44:59 -0700 Subject: fix entropy point calculation --- modules/textual_inversion/preprocess.py | 34 ++++++++++++++++++--------------- 1 file changed, 19 insertions(+), 15 deletions(-) (limited to 'modules') diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py index 168bfb09..7c1a594e 100644 --- a/modules/textual_inversion/preprocess.py +++ b/modules/textual_inversion/preprocess.py @@ -196,9 +196,9 @@ def image_focal_points(im): points = cv2.goodFeaturesToTrack( np_im, - maxCorners=50, + maxCorners=100, qualityLevel=0.04, - minDistance=min(grayscale.width, grayscale.height)*0.05, + minDistance=min(grayscale.width, grayscale.height)*0.07, useHarrisDetector=False, ) @@ -218,28 +218,32 @@ def image_focal_points(im): def image_entropy_point(im, crop_width, crop_height): - img = im.copy() - # just make it easier to slide the test crop with images oriented the same way - if (img.size[0] < img.size[1]): - portrait = True - img = img.rotate(90, expand=1) + landscape = im.height < im.width + portrait = im.height > im.width + if landscape: + move_idx = [0, 2] + move_max = im.size[0] + elif portrait: + move_idx = [1, 3] + move_max = im.size[1] e_max = 0 crop_current = [0, 0, crop_width, crop_height] crop_best = crop_current - while crop_current[2] < img.size[0]: - crop = img.crop(tuple(crop_current)) + while crop_current[move_idx[1]] < move_max: + crop = im.crop(tuple(crop_current)) e = image_entropy(crop) - if (e_max < e): + if (e > e_max): e_max = e crop_best = list(crop_current) - crop_current[0] += 4 - crop_current[2] += 4 + crop_current[move_idx[0]] += 4 + crop_current[move_idx[1]] += 4 + + x_mid = int(crop_best[0] + crop_width/2) + y_mid = int(crop_best[1] + crop_height/2) - x_mid = int((crop_best[2] - crop_best[0])/2) - y_mid = int((crop_best[3] - crop_best[1])/2) return { 'x': x_mid, @@ -250,7 +254,7 @@ def image_entropy_point(im, crop_width, crop_height): def image_entropy(im): # greyscale image entropy - band = np.asarray(im.convert("L")) + band = np.asarray(im.convert("1")) hist, _ = np.histogram(band, bins=range(0, 256)) hist = hist[hist > 0] return -np.log2(hist / hist.sum()).sum() -- cgit v1.2.1 From 59ed74438318af893d2cba552b0e28dbc2a9266c Mon Sep 17 00:00:00 2001 From: captin411 Date: Wed, 19 Oct 2022 17:19:02 -0700 Subject: face detection algo, configurability, reusability Try to move the crop in the direction of a face if it is present More internal configuration options for choosing weights of each of the algorithm's findings Move logic into its module --- modules/textual_inversion/autocrop.py | 216 ++++++++++++++++++++++++++++++++ modules/textual_inversion/preprocess.py | 150 +++------------------- 2 files changed, 230 insertions(+), 136 deletions(-) create mode 100644 modules/textual_inversion/autocrop.py (limited to 'modules') diff --git a/modules/textual_inversion/autocrop.py b/modules/textual_inversion/autocrop.py new file mode 100644 index 00000000..f858a958 --- /dev/null +++ b/modules/textual_inversion/autocrop.py @@ -0,0 +1,216 @@ +import cv2 +from collections import defaultdict +from math import log, sqrt +import numpy as np +from PIL import Image, ImageDraw + +GREEN = "#0F0" +BLUE = "#00F" +RED = "#F00" + +def crop_image(im, settings): + """ Intelligently crop an image to the subject matter """ + if im.height > im.width: + im = im.resize((settings.crop_width, settings.crop_height * im.height // im.width)) + else: + im = im.resize((settings.crop_width * im.width // im.height, settings.crop_height)) + + focus = focal_point(im, settings) + + # take the focal point and turn it into crop coordinates that try to center over the focal + # point but then get adjusted back into the frame + y_half = int(settings.crop_height / 2) + x_half = int(settings.crop_width / 2) + + x1 = focus.x - x_half + if x1 < 0: + x1 = 0 + elif x1 + settings.crop_width > im.width: + x1 = im.width - settings.crop_width + + y1 = focus.y - y_half + if y1 < 0: + y1 = 0 + elif y1 + settings.crop_height > im.height: + y1 = im.height - settings.crop_height + + x2 = x1 + settings.crop_width + y2 = y1 + settings.crop_height + + crop = [x1, y1, x2, y2] + + if settings.annotate_image: + d = ImageDraw.Draw(im) + rect = list(crop) + rect[2] -= 1 + rect[3] -= 1 + d.rectangle(rect, outline=GREEN) + if settings.destop_view_image: + im.show() + + return im.crop(tuple(crop)) + +def focal_point(im, settings): + corner_points = image_corner_points(im, settings) + entropy_points = image_entropy_points(im, settings) + face_points = image_face_points(im, settings) + + total_points = len(corner_points) + len(entropy_points) + len(face_points) + + corner_weight = settings.corner_points_weight + entropy_weight = settings.entropy_points_weight + face_weight = settings.face_points_weight + + weight_pref_total = corner_weight + entropy_weight + face_weight + + # weight things + pois = [] + if weight_pref_total == 0 or total_points == 0: + return pois + + pois.extend( + [ PointOfInterest( p.x, p.y, weight=p.weight * ( (corner_weight/weight_pref_total) / (len(corner_points)/total_points) )) for p in corner_points ] + ) + pois.extend( + [ PointOfInterest( p.x, p.y, weight=p.weight * ( (entropy_weight/weight_pref_total) / (len(entropy_points)/total_points) )) for p in entropy_points ] + ) + pois.extend( + [ PointOfInterest( p.x, p.y, weight=p.weight * ( (face_weight/weight_pref_total) / (len(face_points)/total_points) )) for p in face_points ] + ) + + if settings.annotate_image: + d = ImageDraw.Draw(im) + + average_point = poi_average(pois, settings, im=im) + + if settings.annotate_image: + d.ellipse([average_point.x - 25, average_point.y - 25, average_point.x + 25, average_point.y + 25], outline=GREEN) + + return average_point + + +def image_face_points(im, settings): + np_im = np.array(im) + gray = cv2.cvtColor(np_im, cv2.COLOR_BGR2GRAY) + classifier = cv2.CascadeClassifier(f'{cv2.data.haarcascades}haarcascade_frontalface_default.xml') + + minsize = int(min(im.width, im.height) * 0.15) # at least N percent of the smallest side + faces = classifier.detectMultiScale(gray, scaleFactor=1.05, + minNeighbors=5, minSize=(minsize, minsize), flags=cv2.CASCADE_SCALE_IMAGE) + + if len(faces) == 0: + return [] + + rects = [[f[0], f[1], f[0] + f[2], f[1] + f[3]] for f in faces] + if settings.annotate_image: + for f in rects: + d = ImageDraw.Draw(im) + d.rectangle(f, outline=RED) + + return [PointOfInterest((r[0] +r[2]) // 2, (r[1] + r[3]) // 2) for r in rects] + + +def image_corner_points(im, settings): + grayscale = im.convert("L") + + # naive attempt at preventing focal points from collecting at watermarks near the bottom + gd = ImageDraw.Draw(grayscale) + gd.rectangle([0, im.height*.9, im.width, im.height], fill="#999") + + np_im = np.array(grayscale) + + points = cv2.goodFeaturesToTrack( + np_im, + maxCorners=100, + qualityLevel=0.04, + minDistance=min(grayscale.width, grayscale.height)*0.07, + useHarrisDetector=False, + ) + + if points is None: + return [] + + focal_points = [] + for point in points: + x, y = point.ravel() + focal_points.append(PointOfInterest(x, y)) + + return focal_points + + +def image_entropy_points(im, settings): + landscape = im.height < im.width + portrait = im.height > im.width + if landscape: + move_idx = [0, 2] + move_max = im.size[0] + elif portrait: + move_idx = [1, 3] + move_max = im.size[1] + else: + return [] + + e_max = 0 + crop_current = [0, 0, settings.crop_width, settings.crop_height] + crop_best = crop_current + while crop_current[move_idx[1]] < move_max: + crop = im.crop(tuple(crop_current)) + e = image_entropy(crop) + + if (e > e_max): + e_max = e + crop_best = list(crop_current) + + crop_current[move_idx[0]] += 4 + crop_current[move_idx[1]] += 4 + + x_mid = int(crop_best[0] + settings.crop_width/2) + y_mid = int(crop_best[1] + settings.crop_height/2) + + return [PointOfInterest(x_mid, y_mid)] + + +def image_entropy(im): + # greyscale image entropy + band = np.asarray(im.convert("1")) + hist, _ = np.histogram(band, bins=range(0, 256)) + hist = hist[hist > 0] + return -np.log2(hist / hist.sum()).sum() + + +def poi_average(pois, settings, im=None): + weight = 0.0 + x = 0.0 + y = 0.0 + for pois in pois: + if settings.annotate_image and im is not None: + w = 4 * 0.5 * sqrt(pois.weight) + d = ImageDraw.Draw(im) + d.ellipse([ + pois.x - w, pois.y - w, + pois.x + w, pois.y + w ], fill=BLUE) + weight += pois.weight + x += pois.x * pois.weight + y += pois.y * pois.weight + avg_x = round(x / weight) + avg_y = round(y / weight) + + return PointOfInterest(avg_x, avg_y) + + +class PointOfInterest: + def __init__(self, x, y, weight=1.0): + self.x = x + self.y = y + self.weight = weight + + +class Settings: + def __init__(self, crop_width=512, crop_height=512, corner_points_weight=0.5, entropy_points_weight=0.5, face_points_weight=0.5, annotate_image=False): + self.crop_width = crop_width + self.crop_height = crop_height + self.corner_points_weight = corner_points_weight + self.entropy_points_weight = entropy_points_weight + self.face_points_weight = entropy_points_weight + self.annotate_image = annotate_image + self.destop_view_image = False \ No newline at end of file diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py index 7c1a594e..0c79f012 100644 --- a/modules/textual_inversion/preprocess.py +++ b/modules/textual_inversion/preprocess.py @@ -1,7 +1,5 @@ import os -import cv2 -import numpy as np -from PIL import Image, ImageOps, ImageDraw +from PIL import Image, ImageOps import platform import sys import tqdm @@ -9,6 +7,7 @@ import time from modules import shared, images from modules.shared import opts, cmd_opts +from modules.textual_inversion import autocrop if cmd_opts.deepdanbooru: import modules.deepbooru as deepbooru @@ -80,6 +79,7 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro if process_flip: save_pic_with_caption(ImageOps.mirror(image), index) + for index, imagefile in enumerate(tqdm.tqdm(files)): subindex = [0] filename = os.path.join(src, imagefile) @@ -118,37 +118,16 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro processing_option_ran = True - if process_entropy_focus and (is_tall or is_wide): - if is_tall: - img = img.resize((width, height * img.height // img.width)) - else: - img = img.resize((width * img.width // img.height, height)) - - x_focal_center, y_focal_center = image_central_focal_point(img, width, height) - - # take the focal point and turn it into crop coordinates that try to center over the focal - # point but then get adjusted back into the frame - y_half = int(height / 2) - x_half = int(width / 2) - - x1 = x_focal_center - x_half - if x1 < 0: - x1 = 0 - elif x1 + width > img.width: - x1 = img.width - width - - y1 = y_focal_center - y_half - if y1 < 0: - y1 = 0 - elif y1 + height > img.height: - y1 = img.height - height - - x2 = x1 + width - y2 = y1 + height - - crop = [x1, y1, x2, y2] - - focal = img.crop(tuple(crop)) + if process_entropy_focus and img.height != img.width: + autocrop_settings = autocrop.Settings( + crop_width = width, + crop_height = height, + face_points_weight = 0.9, + entropy_points_weight = 0.7, + corner_points_weight = 0.5, + annotate_image = False + ) + focal = autocrop.crop_image(img, autocrop_settings) save_pic(focal, index) processing_option_ran = True @@ -157,105 +136,4 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro img = images.resize_image(1, img, width, height) save_pic(img, index) - shared.state.nextjob() - - -def image_central_focal_point(im, target_width, target_height): - focal_points = [] - - focal_points.extend( - image_focal_points(im) - ) - - fp_entropy = image_entropy_point(im, target_width, target_height) - fp_entropy['weight'] = len(focal_points) + 1 # about half of the weight to entropy - - focal_points.append(fp_entropy) - - weight = 0.0 - x = 0.0 - y = 0.0 - for focal_point in focal_points: - weight += focal_point['weight'] - x += focal_point['x'] * focal_point['weight'] - y += focal_point['y'] * focal_point['weight'] - avg_x = round(x // weight) - avg_y = round(y // weight) - - return avg_x, avg_y - - -def image_focal_points(im): - grayscale = im.convert("L") - - # naive attempt at preventing focal points from collecting at watermarks near the bottom - gd = ImageDraw.Draw(grayscale) - gd.rectangle([0, im.height*.9, im.width, im.height], fill="#999") - - np_im = np.array(grayscale) - - points = cv2.goodFeaturesToTrack( - np_im, - maxCorners=100, - qualityLevel=0.04, - minDistance=min(grayscale.width, grayscale.height)*0.07, - useHarrisDetector=False, - ) - - if points is None: - return [] - - focal_points = [] - for point in points: - x, y = point.ravel() - focal_points.append({ - 'x': x, - 'y': y, - 'weight': 1.0 - }) - - return focal_points - - -def image_entropy_point(im, crop_width, crop_height): - landscape = im.height < im.width - portrait = im.height > im.width - if landscape: - move_idx = [0, 2] - move_max = im.size[0] - elif portrait: - move_idx = [1, 3] - move_max = im.size[1] - - e_max = 0 - crop_current = [0, 0, crop_width, crop_height] - crop_best = crop_current - while crop_current[move_idx[1]] < move_max: - crop = im.crop(tuple(crop_current)) - e = image_entropy(crop) - - if (e > e_max): - e_max = e - crop_best = list(crop_current) - - crop_current[move_idx[0]] += 4 - crop_current[move_idx[1]] += 4 - - x_mid = int(crop_best[0] + crop_width/2) - y_mid = int(crop_best[1] + crop_height/2) - - - return { - 'x': x_mid, - 'y': y_mid, - 'weight': 1.0 - } - - -def image_entropy(im): - # greyscale image entropy - band = np.asarray(im.convert("1")) - hist, _ = np.histogram(band, bins=range(0, 256)) - hist = hist[hist > 0] - return -np.log2(hist / hist.sum()).sum() - + shared.state.nextjob() \ No newline at end of file -- cgit v1.2.1 From 0ddaf8d2028a7251e8c4ad93551a43b5d4700841 Mon Sep 17 00:00:00 2001 From: captin411 Date: Thu, 20 Oct 2022 00:34:55 -0700 Subject: improve face detection a lot --- modules/textual_inversion/autocrop.py | 99 ++++++++++++++++++++++------------- 1 file changed, 62 insertions(+), 37 deletions(-) (limited to 'modules') diff --git a/modules/textual_inversion/autocrop.py b/modules/textual_inversion/autocrop.py index f858a958..5a551c25 100644 --- a/modules/textual_inversion/autocrop.py +++ b/modules/textual_inversion/autocrop.py @@ -8,12 +8,18 @@ GREEN = "#0F0" BLUE = "#00F" RED = "#F00" + def crop_image(im, settings): """ Intelligently crop an image to the subject matter """ if im.height > im.width: im = im.resize((settings.crop_width, settings.crop_height * im.height // im.width)) - else: + elif im.width > im.height: im = im.resize((settings.crop_width * im.width // im.height, settings.crop_height)) + else: + im = im.resize((settings.crop_width, settings.crop_height)) + + if im.height == im.width: + return im focus = focal_point(im, settings) @@ -78,13 +84,18 @@ def focal_point(im, settings): [ PointOfInterest( p.x, p.y, weight=p.weight * ( (face_weight/weight_pref_total) / (len(face_points)/total_points) )) for p in face_points ] ) - if settings.annotate_image: - d = ImageDraw.Draw(im) - - average_point = poi_average(pois, settings, im=im) + average_point = poi_average(pois, settings) if settings.annotate_image: - d.ellipse([average_point.x - 25, average_point.y - 25, average_point.x + 25, average_point.y + 25], outline=GREEN) + d = ImageDraw.Draw(im) + for f in face_points: + d.rectangle(f.bounding(f.size), outline=RED) + for f in entropy_points: + d.rectangle(f.bounding(30), outline=BLUE) + for poi in pois: + w = max(4, 4 * 0.5 * sqrt(poi.weight)) + d.ellipse(poi.bounding(w), fill=BLUE) + d.ellipse(average_point.bounding(25), outline=GREEN) return average_point @@ -92,22 +103,32 @@ def focal_point(im, settings): def image_face_points(im, settings): np_im = np.array(im) gray = cv2.cvtColor(np_im, cv2.COLOR_BGR2GRAY) - classifier = cv2.CascadeClassifier(f'{cv2.data.haarcascades}haarcascade_frontalface_default.xml') - - minsize = int(min(im.width, im.height) * 0.15) # at least N percent of the smallest side - faces = classifier.detectMultiScale(gray, scaleFactor=1.05, - minNeighbors=5, minSize=(minsize, minsize), flags=cv2.CASCADE_SCALE_IMAGE) - if len(faces) == 0: - return [] - - rects = [[f[0], f[1], f[0] + f[2], f[1] + f[3]] for f in faces] - if settings.annotate_image: - for f in rects: - d = ImageDraw.Draw(im) - d.rectangle(f, outline=RED) - - return [PointOfInterest((r[0] +r[2]) // 2, (r[1] + r[3]) // 2) for r in rects] + tries = [ + [ f'{cv2.data.haarcascades}haarcascade_eye.xml', 0.01 ], + [ f'{cv2.data.haarcascades}haarcascade_frontalface_default.xml', 0.05 ], + [ f'{cv2.data.haarcascades}haarcascade_profileface.xml', 0.05 ], + [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt.xml', 0.05 ], + [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt2.xml', 0.05 ], + [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt_tree.xml', 0.05 ], + [ f'{cv2.data.haarcascades}haarcascade_eye_tree_eyeglasses.xml', 0.05 ], + [ f'{cv2.data.haarcascades}haarcascade_upperbody.xml', 0.05 ] + ] + + for t in tries: + # print(t[0]) + classifier = cv2.CascadeClassifier(t[0]) + minsize = int(min(im.width, im.height) * t[1]) # at least N percent of the smallest side + try: + faces = classifier.detectMultiScale(gray, scaleFactor=1.1, + minNeighbors=7, minSize=(minsize, minsize), flags=cv2.CASCADE_SCALE_IMAGE) + except: + continue + + if len(faces) > 0: + rects = [[f[0], f[1], f[0] + f[2], f[1] + f[3]] for f in faces] + return [PointOfInterest((r[0] +r[2]) // 2, (r[1] + r[3]) // 2, size=abs(r[0]-r[2])) for r in rects] + return [] def image_corner_points(im, settings): @@ -132,8 +153,8 @@ def image_corner_points(im, settings): focal_points = [] for point in points: - x, y = point.ravel() - focal_points.append(PointOfInterest(x, y)) + x, y = point.ravel() + focal_points.append(PointOfInterest(x, y, size=4)) return focal_points @@ -167,31 +188,26 @@ def image_entropy_points(im, settings): x_mid = int(crop_best[0] + settings.crop_width/2) y_mid = int(crop_best[1] + settings.crop_height/2) - return [PointOfInterest(x_mid, y_mid)] + return [PointOfInterest(x_mid, y_mid, size=25)] def image_entropy(im): # greyscale image entropy - band = np.asarray(im.convert("1")) + # band = np.asarray(im.convert("L")) + band = np.asarray(im.convert("1"), dtype=np.uint8) hist, _ = np.histogram(band, bins=range(0, 256)) hist = hist[hist > 0] return -np.log2(hist / hist.sum()).sum() -def poi_average(pois, settings, im=None): +def poi_average(pois, settings): weight = 0.0 x = 0.0 y = 0.0 - for pois in pois: - if settings.annotate_image and im is not None: - w = 4 * 0.5 * sqrt(pois.weight) - d = ImageDraw.Draw(im) - d.ellipse([ - pois.x - w, pois.y - w, - pois.x + w, pois.y + w ], fill=BLUE) - weight += pois.weight - x += pois.x * pois.weight - y += pois.y * pois.weight + for poi in pois: + weight += poi.weight + x += poi.x * poi.weight + y += poi.y * poi.weight avg_x = round(x / weight) avg_y = round(y / weight) @@ -199,10 +215,19 @@ def poi_average(pois, settings, im=None): class PointOfInterest: - def __init__(self, x, y, weight=1.0): + def __init__(self, x, y, weight=1.0, size=10): self.x = x self.y = y self.weight = weight + self.size = size + + def bounding(self, size): + return [ + self.x - size//2, + self.y - size//2, + self.x + size//2, + self.y + size//2 + ] class Settings: -- cgit v1.2.1 From 1be5933ba21a3badec42b7b2753d626f849b609d Mon Sep 17 00:00:00 2001 From: captin411 Date: Sun, 23 Oct 2022 04:11:07 -0700 Subject: auto cropping now works with non square crops --- modules/textual_inversion/autocrop.py | 509 ++++++++++++++++++---------------- 1 file changed, 269 insertions(+), 240 deletions(-) (limited to 'modules') diff --git a/modules/textual_inversion/autocrop.py b/modules/textual_inversion/autocrop.py index 5a551c25..b2f9241c 100644 --- a/modules/textual_inversion/autocrop.py +++ b/modules/textual_inversion/autocrop.py @@ -1,241 +1,270 @@ -import cv2 -from collections import defaultdict -from math import log, sqrt -import numpy as np -from PIL import Image, ImageDraw - -GREEN = "#0F0" -BLUE = "#00F" -RED = "#F00" - - -def crop_image(im, settings): - """ Intelligently crop an image to the subject matter """ - if im.height > im.width: - im = im.resize((settings.crop_width, settings.crop_height * im.height // im.width)) - elif im.width > im.height: - im = im.resize((settings.crop_width * im.width // im.height, settings.crop_height)) - else: - im = im.resize((settings.crop_width, settings.crop_height)) - - if im.height == im.width: - return im - - focus = focal_point(im, settings) - - # take the focal point and turn it into crop coordinates that try to center over the focal - # point but then get adjusted back into the frame - y_half = int(settings.crop_height / 2) - x_half = int(settings.crop_width / 2) - - x1 = focus.x - x_half - if x1 < 0: - x1 = 0 - elif x1 + settings.crop_width > im.width: - x1 = im.width - settings.crop_width - - y1 = focus.y - y_half - if y1 < 0: - y1 = 0 - elif y1 + settings.crop_height > im.height: - y1 = im.height - settings.crop_height - - x2 = x1 + settings.crop_width - y2 = y1 + settings.crop_height - - crop = [x1, y1, x2, y2] - - if settings.annotate_image: - d = ImageDraw.Draw(im) - rect = list(crop) - rect[2] -= 1 - rect[3] -= 1 - d.rectangle(rect, outline=GREEN) - if settings.destop_view_image: - im.show() - - return im.crop(tuple(crop)) - -def focal_point(im, settings): - corner_points = image_corner_points(im, settings) - entropy_points = image_entropy_points(im, settings) - face_points = image_face_points(im, settings) - - total_points = len(corner_points) + len(entropy_points) + len(face_points) - - corner_weight = settings.corner_points_weight - entropy_weight = settings.entropy_points_weight - face_weight = settings.face_points_weight - - weight_pref_total = corner_weight + entropy_weight + face_weight - - # weight things - pois = [] - if weight_pref_total == 0 or total_points == 0: - return pois - - pois.extend( - [ PointOfInterest( p.x, p.y, weight=p.weight * ( (corner_weight/weight_pref_total) / (len(corner_points)/total_points) )) for p in corner_points ] - ) - pois.extend( - [ PointOfInterest( p.x, p.y, weight=p.weight * ( (entropy_weight/weight_pref_total) / (len(entropy_points)/total_points) )) for p in entropy_points ] - ) - pois.extend( - [ PointOfInterest( p.x, p.y, weight=p.weight * ( (face_weight/weight_pref_total) / (len(face_points)/total_points) )) for p in face_points ] - ) - - average_point = poi_average(pois, settings) - - if settings.annotate_image: - d = ImageDraw.Draw(im) - for f in face_points: - d.rectangle(f.bounding(f.size), outline=RED) - for f in entropy_points: - d.rectangle(f.bounding(30), outline=BLUE) - for poi in pois: - w = max(4, 4 * 0.5 * sqrt(poi.weight)) - d.ellipse(poi.bounding(w), fill=BLUE) - d.ellipse(average_point.bounding(25), outline=GREEN) - - return average_point - - -def image_face_points(im, settings): - np_im = np.array(im) - gray = cv2.cvtColor(np_im, cv2.COLOR_BGR2GRAY) - - tries = [ - [ f'{cv2.data.haarcascades}haarcascade_eye.xml', 0.01 ], - [ f'{cv2.data.haarcascades}haarcascade_frontalface_default.xml', 0.05 ], - [ f'{cv2.data.haarcascades}haarcascade_profileface.xml', 0.05 ], - [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt.xml', 0.05 ], - [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt2.xml', 0.05 ], - [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt_tree.xml', 0.05 ], - [ f'{cv2.data.haarcascades}haarcascade_eye_tree_eyeglasses.xml', 0.05 ], - [ f'{cv2.data.haarcascades}haarcascade_upperbody.xml', 0.05 ] - ] - - for t in tries: - # print(t[0]) - classifier = cv2.CascadeClassifier(t[0]) - minsize = int(min(im.width, im.height) * t[1]) # at least N percent of the smallest side - try: - faces = classifier.detectMultiScale(gray, scaleFactor=1.1, - minNeighbors=7, minSize=(minsize, minsize), flags=cv2.CASCADE_SCALE_IMAGE) - except: - continue - - if len(faces) > 0: - rects = [[f[0], f[1], f[0] + f[2], f[1] + f[3]] for f in faces] - return [PointOfInterest((r[0] +r[2]) // 2, (r[1] + r[3]) // 2, size=abs(r[0]-r[2])) for r in rects] - return [] - - -def image_corner_points(im, settings): - grayscale = im.convert("L") - - # naive attempt at preventing focal points from collecting at watermarks near the bottom - gd = ImageDraw.Draw(grayscale) - gd.rectangle([0, im.height*.9, im.width, im.height], fill="#999") - - np_im = np.array(grayscale) - - points = cv2.goodFeaturesToTrack( - np_im, - maxCorners=100, - qualityLevel=0.04, - minDistance=min(grayscale.width, grayscale.height)*0.07, - useHarrisDetector=False, - ) - - if points is None: - return [] - - focal_points = [] - for point in points: - x, y = point.ravel() - focal_points.append(PointOfInterest(x, y, size=4)) - - return focal_points - - -def image_entropy_points(im, settings): - landscape = im.height < im.width - portrait = im.height > im.width - if landscape: - move_idx = [0, 2] - move_max = im.size[0] - elif portrait: - move_idx = [1, 3] - move_max = im.size[1] - else: - return [] - - e_max = 0 - crop_current = [0, 0, settings.crop_width, settings.crop_height] - crop_best = crop_current - while crop_current[move_idx[1]] < move_max: - crop = im.crop(tuple(crop_current)) - e = image_entropy(crop) - - if (e > e_max): - e_max = e - crop_best = list(crop_current) - - crop_current[move_idx[0]] += 4 - crop_current[move_idx[1]] += 4 - - x_mid = int(crop_best[0] + settings.crop_width/2) - y_mid = int(crop_best[1] + settings.crop_height/2) - - return [PointOfInterest(x_mid, y_mid, size=25)] - - -def image_entropy(im): - # greyscale image entropy - # band = np.asarray(im.convert("L")) - band = np.asarray(im.convert("1"), dtype=np.uint8) - hist, _ = np.histogram(band, bins=range(0, 256)) - hist = hist[hist > 0] - return -np.log2(hist / hist.sum()).sum() - - -def poi_average(pois, settings): - weight = 0.0 - x = 0.0 - y = 0.0 - for poi in pois: - weight += poi.weight - x += poi.x * poi.weight - y += poi.y * poi.weight - avg_x = round(x / weight) - avg_y = round(y / weight) - - return PointOfInterest(avg_x, avg_y) - - -class PointOfInterest: - def __init__(self, x, y, weight=1.0, size=10): - self.x = x - self.y = y - self.weight = weight - self.size = size - - def bounding(self, size): - return [ - self.x - size//2, - self.y - size//2, - self.x + size//2, - self.y + size//2 - ] - - -class Settings: - def __init__(self, crop_width=512, crop_height=512, corner_points_weight=0.5, entropy_points_weight=0.5, face_points_weight=0.5, annotate_image=False): - self.crop_width = crop_width - self.crop_height = crop_height - self.corner_points_weight = corner_points_weight - self.entropy_points_weight = entropy_points_weight - self.face_points_weight = entropy_points_weight - self.annotate_image = annotate_image +import cv2 +from collections import defaultdict +from math import log, sqrt +import numpy as np +from PIL import Image, ImageDraw + +GREEN = "#0F0" +BLUE = "#00F" +RED = "#F00" + + +def crop_image(im, settings): + """ Intelligently crop an image to the subject matter """ + + scale_by = 1 + if is_landscape(im.width, im.height): + scale_by = settings.crop_height / im.height + elif is_portrait(im.width, im.height): + scale_by = settings.crop_width / im.width + elif is_square(im.width, im.height): + if is_square(settings.crop_width, settings.crop_height): + scale_by = settings.crop_width / im.width + elif is_landscape(settings.crop_width, settings.crop_height): + scale_by = settings.crop_width / im.width + elif is_portrait(settings.crop_width, settings.crop_height): + scale_by = settings.crop_height / im.height + + im = im.resize((int(im.width * scale_by), int(im.height * scale_by))) + + if im.width == settings.crop_width and im.height == settings.crop_height: + if settings.annotate_image: + d = ImageDraw.Draw(im) + rect = [0, 0, im.width, im.height] + rect[2] -= 1 + rect[3] -= 1 + d.rectangle(rect, outline=GREEN) + if settings.destop_view_image: + im.show() + return im + + focus = focal_point(im, settings) + + # take the focal point and turn it into crop coordinates that try to center over the focal + # point but then get adjusted back into the frame + y_half = int(settings.crop_height / 2) + x_half = int(settings.crop_width / 2) + + x1 = focus.x - x_half + if x1 < 0: + x1 = 0 + elif x1 + settings.crop_width > im.width: + x1 = im.width - settings.crop_width + + y1 = focus.y - y_half + if y1 < 0: + y1 = 0 + elif y1 + settings.crop_height > im.height: + y1 = im.height - settings.crop_height + + x2 = x1 + settings.crop_width + y2 = y1 + settings.crop_height + + crop = [x1, y1, x2, y2] + + if settings.annotate_image: + d = ImageDraw.Draw(im) + rect = list(crop) + rect[2] -= 1 + rect[3] -= 1 + d.rectangle(rect, outline=GREEN) + if settings.destop_view_image: + im.show() + + return im.crop(tuple(crop)) + +def focal_point(im, settings): + corner_points = image_corner_points(im, settings) + entropy_points = image_entropy_points(im, settings) + face_points = image_face_points(im, settings) + + total_points = len(corner_points) + len(entropy_points) + len(face_points) + + corner_weight = settings.corner_points_weight + entropy_weight = settings.entropy_points_weight + face_weight = settings.face_points_weight + + weight_pref_total = corner_weight + entropy_weight + face_weight + + # weight things + pois = [] + if weight_pref_total == 0 or total_points == 0: + return pois + + pois.extend( + [ PointOfInterest( p.x, p.y, weight=p.weight * ( (corner_weight/weight_pref_total) / (len(corner_points)/total_points) )) for p in corner_points ] + ) + pois.extend( + [ PointOfInterest( p.x, p.y, weight=p.weight * ( (entropy_weight/weight_pref_total) / (len(entropy_points)/total_points) )) for p in entropy_points ] + ) + pois.extend( + [ PointOfInterest( p.x, p.y, weight=p.weight * ( (face_weight/weight_pref_total) / (len(face_points)/total_points) )) for p in face_points ] + ) + + average_point = poi_average(pois, settings) + + if settings.annotate_image: + d = ImageDraw.Draw(im) + for f in face_points: + d.rectangle(f.bounding(f.size), outline=RED) + for f in entropy_points: + d.rectangle(f.bounding(30), outline=BLUE) + for poi in pois: + w = max(4, 4 * 0.5 * sqrt(poi.weight)) + d.ellipse(poi.bounding(w), fill=BLUE) + d.ellipse(average_point.bounding(25), outline=GREEN) + + return average_point + + +def image_face_points(im, settings): + np_im = np.array(im) + gray = cv2.cvtColor(np_im, cv2.COLOR_BGR2GRAY) + + tries = [ + [ f'{cv2.data.haarcascades}haarcascade_eye.xml', 0.01 ], + [ f'{cv2.data.haarcascades}haarcascade_frontalface_default.xml', 0.05 ], + [ f'{cv2.data.haarcascades}haarcascade_profileface.xml', 0.05 ], + [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt.xml', 0.05 ], + [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt2.xml', 0.05 ], + [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt_tree.xml', 0.05 ], + [ f'{cv2.data.haarcascades}haarcascade_eye_tree_eyeglasses.xml', 0.05 ], + [ f'{cv2.data.haarcascades}haarcascade_upperbody.xml', 0.05 ] + ] + + for t in tries: + # print(t[0]) + classifier = cv2.CascadeClassifier(t[0]) + minsize = int(min(im.width, im.height) * t[1]) # at least N percent of the smallest side + try: + faces = classifier.detectMultiScale(gray, scaleFactor=1.1, + minNeighbors=7, minSize=(minsize, minsize), flags=cv2.CASCADE_SCALE_IMAGE) + except: + continue + + if len(faces) > 0: + rects = [[f[0], f[1], f[0] + f[2], f[1] + f[3]] for f in faces] + return [PointOfInterest((r[0] +r[2]) // 2, (r[1] + r[3]) // 2, size=abs(r[0]-r[2])) for r in rects] + return [] + + +def image_corner_points(im, settings): + grayscale = im.convert("L") + + # naive attempt at preventing focal points from collecting at watermarks near the bottom + gd = ImageDraw.Draw(grayscale) + gd.rectangle([0, im.height*.9, im.width, im.height], fill="#999") + + np_im = np.array(grayscale) + + points = cv2.goodFeaturesToTrack( + np_im, + maxCorners=100, + qualityLevel=0.04, + minDistance=min(grayscale.width, grayscale.height)*0.07, + useHarrisDetector=False, + ) + + if points is None: + return [] + + focal_points = [] + for point in points: + x, y = point.ravel() + focal_points.append(PointOfInterest(x, y, size=4)) + + return focal_points + + +def image_entropy_points(im, settings): + landscape = im.height < im.width + portrait = im.height > im.width + if landscape: + move_idx = [0, 2] + move_max = im.size[0] + elif portrait: + move_idx = [1, 3] + move_max = im.size[1] + else: + return [] + + e_max = 0 + crop_current = [0, 0, settings.crop_width, settings.crop_height] + crop_best = crop_current + while crop_current[move_idx[1]] < move_max: + crop = im.crop(tuple(crop_current)) + e = image_entropy(crop) + + if (e > e_max): + e_max = e + crop_best = list(crop_current) + + crop_current[move_idx[0]] += 4 + crop_current[move_idx[1]] += 4 + + x_mid = int(crop_best[0] + settings.crop_width/2) + y_mid = int(crop_best[1] + settings.crop_height/2) + + return [PointOfInterest(x_mid, y_mid, size=25)] + + +def image_entropy(im): + # greyscale image entropy + # band = np.asarray(im.convert("L")) + band = np.asarray(im.convert("1"), dtype=np.uint8) + hist, _ = np.histogram(band, bins=range(0, 256)) + hist = hist[hist > 0] + return -np.log2(hist / hist.sum()).sum() + + +def poi_average(pois, settings): + weight = 0.0 + x = 0.0 + y = 0.0 + for poi in pois: + weight += poi.weight + x += poi.x * poi.weight + y += poi.y * poi.weight + avg_x = round(x / weight) + avg_y = round(y / weight) + + return PointOfInterest(avg_x, avg_y) + + +def is_landscape(w, h): + return w > h + + +def is_portrait(w, h): + return h > w + + +def is_square(w, h): + return w == h + + +class PointOfInterest: + def __init__(self, x, y, weight=1.0, size=10): + self.x = x + self.y = y + self.weight = weight + self.size = size + + def bounding(self, size): + return [ + self.x - size//2, + self.y - size//2, + self.x + size//2, + self.y + size//2 + ] + + +class Settings: + def __init__(self, crop_width=512, crop_height=512, corner_points_weight=0.5, entropy_points_weight=0.5, face_points_weight=0.5, annotate_image=False): + self.crop_width = crop_width + self.crop_height = crop_height + self.corner_points_weight = corner_points_weight + self.entropy_points_weight = entropy_points_weight + self.face_points_weight = entropy_points_weight + self.annotate_image = annotate_image self.destop_view_image = False \ No newline at end of file -- cgit v1.2.1 From 3e6c2420c1177e9e79f2b566a5a7795b7416e34a Mon Sep 17 00:00:00 2001 From: captin411 Date: Tue, 25 Oct 2022 13:10:58 -0700 Subject: improve debug markers, fix algo weighting --- modules/textual_inversion/autocrop.py | 207 +++++++++++++++++++++------------- 1 file changed, 129 insertions(+), 78 deletions(-) (limited to 'modules') diff --git a/modules/textual_inversion/autocrop.py b/modules/textual_inversion/autocrop.py index b2f9241c..caaf18c8 100644 --- a/modules/textual_inversion/autocrop.py +++ b/modules/textual_inversion/autocrop.py @@ -1,4 +1,5 @@ import cv2 +import os from collections import defaultdict from math import log, sqrt import numpy as np @@ -26,19 +27,9 @@ def crop_image(im, settings): scale_by = settings.crop_height / im.height im = im.resize((int(im.width * scale_by), int(im.height * scale_by))) + im_debug = im.copy() - if im.width == settings.crop_width and im.height == settings.crop_height: - if settings.annotate_image: - d = ImageDraw.Draw(im) - rect = [0, 0, im.width, im.height] - rect[2] -= 1 - rect[3] -= 1 - d.rectangle(rect, outline=GREEN) - if settings.destop_view_image: - im.show() - return im - - focus = focal_point(im, settings) + focus = focal_point(im_debug, settings) # take the focal point and turn it into crop coordinates that try to center over the focal # point but then get adjusted back into the frame @@ -62,89 +53,143 @@ def crop_image(im, settings): crop = [x1, y1, x2, y2] + results = [] + + results.append(im.crop(tuple(crop))) + if settings.annotate_image: - d = ImageDraw.Draw(im) + d = ImageDraw.Draw(im_debug) rect = list(crop) rect[2] -= 1 rect[3] -= 1 d.rectangle(rect, outline=GREEN) + results.append(im_debug) if settings.destop_view_image: - im.show() + im_debug.show() - return im.crop(tuple(crop)) + return results def focal_point(im, settings): corner_points = image_corner_points(im, settings) entropy_points = image_entropy_points(im, settings) face_points = image_face_points(im, settings) - total_points = len(corner_points) + len(entropy_points) + len(face_points) - - corner_weight = settings.corner_points_weight - entropy_weight = settings.entropy_points_weight - face_weight = settings.face_points_weight - - weight_pref_total = corner_weight + entropy_weight + face_weight - - # weight things pois = [] - if weight_pref_total == 0 or total_points == 0: - return pois - pois.extend( - [ PointOfInterest( p.x, p.y, weight=p.weight * ( (corner_weight/weight_pref_total) / (len(corner_points)/total_points) )) for p in corner_points ] - ) - pois.extend( - [ PointOfInterest( p.x, p.y, weight=p.weight * ( (entropy_weight/weight_pref_total) / (len(entropy_points)/total_points) )) for p in entropy_points ] - ) - pois.extend( - [ PointOfInterest( p.x, p.y, weight=p.weight * ( (face_weight/weight_pref_total) / (len(face_points)/total_points) )) for p in face_points ] - ) + weight_pref_total = 0 + if len(corner_points) > 0: + weight_pref_total += settings.corner_points_weight + if len(entropy_points) > 0: + weight_pref_total += settings.entropy_points_weight + if len(face_points) > 0: + weight_pref_total += settings.face_points_weight + + corner_centroid = None + if len(corner_points) > 0: + corner_centroid = centroid(corner_points) + corner_centroid.weight = settings.corner_points_weight / weight_pref_total + pois.append(corner_centroid) + + entropy_centroid = None + if len(entropy_points) > 0: + entropy_centroid = centroid(entropy_points) + entropy_centroid.weight = settings.entropy_points_weight / weight_pref_total + pois.append(entropy_centroid) + + face_centroid = None + if len(face_points) > 0: + face_centroid = centroid(face_points) + face_centroid.weight = settings.face_points_weight / weight_pref_total + pois.append(face_centroid) average_point = poi_average(pois, settings) if settings.annotate_image: d = ImageDraw.Draw(im) - for f in face_points: - d.rectangle(f.bounding(f.size), outline=RED) - for f in entropy_points: - d.rectangle(f.bounding(30), outline=BLUE) - for poi in pois: - w = max(4, 4 * 0.5 * sqrt(poi.weight)) - d.ellipse(poi.bounding(w), fill=BLUE) - d.ellipse(average_point.bounding(25), outline=GREEN) + max_size = min(im.width, im.height) * 0.07 + if corner_centroid is not None: + color = BLUE + box = corner_centroid.bounding(max_size * corner_centroid.weight) + d.text((box[0], box[1]-15), "Edge: %.02f" % corner_centroid.weight, fill=color) + d.ellipse(box, outline=color) + if len(corner_points) > 1: + for f in corner_points: + d.rectangle(f.bounding(4), outline=color) + if entropy_centroid is not None: + color = "#ff0" + box = entropy_centroid.bounding(max_size * entropy_centroid.weight) + d.text((box[0], box[1]-15), "Entropy: %.02f" % entropy_centroid.weight, fill=color) + d.ellipse(box, outline=color) + if len(entropy_points) > 1: + for f in entropy_points: + d.rectangle(f.bounding(4), outline=color) + if face_centroid is not None: + color = RED + box = face_centroid.bounding(max_size * face_centroid.weight) + d.text((box[0], box[1]-15), "Face: %.02f" % face_centroid.weight, fill=color) + d.ellipse(box, outline=color) + if len(face_points) > 1: + for f in face_points: + d.rectangle(f.bounding(4), outline=color) + + d.ellipse(average_point.bounding(max_size), outline=GREEN) return average_point def image_face_points(im, settings): - np_im = np.array(im) - gray = cv2.cvtColor(np_im, cv2.COLOR_BGR2GRAY) - - tries = [ - [ f'{cv2.data.haarcascades}haarcascade_eye.xml', 0.01 ], - [ f'{cv2.data.haarcascades}haarcascade_frontalface_default.xml', 0.05 ], - [ f'{cv2.data.haarcascades}haarcascade_profileface.xml', 0.05 ], - [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt.xml', 0.05 ], - [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt2.xml', 0.05 ], - [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt_tree.xml', 0.05 ], - [ f'{cv2.data.haarcascades}haarcascade_eye_tree_eyeglasses.xml', 0.05 ], - [ f'{cv2.data.haarcascades}haarcascade_upperbody.xml', 0.05 ] - ] - - for t in tries: - # print(t[0]) - classifier = cv2.CascadeClassifier(t[0]) - minsize = int(min(im.width, im.height) * t[1]) # at least N percent of the smallest side - try: - faces = classifier.detectMultiScale(gray, scaleFactor=1.1, - minNeighbors=7, minSize=(minsize, minsize), flags=cv2.CASCADE_SCALE_IMAGE) - except: - continue - - if len(faces) > 0: - rects = [[f[0], f[1], f[0] + f[2], f[1] + f[3]] for f in faces] - return [PointOfInterest((r[0] +r[2]) // 2, (r[1] + r[3]) // 2, size=abs(r[0]-r[2])) for r in rects] + if settings.dnn_model_path is not None: + detector = cv2.FaceDetectorYN.create( + settings.dnn_model_path, + "", + (im.width, im.height), + 0.8, # score threshold + 0.3, # nms threshold + 5000 # keep top k before nms + ) + faces = detector.detect(np.array(im)) + results = [] + if faces[1] is not None: + for face in faces[1]: + x = face[0] + y = face[1] + w = face[2] + h = face[3] + results.append( + PointOfInterest( + int(x + (w * 0.5)), # face focus left/right is center + int(y + (h * 0)), # face focus up/down is close to the top of the head + size = w, + weight = 1/len(faces[1]) + ) + ) + return results + else: + np_im = np.array(im) + gray = cv2.cvtColor(np_im, cv2.COLOR_BGR2GRAY) + + tries = [ + [ f'{cv2.data.haarcascades}haarcascade_eye.xml', 0.01 ], + [ f'{cv2.data.haarcascades}haarcascade_frontalface_default.xml', 0.05 ], + [ f'{cv2.data.haarcascades}haarcascade_profileface.xml', 0.05 ], + [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt.xml', 0.05 ], + [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt2.xml', 0.05 ], + [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt_tree.xml', 0.05 ], + [ f'{cv2.data.haarcascades}haarcascade_eye_tree_eyeglasses.xml', 0.05 ], + [ f'{cv2.data.haarcascades}haarcascade_upperbody.xml', 0.05 ] + ] + for t in tries: + classifier = cv2.CascadeClassifier(t[0]) + minsize = int(min(im.width, im.height) * t[1]) # at least N percent of the smallest side + try: + faces = classifier.detectMultiScale(gray, scaleFactor=1.1, + minNeighbors=7, minSize=(minsize, minsize), flags=cv2.CASCADE_SCALE_IMAGE) + except: + continue + + if len(faces) > 0: + rects = [[f[0], f[1], f[0] + f[2], f[1] + f[3]] for f in faces] + return [PointOfInterest((r[0] +r[2]) // 2, (r[1] + r[3]) // 2, size=abs(r[0]-r[2]), weight=1/len(rects)) for r in rects] return [] @@ -161,7 +206,7 @@ def image_corner_points(im, settings): np_im, maxCorners=100, qualityLevel=0.04, - minDistance=min(grayscale.width, grayscale.height)*0.07, + minDistance=min(grayscale.width, grayscale.height)*0.03, useHarrisDetector=False, ) @@ -171,7 +216,7 @@ def image_corner_points(im, settings): focal_points = [] for point in points: x, y = point.ravel() - focal_points.append(PointOfInterest(x, y, size=4)) + focal_points.append(PointOfInterest(x, y, size=4, weight=1/len(points))) return focal_points @@ -205,17 +250,22 @@ def image_entropy_points(im, settings): x_mid = int(crop_best[0] + settings.crop_width/2) y_mid = int(crop_best[1] + settings.crop_height/2) - return [PointOfInterest(x_mid, y_mid, size=25)] + return [PointOfInterest(x_mid, y_mid, size=25, weight=1.0)] def image_entropy(im): # greyscale image entropy - # band = np.asarray(im.convert("L")) - band = np.asarray(im.convert("1"), dtype=np.uint8) + band = np.asarray(im.convert("L")) + # band = np.asarray(im.convert("1"), dtype=np.uint8) hist, _ = np.histogram(band, bins=range(0, 256)) hist = hist[hist > 0] return -np.log2(hist / hist.sum()).sum() +def centroid(pois): + x = [poi.x for poi in pois] + y = [poi.y for poi in pois] + return PointOfInterest(sum(x)/len(pois), sum(y)/len(pois)) + def poi_average(pois, settings): weight = 0.0 @@ -260,11 +310,12 @@ class PointOfInterest: class Settings: - def __init__(self, crop_width=512, crop_height=512, corner_points_weight=0.5, entropy_points_weight=0.5, face_points_weight=0.5, annotate_image=False): + def __init__(self, crop_width=512, crop_height=512, corner_points_weight=0.5, entropy_points_weight=0.5, face_points_weight=0.5, annotate_image=False, dnn_model_path=None): self.crop_width = crop_width self.crop_height = crop_height self.corner_points_weight = corner_points_weight self.entropy_points_weight = entropy_points_weight - self.face_points_weight = entropy_points_weight + self.face_points_weight = face_points_weight self.annotate_image = annotate_image - self.destop_view_image = False \ No newline at end of file + self.destop_view_image = False + self.dnn_model_path = dnn_model_path \ No newline at end of file -- cgit v1.2.1 From db8ed5fe5cd6e967d12d43d96b7f83083e58626c Mon Sep 17 00:00:00 2001 From: captin411 Date: Tue, 25 Oct 2022 15:22:29 -0700 Subject: Focal crop UI elements --- modules/textual_inversion/preprocess.py | 26 +++++++++++++------------- modules/ui.py | 20 ++++++++++++++++++-- 2 files changed, 31 insertions(+), 15 deletions(-) (limited to 'modules') diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py index a8c17c6f..1e4d4de8 100644 --- a/modules/textual_inversion/preprocess.py +++ b/modules/textual_inversion/preprocess.py @@ -13,7 +13,7 @@ if cmd_opts.deepdanbooru: import modules.deepbooru as deepbooru -def preprocess(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_entropy_focus=False): +def preprocess(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False): try: if process_caption: shared.interrogator.load() @@ -23,7 +23,7 @@ def preprocess(process_src, process_dst, process_width, process_height, preproce db_opts[deepbooru.OPT_INCLUDE_RANKS] = False deepbooru.create_deepbooru_process(opts.interrogate_deepbooru_score_threshold, db_opts) - preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru, split_threshold, overlap_ratio, process_entropy_focus) + preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru, split_threshold, overlap_ratio, process_focal_crop, process_focal_crop_face_weight, process_focal_crop_entropy_weight, process_focal_crop_edges_weight, process_focal_crop_debug) finally: @@ -35,7 +35,7 @@ def preprocess(process_src, process_dst, process_width, process_height, preproce -def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_entropy_focus=False): +def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False): width = process_width height = process_height src = os.path.abspath(process_src) @@ -139,27 +139,27 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre ratio = (img.height * width) / (img.width * height) inverse_xy = True - processing_option_ran = False + process_default_resize = True if process_split and ratio < 1.0 and ratio <= split_threshold: for splitted in split_pic(img, inverse_xy): save_pic(splitted, index, existing_caption=existing_caption) - processing_option_ran = True + process_default_resize = False if process_entropy_focus and img.height != img.width: autocrop_settings = autocrop.Settings( crop_width = width, crop_height = height, - face_points_weight = 0.9, - entropy_points_weight = 0.7, - corner_points_weight = 0.5, - annotate_image = False + face_points_weight = process_focal_crop_face_weight, + entropy_points_weight = process_focal_crop_entropy_weight, + corner_points_weight = process_focal_crop_edges_weight, + annotate_image = process_focal_crop_debug ) - focal = autocrop.crop_image(img, autocrop_settings) - save_pic(focal, index, existing_caption=existing_caption) - processing_option_ran = True + for focal in autocrop.crop_image(img, autocrop_settings): + save_pic(focal, index, existing_caption=existing_caption) + process_default_resize = False - if not processing_option_ran: + if process_default_resize: img = images.resize_image(1, img, width, height) save_pic(img, index, existing_caption=existing_caption) diff --git a/modules/ui.py b/modules/ui.py index 028eb4e5..95b9c703 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1260,7 +1260,7 @@ def create_ui(wrap_gradio_gpu_call): with gr.Row(): process_flip = gr.Checkbox(label='Create flipped copies') process_split = gr.Checkbox(label='Split oversized images') - process_entropy_focus = gr.Checkbox(label='Create auto focal point crop') + process_focal_crop = gr.Checkbox(label='Auto focal point crop') process_caption = gr.Checkbox(label='Use BLIP for caption') process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True if cmd_opts.deepdanbooru else False) @@ -1268,6 +1268,12 @@ def create_ui(wrap_gradio_gpu_call): process_split_threshold = gr.Slider(label='Split image threshold', value=0.5, minimum=0.0, maximum=1.0, step=0.05) process_overlap_ratio = gr.Slider(label='Split image overlap ratio', value=0.2, minimum=0.0, maximum=0.9, step=0.05) + with gr.Row(visible=False) as process_focal_crop_row: + process_focal_crop_face_weight = gr.Slider(label='Focal point face weight', value=0.9, minimum=0.0, maximum=1.0, step=0.05) + process_focal_crop_entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.3, minimum=0.0, maximum=1.0, step=0.05) + process_focal_crop_edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05) + process_focal_crop_debug = gr.Checkbox(label='Create debug image') + with gr.Row(): with gr.Column(scale=3): gr.HTML(value="") @@ -1281,6 +1287,12 @@ def create_ui(wrap_gradio_gpu_call): outputs=[process_split_extra_row], ) + process_focal_crop.change( + fn=lambda show: gr_show(show), + inputs=[process_focal_crop], + outputs=[process_focal_crop_row], + ) + with gr.Tab(label="Train"): gr.HTML(value="

Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images [wiki]

") with gr.Row(): @@ -1368,7 +1380,11 @@ def create_ui(wrap_gradio_gpu_call): process_caption_deepbooru, process_split_threshold, process_overlap_ratio, - process_entropy_focus, + process_focal_crop, + process_focal_crop_face_weight, + process_focal_crop_entropy_weight, + process_focal_crop_edges_weight, + process_focal_crop_debug, ], outputs=[ ti_output, -- cgit v1.2.1 From 54f0c1482427a5b3f2248b97be55878e742cbcb1 Mon Sep 17 00:00:00 2001 From: captin411 Date: Tue, 25 Oct 2022 16:14:13 -0700 Subject: download better face detection module dynamically --- modules/textual_inversion/autocrop.py | 20 ++++++++++++++++++++ modules/textual_inversion/preprocess.py | 13 +++++++++++-- 2 files changed, 31 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/textual_inversion/autocrop.py b/modules/textual_inversion/autocrop.py index caaf18c8..01a92b12 100644 --- a/modules/textual_inversion/autocrop.py +++ b/modules/textual_inversion/autocrop.py @@ -1,4 +1,5 @@ import cv2 +import requests import os from collections import defaultdict from math import log, sqrt @@ -293,6 +294,25 @@ def is_square(w, h): return w == h +def download_and_cache_models(dirname): + download_url = 'https://github.com/opencv/opencv_zoo/blob/91fb0290f50896f38a0ab1e558b74b16bc009428/models/face_detection_yunet/face_detection_yunet_2022mar.onnx?raw=true' + model_file_name = 'face_detection_yunet.onnx' + + if not os.path.exists(dirname): + os.makedirs(dirname) + + cache_file = os.path.join(dirname, model_file_name) + if not os.path.exists(cache_file): + print(f"downloading face detection model from '{download_url}' to '{cache_file}'") + response = requests.get(download_url) + with open(cache_file, "wb") as f: + f.write(response.content) + + if os.path.exists(cache_file): + return cache_file + return None + + class PointOfInterest: def __init__(self, x, y, weight=1.0, size=10): self.x = x diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py index 1e4d4de8..e13b1894 100644 --- a/modules/textual_inversion/preprocess.py +++ b/modules/textual_inversion/preprocess.py @@ -7,6 +7,7 @@ import tqdm import time from modules import shared, images +from modules.paths import models_path from modules.shared import opts, cmd_opts from modules.textual_inversion import autocrop if cmd_opts.deepdanbooru: @@ -146,14 +147,22 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre save_pic(splitted, index, existing_caption=existing_caption) process_default_resize = False - if process_entropy_focus and img.height != img.width: + if process_focal_crop and img.height != img.width: + + dnn_model_path = None + try: + dnn_model_path = autocrop.download_and_cache_models(os.path.join(models_path, "opencv")) + except Exception as e: + print("Unable to load face detection model for auto crop selection. Falling back to lower quality haar method.", e) + autocrop_settings = autocrop.Settings( crop_width = width, crop_height = height, face_points_weight = process_focal_crop_face_weight, entropy_points_weight = process_focal_crop_entropy_weight, corner_points_weight = process_focal_crop_edges_weight, - annotate_image = process_focal_crop_debug + annotate_image = process_focal_crop_debug, + dnn_model_path = dnn_model_path, ) for focal in autocrop.crop_image(img, autocrop_settings): save_pic(focal, index, existing_caption=existing_caption) -- cgit v1.2.1 From df0c5ea29d7f0c682ac81f184f3e482a6450d018 Mon Sep 17 00:00:00 2001 From: captin411 Date: Tue, 25 Oct 2022 17:06:59 -0700 Subject: update default weights --- modules/textual_inversion/autocrop.py | 16 ++++++++-------- modules/ui.py | 2 +- 2 files changed, 9 insertions(+), 9 deletions(-) (limited to 'modules') diff --git a/modules/textual_inversion/autocrop.py b/modules/textual_inversion/autocrop.py index 01a92b12..9859974a 100644 --- a/modules/textual_inversion/autocrop.py +++ b/modules/textual_inversion/autocrop.py @@ -71,9 +71,9 @@ def crop_image(im, settings): return results def focal_point(im, settings): - corner_points = image_corner_points(im, settings) - entropy_points = image_entropy_points(im, settings) - face_points = image_face_points(im, settings) + corner_points = image_corner_points(im, settings) if settings.corner_points_weight > 0 else [] + entropy_points = image_entropy_points(im, settings) if settings.entropy_points_weight > 0 else [] + face_points = image_face_points(im, settings) if settings.face_points_weight > 0 else [] pois = [] @@ -144,7 +144,7 @@ def image_face_points(im, settings): settings.dnn_model_path, "", (im.width, im.height), - 0.8, # score threshold + 0.9, # score threshold 0.3, # nms threshold 5000 # keep top k before nms ) @@ -159,7 +159,7 @@ def image_face_points(im, settings): results.append( PointOfInterest( int(x + (w * 0.5)), # face focus left/right is center - int(y + (h * 0)), # face focus up/down is close to the top of the head + int(y + (h * 0.33)), # face focus up/down is close to the top of the head size = w, weight = 1/len(faces[1]) ) @@ -207,7 +207,7 @@ def image_corner_points(im, settings): np_im, maxCorners=100, qualityLevel=0.04, - minDistance=min(grayscale.width, grayscale.height)*0.03, + minDistance=min(grayscale.width, grayscale.height)*0.06, useHarrisDetector=False, ) @@ -256,8 +256,8 @@ def image_entropy_points(im, settings): def image_entropy(im): # greyscale image entropy - band = np.asarray(im.convert("L")) - # band = np.asarray(im.convert("1"), dtype=np.uint8) + # band = np.asarray(im.convert("L")) + band = np.asarray(im.convert("1"), dtype=np.uint8) hist, _ = np.histogram(band, bins=range(0, 256)) hist = hist[hist > 0] return -np.log2(hist / hist.sum()).sum() diff --git a/modules/ui.py b/modules/ui.py index 95b9c703..095200a8 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1270,7 +1270,7 @@ def create_ui(wrap_gradio_gpu_call): with gr.Row(visible=False) as process_focal_crop_row: process_focal_crop_face_weight = gr.Slider(label='Focal point face weight', value=0.9, minimum=0.0, maximum=1.0, step=0.05) - process_focal_crop_entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.3, minimum=0.0, maximum=1.0, step=0.05) + process_focal_crop_entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.15, minimum=0.0, maximum=1.0, step=0.05) process_focal_crop_edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05) process_focal_crop_debug = gr.Checkbox(label='Create debug image') -- cgit v1.2.1 From de096d0ce752c96e45508dcc7b9e84f7dbe10cca Mon Sep 17 00:00:00 2001 From: AngelBottomless <35677394+aria1th@users.noreply.github.com> Date: Tue, 25 Oct 2022 14:48:49 +0900 Subject: Weight initialization and More activation func add weight init add weight init option in create_hypernetwork fstringify hypernet info save weight initialization info for further debugging fill bias with zero for He/Xavier initialize LayerNorm with Normal fix loading weight_init --- modules/hypernetworks/hypernetwork.py | 47 ++++++++++++++++++++++++++++------- modules/hypernetworks/ui.py | 4 ++- modules/ui.py | 4 ++- 3 files changed, 44 insertions(+), 11 deletions(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index d647ea55..afbcdff8 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -5,6 +5,7 @@ import html import os import sys import traceback +import inspect import modules.textual_inversion.dataset import torch @@ -15,10 +16,12 @@ from modules import devices, processing, sd_models, shared from modules.textual_inversion import textual_inversion from modules.textual_inversion.learn_schedule import LearnRateScheduler from torch import einsum +from torch.nn.init import normal_, xavier_normal_, xavier_uniform_, kaiming_normal_, kaiming_uniform_, zeros_ from collections import defaultdict, deque from statistics import stdev, mean + class HypernetworkModule(torch.nn.Module): multiplier = 1.0 activation_dict = { @@ -26,9 +29,12 @@ class HypernetworkModule(torch.nn.Module): "leakyrelu": torch.nn.LeakyReLU, "elu": torch.nn.ELU, "swish": torch.nn.Hardswish, + "tanh": torch.nn.Tanh, + "sigmoid": torch.nn.Sigmoid, } + activation_dict.update({cls_name: cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'}) - def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False): + def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal', add_layer_norm=False, use_dropout=False): super().__init__() assert layer_structure is not None, "layer_structure must not be None" @@ -65,9 +71,24 @@ class HypernetworkModule(torch.nn.Module): else: for layer in self.linear: if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm: - layer.weight.data.normal_(mean=0.0, std=0.01) - layer.bias.data.zero_() - + w, b = layer.weight.data, layer.bias.data + if weight_init == "Normal" or type(layer) == torch.nn.LayerNorm: + normal_(w, mean=0.0, std=0.01) + normal_(b, mean=0.0, std=0.005) + elif weight_init == 'XavierUniform': + xavier_uniform_(w) + zeros_(b) + elif weight_init == 'XavierNormal': + xavier_normal_(w) + zeros_(b) + elif weight_init == 'KaimingUniform': + kaiming_uniform_(w, nonlinearity='leaky_relu' if 'leakyrelu' == activation_func else 'relu') + zeros_(b) + elif weight_init == 'KaimingNormal': + kaiming_normal_(w, nonlinearity='leaky_relu' if 'leakyrelu' == activation_func else 'relu') + zeros_(b) + else: + raise KeyError(f"Key {weight_init} is not defined as initialization!") self.to(devices.device) def fix_old_state_dict(self, state_dict): @@ -105,7 +126,7 @@ class Hypernetwork: filename = None name = None - def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False): + def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False): self.filename = None self.name = name self.layers = {} @@ -114,13 +135,14 @@ class Hypernetwork: self.sd_checkpoint_name = None self.layer_structure = layer_structure self.activation_func = activation_func + self.weight_init = weight_init self.add_layer_norm = add_layer_norm self.use_dropout = use_dropout for size in enable_sizes or []: self.layers[size] = ( - HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout), - HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout), + HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout), + HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout), ) def weights(self): @@ -144,6 +166,7 @@ class Hypernetwork: state_dict['layer_structure'] = self.layer_structure state_dict['activation_func'] = self.activation_func state_dict['is_layer_norm'] = self.add_layer_norm + state_dict['weight_initialization'] = self.weight_init state_dict['use_dropout'] = self.use_dropout state_dict['sd_checkpoint'] = self.sd_checkpoint state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name @@ -158,15 +181,21 @@ class Hypernetwork: state_dict = torch.load(filename, map_location='cpu') self.layer_structure = state_dict.get('layer_structure', [1, 2, 1]) + print(self.layer_structure) self.activation_func = state_dict.get('activation_func', None) + print(f"Activation function is {self.activation_func}") + self.weight_init = state_dict.get('weight_initialization', 'Normal') + print(f"Weight initialization is {self.weight_init}") self.add_layer_norm = state_dict.get('is_layer_norm', False) + print(f"Layer norm is set to {self.add_layer_norm}") self.use_dropout = state_dict.get('use_dropout', False) + print(f"Dropout usage is set to {self.use_dropout}" ) for size, sd in state_dict.items(): if type(size) == int: self.layers[size] = ( - HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout), - HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout), + HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout), + HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout), ) self.name = state_dict.get('name', self.name) diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py index 2b472d87..2c6c0470 100644 --- a/modules/hypernetworks/ui.py +++ b/modules/hypernetworks/ui.py @@ -8,8 +8,9 @@ import modules.textual_inversion.textual_inversion from modules import devices, sd_hijack, shared from modules.hypernetworks import hypernetwork +keys = list(hypernetwork.HypernetworkModule.activation_dict.keys()) -def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False): +def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False): # Remove illegal characters from name. name = "".join( x for x in name if (x.isalnum() or x in "._- ")) @@ -25,6 +26,7 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, enable_sizes=[int(x) for x in enable_sizes], layer_structure=layer_structure, activation_func=activation_func, + weight_init=weight_init, add_layer_norm=add_layer_norm, use_dropout=use_dropout, ) diff --git a/modules/ui.py b/modules/ui.py index 03528968..8e343258 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1238,7 +1238,8 @@ def create_ui(wrap_gradio_gpu_call): new_hypernetwork_name = gr.Textbox(label="Name") new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"]) new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'") - new_hypernetwork_activation_func = gr.Dropdown(value="relu", label="Select activation function of hypernetwork", choices=["linear", "relu", "leakyrelu", "elu", "swish"]) + new_hypernetwork_activation_func = gr.Dropdown(value="relu", label="Select activation function of hypernetwork", choices=modules.hypernetworks.ui.keys) + new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. relu-like - Kaiming, sigmoid-like - Xavier is recommended", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"]) new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization") new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout") overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork") @@ -1342,6 +1343,7 @@ def create_ui(wrap_gradio_gpu_call): overwrite_old_hypernetwork, new_hypernetwork_layer_structure, new_hypernetwork_activation_func, + new_hypernetwork_initialization_option, new_hypernetwork_add_layer_norm, new_hypernetwork_use_dropout ], -- cgit v1.2.1 From 7207e3bf49ed000464d288cd67e02f0ba8614dc3 Mon Sep 17 00:00:00 2001 From: AngelBottomless <35677394+aria1th@users.noreply.github.com> Date: Tue, 25 Oct 2022 15:24:59 +0900 Subject: remove duplicate keys and lowercase --- modules/hypernetworks/hypernetwork.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index afbcdff8..842b6447 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -32,7 +32,7 @@ class HypernetworkModule(torch.nn.Module): "tanh": torch.nn.Tanh, "sigmoid": torch.nn.Sigmoid, } - activation_dict.update({cls_name: cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'}) + activation_dict.update({cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'}) def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal', add_layer_norm=False, use_dropout=False): super().__init__() -- cgit v1.2.1 From cbb857b675cf0f169b21515c29da492b513cc8c4 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 26 Oct 2022 09:44:02 +0300 Subject: enable creating embedding with --medvram --- modules/textual_inversion/textual_inversion.py | 3 +++ 1 file changed, 3 insertions(+) (limited to 'modules') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 529ed3e2..647ffe3e 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -157,6 +157,9 @@ def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'): cond_model = shared.sd_model.cond_stage_model embedding_layer = cond_model.wrapped.transformer.text_model.embeddings + with devices.autocast(): + cond_model([""]) # will send cond model to GPU if lowvram/medvram is active + ids = cond_model.tokenizer(init_text, max_length=num_vectors_per_token, return_tensors="pt", add_special_tokens=False)["input_ids"] embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0) vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device) -- cgit v1.2.1 From db9ab1a46b5ad4d36ecce76dfee04b7164249829 Mon Sep 17 00:00:00 2001 From: Stephen Date: Mon, 24 Oct 2022 11:16:07 -0400 Subject: [Bugfix][API] - Fix API response for colab users --- modules/api/api.py | 17 +++++++++++++---- modules/api/models.py | 10 ++++++---- 2 files changed, 19 insertions(+), 8 deletions(-) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index a860a964..ba890243 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -7,6 +7,7 @@ import uvicorn from fastapi import Body, APIRouter, HTTPException from fastapi.responses import JSONResponse from pydantic import BaseModel, Field, Json +from typing import List import json import io import base64 @@ -15,12 +16,12 @@ from PIL import Image sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None) class TextToImageResponse(BaseModel): - images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.") + images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.") parameters: Json info: Json class ImageToImageResponse(BaseModel): - images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.") + images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.") parameters: Json info: Json @@ -41,6 +42,9 @@ class Api: # convert base64 to PIL image return Image.open(io.BytesIO(imgdata)) + def __processed_info_to_json(self, processed): + return json.dumps(processed.info) + def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI): sampler_index = sampler_to_index(txt2imgreq.sampler_index) @@ -65,7 +69,7 @@ class Api: i.save(buffer, format="png") b64images.append(base64.b64encode(buffer.getvalue())) - return TextToImageResponse(images=b64images, parameters=json.dumps(vars(txt2imgreq)), info=json.dumps(processed.info)) + return TextToImageResponse(images=b64images, parameters=json.dumps(vars(txt2imgreq)), info=processed.js()) @@ -111,7 +115,12 @@ class Api: i.save(buffer, format="png") b64images.append(base64.b64encode(buffer.getvalue())) - return ImageToImageResponse(images=b64images, parameters=json.dumps(vars(img2imgreq)), info=json.dumps(processed.info)) + if (not img2imgreq.include_init_images): + # remove img2imgreq.init_images and img2imgreq.mask + img2imgreq.init_images = None + img2imgreq.mask = None + + return ImageToImageResponse(images=b64images, parameters=json.dumps(vars(img2imgreq)), info=processed.js()) def extrasapi(self): raise NotImplementedError diff --git a/modules/api/models.py b/modules/api/models.py index f551fa35..c6d43606 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -31,6 +31,7 @@ class ModelDef(BaseModel): field_alias: str field_type: Any field_value: Any + field_exclude: bool = False class PydanticModelGenerator: @@ -68,7 +69,7 @@ class PydanticModelGenerator: field=underscore(k), field_alias=k, field_type=field_type_generator(k, v), - field_value=v.default + field_value=v.default, ) for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED ] @@ -78,7 +79,8 @@ class PydanticModelGenerator: field=underscore(fields["key"]), field_alias=fields["key"], field_type=fields["type"], - field_value=fields["default"])) + field_value=fields["default"], + field_exclude=fields["exclude"] if "exclude" in fields else False)) def generate_model(self): """ @@ -86,7 +88,7 @@ class PydanticModelGenerator: from the json and overrides provided at initialization """ fields = { - d.field: (d.field_type, Field(default=d.field_value, alias=d.field_alias)) for d in self._model_def + d.field: (d.field_type, Field(default=d.field_value, alias=d.field_alias, exclude=d.field_exclude)) for d in self._model_def } DynamicModel = create_model(self._model_name, **fields) DynamicModel.__config__.allow_population_by_field_name = True @@ -102,5 +104,5 @@ StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator( StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator( "StableDiffusionProcessingImg2Img", StableDiffusionProcessingImg2Img, - [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}] + [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}, {"key": "include_init_images", "type": bool, "default": False, "exclude" : True}] ).generate_model() \ No newline at end of file -- cgit v1.2.1 From b46c64c6e5b40d69521e4d50e2d35f6a35468129 Mon Sep 17 00:00:00 2001 From: Stephen Date: Mon, 24 Oct 2022 12:18:54 -0400 Subject: clean --- modules/api/api.py | 4 ---- modules/api/models.py | 2 +- 2 files changed, 1 insertion(+), 5 deletions(-) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index ba890243..6e9d6097 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -42,9 +42,6 @@ class Api: # convert base64 to PIL image return Image.open(io.BytesIO(imgdata)) - def __processed_info_to_json(self, processed): - return json.dumps(processed.info) - def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI): sampler_index = sampler_to_index(txt2imgreq.sampler_index) @@ -116,7 +113,6 @@ class Api: b64images.append(base64.b64encode(buffer.getvalue())) if (not img2imgreq.include_init_images): - # remove img2imgreq.init_images and img2imgreq.mask img2imgreq.init_images = None img2imgreq.mask = None diff --git a/modules/api/models.py b/modules/api/models.py index c6d43606..079e33d9 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -69,7 +69,7 @@ class PydanticModelGenerator: field=underscore(k), field_alias=k, field_type=field_type_generator(k, v), - field_value=v.default, + field_value=v.default ) for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED ] -- cgit v1.2.1 From 146856f66d7e06a762f5ef5bf61a226057de6757 Mon Sep 17 00:00:00 2001 From: Milly Date: Tue, 25 Oct 2022 06:21:31 +0900 Subject: images: allow nested bracket in filename pattern --- modules/images.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) (limited to 'modules') diff --git a/modules/images.py b/modules/images.py index 286de2ae..ed448a8a 100644 --- a/modules/images.py +++ b/modules/images.py @@ -277,7 +277,7 @@ invalid_filename_chars = '<>:"/\\|?*\n' invalid_filename_prefix = ' ' invalid_filename_postfix = ' .' re_nonletters = re.compile(r'[\s' + string.punctuation + ']+') -re_pattern = re.compile(r"([^\[\]]+|\[([^]]+)]|[\[\]]*)") +re_pattern = re.compile(r"(.*?)(?:\[([^\[\]]+)\]|$)") re_pattern_arg = re.compile(r"(.*)<([^>]*)>$") max_filename_part_length = 128 @@ -362,9 +362,9 @@ class FilenameGenerator: for m in re_pattern.finditer(x): text, pattern = m.groups() + res += text if pattern is None: - res += text continue pattern_args = [] @@ -385,12 +385,9 @@ class FilenameGenerator: print(f"Error adding [{pattern}] to filename", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) - if replacement is None: - res += f'[{pattern}]' - else: + if replacement is not None: res += str(replacement) - - continue + continue res += f'[{pattern}]' -- cgit v1.2.1 From 757264c453eca533ee1c9ea7e9d9b45a009367d7 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Tue, 25 Oct 2022 23:39:21 +0900 Subject: default_time_format if format is blank --- modules/images.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/images.py b/modules/images.py index ed448a8a..bfc2ba06 100644 --- a/modules/images.py +++ b/modules/images.py @@ -343,7 +343,7 @@ class FilenameGenerator: def datetime(self, *args): time_datetime = datetime.datetime.now() - time_format = args[0] if len(args) > 0 else self.default_time_format + time_format = args[0] if len(args) > 0 and args[0] != "" else self.default_time_format try: time_zone = pytz.timezone(args[1]) if len(args) > 1 else None except pytz.exceptions.UnknownTimeZoneError as _: -- cgit v1.2.1 From 9d82c351ac36d1511f5f65b24443c60250ee3e9e Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 26 Oct 2022 09:56:25 +0300 Subject: fix typo in on_save_imaged/on_image_saved; hope no extension is using it yet --- modules/script_callbacks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index dc520abc..6803d57b 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -90,7 +90,7 @@ def on_ui_settings(callback): add_callback(callbacks_ui_settings, callback) -def on_save_imaged(callback): +def on_image_saved(callback): """register a function to be called after modules.images.save_image is called. The callback is called with three arguments: - p - procesing object (or a dummy object with same fields if the image is saved using save button) -- cgit v1.2.1 From cb49800c08a9f6619733250401952e5571dc12f8 Mon Sep 17 00:00:00 2001 From: timntorres Date: Tue, 25 Oct 2022 01:39:59 -0700 Subject: img2img, use smartphone photos' EXIF orientation --- modules/img2img.py | 8 ++++++++ 1 file changed, 8 insertions(+) (limited to 'modules') diff --git a/modules/img2img.py b/modules/img2img.py index 8d9f7cf9..9c0cf23e 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -39,6 +39,8 @@ def process_batch(p, input_dir, output_dir, args): break img = Image.open(image) + # Use the EXIF orientation of photos taken by smartphones. + img = ImageOps.exif_transpose(img) p.init_images = [img] * p.batch_size proc = modules.scripts.scripts_img2img.run(p, *args) @@ -61,19 +63,25 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro is_batch = mode == 2 if is_inpaint: + # Drawn mask if mask_mode == 0: image = init_img_with_mask['image'] mask = init_img_with_mask['mask'] alpha_mask = ImageOps.invert(image.split()[-1]).convert('L').point(lambda x: 255 if x > 0 else 0, mode='1') mask = ImageChops.lighter(alpha_mask, mask.convert('L')).convert('L') image = image.convert('RGB') + # Uploaded mask else: image = init_img_inpaint mask = init_mask_inpaint + # No mask else: image = init_img mask = None + # Use the EXIF orientation of photos taken by smartphones. + image = ImageOps.exif_transpose(image) + assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]' p = StableDiffusionProcessingImg2Img( -- cgit v1.2.1 From a524d137d0a89bb19a6676dc9b8fbb5d1b580678 Mon Sep 17 00:00:00 2001 From: timntorres Date: Mon, 24 Oct 2022 23:48:05 -0700 Subject: patch bug (SeverianVoid's comment on 5245c7a) --- modules/hypernetworks/hypernetwork.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 842b6447..8113b35b 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -487,7 +487,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log if image is not None: shared.state.current_image = image - last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename) + last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False) last_saved_image += f", prompt: {preview_text}" shared.state.job_no = hypernetwork.step -- cgit v1.2.1 From c2dc9bfa89070b8e1d857f8773a790b752f1b709 Mon Sep 17 00:00:00 2001 From: timntorres Date: Mon, 24 Oct 2022 23:22:58 -0700 Subject: Implement PR #3189 but for embeddings. --- modules/textual_inversion/textual_inversion.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) (limited to 'modules') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 647ffe3e..22c7b54b 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -10,7 +10,7 @@ import csv from PIL import Image, PngImagePlugin -from modules import shared, devices, sd_hijack, processing, sd_models +from modules import shared, devices, sd_hijack, processing, sd_models, images import modules.textual_inversion.dataset from modules.textual_inversion.learn_schedule import LearnRateScheduler @@ -247,6 +247,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc last_saved_file = "" last_saved_image = "" + forced_filename = "" embedding_yet_to_be_embedded = False ititial_step = embedding.step or 0 @@ -296,8 +297,8 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc }) if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0: - last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png') - + forced_filename = f'{embedding_name}-{embedding.step}' + last_saved_image = os.path.join(images_dir, forced_filename) p = processing.StableDiffusionProcessingTxt2Img( sd_model=shared.sd_model, do_not_save_grid=True, @@ -353,8 +354,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info) embedding_yet_to_be_embedded = False - image.save(last_saved_image) - + last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename) last_saved_image += f", prompt: {preview_text}" shared.state.job_no = embedding.step -- cgit v1.2.1 From 4875a6c217df5cc06ee2bf11fb645b172c7156a8 Mon Sep 17 00:00:00 2001 From: timntorres Date: Mon, 24 Oct 2022 23:38:07 -0700 Subject: Implement PR #3309 but for embeddings. --- modules/textual_inversion/textual_inversion.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 22c7b54b..4921bd01 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -167,6 +167,8 @@ def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'): for i in range(num_vectors_per_token): vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token] + # Remove illegal characters from name. + name = "".join( x for x in name if (x.isalnum() or x in "._- ")) fn = os.path.join(shared.cmd_opts.embeddings_dir, f"{name}.pt") if not overwrite_old: assert not os.path.exists(fn), f"file {fn} already exists" @@ -287,7 +289,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc pbar.set_description(f"[Epoch {epoch_num}: {epoch_step}/{len(ds)}]loss: {losses.mean():.7f}") if embedding.step > 0 and embedding_dir is not None and embedding.step % save_embedding_every == 0: - last_saved_file = os.path.join(embedding_dir, f'{embedding_name}-{embedding.step}.pt') + # Before saving, change name to match current checkpoint. + embedding.name = f'{embedding_name}-{embedding.step}' + last_saved_file = os.path.join(embedding_dir, f'{embedding.name}.pt') embedding.save(last_saved_file) embedding_yet_to_be_embedded = True @@ -374,6 +378,9 @@ Last saved image: {html.escape(last_saved_image)}
embedding.sd_checkpoint = checkpoint.hash embedding.sd_checkpoint_name = checkpoint.model_name embedding.cached_checksum = None + # Before saving for the last time, change name back to base name (as opposed to the save_embedding_every step-suffixed naming convention). + embedding.name = embedding_name + filename = os.path.join(shared.cmd_opts.embedding_dir, f'{embedding.name}.pt') embedding.save(filename) return embedding, filename -- cgit v1.2.1 From f4e14642173a04723200b131deb417c6c79cab17 Mon Sep 17 00:00:00 2001 From: timntorres Date: Tue, 25 Oct 2022 00:04:25 -0700 Subject: Implement PR #3625 but for embeddings. --- modules/textual_inversion/textual_inversion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 4921bd01..4fcebe74 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -358,7 +358,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info) embedding_yet_to_be_embedded = False - last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename) + last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False) last_saved_image += f", prompt: {preview_text}" shared.state.job_no = embedding.step -- cgit v1.2.1 From 1e428238db4e399b7a06ad5251cb16eef23a014d Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 26 Oct 2022 11:47:07 +0300 Subject: add override_settings to API as an alternative to #3629 --- modules/processing.py | 25 ++++++++++++++++++++----- modules/shared.py | 4 ++-- 2 files changed, 22 insertions(+), 7 deletions(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index c61bbfbd..4efba946 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -77,9 +77,8 @@ def get_correct_sampler(p): class StableDiffusionProcessing(): """ The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing - """ - def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str="", styles: List[str]=None, seed: int=-1, subseed: int=-1, subseed_strength: float=0, seed_resize_from_h: int=-1, seed_resize_from_w: int=-1, seed_enable_extras: bool=True, sampler_index: int=0, batch_size: int=1, n_iter: int=1, steps:int =50, cfg_scale:float=7.0, width:int=512, height:int=512, restore_faces:bool=False, tiling:bool=False, do_not_save_samples:bool=False, do_not_save_grid:bool=False, extra_generation_params: Dict[Any,Any]=None, overlay_images: Any=None, negative_prompt: str=None, eta: float =None, do_not_reload_embeddings: bool=False, denoising_strength: float = 0, ddim_discretize: str = "uniform", s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0): + def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_index: int = 0, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None): self.sd_model = sd_model self.outpath_samples: str = outpath_samples self.outpath_grids: str = outpath_grids @@ -109,13 +108,14 @@ class StableDiffusionProcessing(): self.do_not_reload_embeddings = do_not_reload_embeddings self.paste_to = None self.color_corrections = None - self.denoising_strength: float = 0 + self.denoising_strength: float = denoising_strength self.sampler_noise_scheduler_override = None - self.ddim_discretize = opts.ddim_discretize + self.ddim_discretize = ddim_discretize or opts.ddim_discretize self.s_churn = s_churn or opts.s_churn self.s_tmin = s_tmin or opts.s_tmin self.s_tmax = s_tmax or float('inf') # not representable as a standard ui option self.s_noise = s_noise or opts.s_noise + self.override_settings = {k: v for k, v in (override_settings or {}).items() if k not in shared.restricted_opts} if not seed_enable_extras: self.subseed = -1 @@ -129,7 +129,6 @@ class StableDiffusionProcessing(): self.all_seeds = None self.all_subseeds = None - def init(self, all_prompts, all_seeds, all_subseeds): pass @@ -351,6 +350,22 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration def process_images(p: StableDiffusionProcessing) -> Processed: + stored_opts = {k: opts.data[k] for k in p.override_settings.keys()} + + try: + for k, v in p.override_settings.items(): + opts.data[k] = v # we don't call onchange for simplicity which makes changing model, hypernet impossible + + res = process_images_inner(p) + + finally: + for k, v in stored_opts.items(): + opts.data[k] = v + + return res + + +def process_images_inner(p: StableDiffusionProcessing) -> Processed: """this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch""" if type(p.prompt) == list: diff --git a/modules/shared.py b/modules/shared.py index 308fccce..1a9b8289 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -84,7 +84,7 @@ parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load mod parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None) cmd_opts = parser.parse_args() -restricted_opts = [ +restricted_opts = { "samples_filename_pattern", "directories_filename_pattern", "outdir_samples", @@ -94,7 +94,7 @@ restricted_opts = [ "outdir_grids", "outdir_txt2img_grids", "outdir_save", -] +} devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_swinir, devices.device_esrgan, devices.device_scunet, devices.device_codeformer = \ (devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'swinir', 'esrgan', 'scunet', 'codeformer']) -- cgit v1.2.1 From 0cd74602531a40f72d1a75b471a8a9166135d333 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 26 Oct 2022 13:12:44 +0300 Subject: add script callback for before image save and change callback for after image save to use a class with parameters --- modules/images.py | 42 ++++++++++++++++++++++----------------- modules/script_callbacks.py | 48 +++++++++++++++++++++++++++++++++++++-------- 2 files changed, 64 insertions(+), 26 deletions(-) (limited to 'modules') diff --git a/modules/images.py b/modules/images.py index bfc2ba06..7870b5b7 100644 --- a/modules/images.py +++ b/modules/images.py @@ -451,17 +451,6 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i """ namegen = FilenameGenerator(p, seed, prompt) - if extension == 'png' and opts.enable_pnginfo and info is not None: - pnginfo = PngImagePlugin.PngInfo() - - if existing_info is not None: - for k, v in existing_info.items(): - pnginfo.add_text(k, str(v)) - - pnginfo.add_text(pnginfo_section_name, info) - else: - pnginfo = None - if save_to_dirs is None: save_to_dirs = (grid and opts.grid_save_to_dirs) or (not grid and opts.save_to_dirs and not no_prompt) @@ -489,19 +478,27 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i if add_number: basecount = get_next_sequence_number(path, basename) fullfn = None - fullfn_without_extension = None for i in range(500): fn = f"{basecount + i:05}" if basename == '' else f"{basename}-{basecount + i:04}" fullfn = os.path.join(path, f"{fn}{file_decoration}.{extension}") - fullfn_without_extension = os.path.join(path, f"{fn}{file_decoration}") if not os.path.exists(fullfn): break else: fullfn = os.path.join(path, f"{file_decoration}.{extension}") - fullfn_without_extension = os.path.join(path, file_decoration) else: fullfn = os.path.join(path, f"{forced_filename}.{extension}") - fullfn_without_extension = os.path.join(path, forced_filename) + + pnginfo = existing_info or {} + if info is not None: + pnginfo[pnginfo_section_name] = info + + params = script_callbacks.ImageSaveParams(image, p, fullfn, pnginfo) + script_callbacks.before_image_saved_callback(params) + + image = params.image + fullfn = params.filename + info = params.pnginfo.get(pnginfo_section_name, None) + fullfn_without_extension, extension = os.path.splitext(params.filename) def exif_bytes(): return piexif.dump({ @@ -510,12 +507,20 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i }, }) - if extension.lower() in ("jpg", "jpeg", "webp"): + if extension.lower() == '.png': + pnginfo_data = PngImagePlugin.PngInfo() + for k, v in params.pnginfo.items(): + pnginfo_data.add_text(k, str(v)) + + image.save(fullfn, quality=opts.jpeg_quality, pnginfo=pnginfo_data) + + elif extension.lower() in (".jpg", ".jpeg", ".webp"): image.save(fullfn, quality=opts.jpeg_quality) + if opts.enable_pnginfo and info is not None: piexif.insert(exif_bytes(), fullfn) else: - image.save(fullfn, quality=opts.jpeg_quality, pnginfo=pnginfo) + image.save(fullfn, quality=opts.jpeg_quality) target_side_length = 4000 oversize = image.width > target_side_length or image.height > target_side_length @@ -538,7 +543,8 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i else: txt_fullfn = None - script_callbacks.image_saved_callback(image, p, fullfn, txt_fullfn) + script_callbacks.image_saved_callback(params) + return fullfn, txt_fullfn diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index 6803d57b..6ea58d61 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -9,15 +9,34 @@ def report_exception(c, job): print(traceback.format_exc(), file=sys.stderr) +class ImageSaveParams: + def __init__(self, image, p, filename, pnginfo): + self.image = image + """the PIL image itself""" + + self.p = p + """p object with processing parameters; either StableDiffusionProcessing or an object with same fields""" + + self.filename = filename + """name of file that the image would be saved to""" + + self.pnginfo = pnginfo + """dictionary with parameters for image's PNG info data; infotext will have the key 'parameters'""" + + ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"]) callbacks_model_loaded = [] callbacks_ui_tabs = [] callbacks_ui_settings = [] +callbacks_before_image_saved = [] callbacks_image_saved = [] + def clear_callbacks(): callbacks_model_loaded.clear() callbacks_ui_tabs.clear() + callbacks_ui_settings.clear() + callbacks_before_image_saved.clear() callbacks_image_saved.clear() @@ -49,10 +68,18 @@ def ui_settings_callback(): report_exception(c, 'ui_settings_callback') -def image_saved_callback(image, p, fullfn, txt_fullfn): +def before_image_saved_callback(params: ImageSaveParams): for c in callbacks_image_saved: try: - c.callback(image, p, fullfn, txt_fullfn) + c.callback(params) + except Exception: + report_exception(c, 'before_image_saved_callback') + + +def image_saved_callback(params: ImageSaveParams): + for c in callbacks_image_saved: + try: + c.callback(params) except Exception: report_exception(c, 'image_saved_callback') @@ -64,7 +91,6 @@ def add_callback(callbacks, fun): callbacks.append(ScriptCallback(filename, fun)) - def on_model_loaded(callback): """register a function to be called when the stable diffusion model is created; the model is passed as an argument""" @@ -90,11 +116,17 @@ def on_ui_settings(callback): add_callback(callbacks_ui_settings, callback) +def on_before_image_saved(callback): + """register a function to be called before an image is saved to a file. + The callback is called with one argument: + - params: ImageSaveParams - parameters the image is to be saved with. You can change fields in this object. + """ + add_callback(callbacks_before_image_saved, callback) + + def on_image_saved(callback): - """register a function to be called after modules.images.save_image is called. - The callback is called with three arguments: - - p - procesing object (or a dummy object with same fields if the image is saved using save button) - - fullfn - image filename - - txt_fullfn - text file with parameters; may be None + """register a function to be called after an image is saved to a file. + The callback is called with one argument: + - params: ImageSaveParams - parameters the image was saved with. Changing fields in this object does nothing. """ add_callback(callbacks_image_saved, callback) -- cgit v1.2.1 From 737eb28faca8be2bb996ee0930ec77d1f7ebd939 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Wed, 26 Oct 2022 14:45:33 +0100 Subject: typo: cmd_opts.embedding_dir to cmd_opts.embeddings_dir --- modules/textual_inversion/textual_inversion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 4fcebe74..ff002d3e 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -380,7 +380,7 @@ Last saved image: {html.escape(last_saved_image)}
embedding.cached_checksum = None # Before saving for the last time, change name back to base name (as opposed to the save_embedding_every step-suffixed naming convention). embedding.name = embedding_name - filename = os.path.join(shared.cmd_opts.embedding_dir, f'{embedding.name}.pt') + filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding.name}.pt') embedding.save(filename) return embedding, filename -- cgit v1.2.1