aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorw-e-w <40751091+w-e-w@users.noreply.github.com>2024-01-21 07:20:52 +0900
committerw-e-w <40751091+w-e-w@users.noreply.github.com>2024-01-21 09:02:18 +0900
commite36827af3254f7bac9f8c78d6d56c709959b40b6 (patch)
treeff57f209fc3de3916e83aa6aa3f5fe3ae3f53998
parentf939bce845ae07536b1c920618743af83e0b01ec (diff)
improve get_crop_region
-rw-r--r--modules/masking.py43
-rw-r--r--modules/processing.py2
2 files changed, 10 insertions, 35 deletions
diff --git a/modules/masking.py b/modules/masking.py
index be9f84c7..29a39452 100644
--- a/modules/masking.py
+++ b/modules/masking.py
@@ -3,40 +3,15 @@ from PIL import Image, ImageFilter, ImageOps
def get_crop_region(mask, pad=0):
"""finds a rectangular region that contains all masked ares in an image. Returns (x1, y1, x2, y2) coordinates of the rectangle.
- For example, if a user has painted the top-right part of a 512x512 image", the result may be (256, 0, 512, 256)"""
-
- h, w = mask.shape
-
- crop_left = 0
- for i in range(w):
- if not (mask[:, i] == 0).all():
- break
- crop_left += 1
-
- crop_right = 0
- for i in reversed(range(w)):
- if not (mask[:, i] == 0).all():
- break
- crop_right += 1
-
- crop_top = 0
- for i in range(h):
- if not (mask[i] == 0).all():
- break
- crop_top += 1
-
- crop_bottom = 0
- for i in reversed(range(h)):
- if not (mask[i] == 0).all():
- break
- crop_bottom += 1
-
- return (
- int(max(crop_left-pad, 0)),
- int(max(crop_top-pad, 0)),
- int(min(w - crop_right + pad, w)),
- int(min(h - crop_bottom + pad, h))
- )
+ For example, if a user has painted the top-right part of a 512x512 image, the result may be (256, 0, 512, 256)"""
+ mask_img = mask if isinstance(mask, Image.Image) else Image.fromarray(mask)
+ box = mask_img.getbbox()
+ if box:
+ x1, y1, x2, y2 = box
+ else: # when no box is found
+ x1, y1 = mask_img.size
+ x2 = y2 = 0
+ return max(x1 - pad, 0), max(y1 - pad, 0), min(x2 + pad, mask_img.size[0]), min(y2 + pad, mask_img.size[1])
def expand_crop_region(crop_region, processing_width, processing_height, image_width, image_height):
diff --git a/modules/processing.py b/modules/processing.py
index 6b631795..72d8093b 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -1562,7 +1562,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
if self.inpaint_full_res:
self.mask_for_overlay = image_mask
mask = image_mask.convert('L')
- crop_region = masking.get_crop_region(np.array(mask), self.inpaint_full_res_padding)
+ crop_region = masking.get_crop_region(mask, self.inpaint_full_res_padding)
crop_region = masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height)
x1, y1, x2, y2 = crop_region