Spaces:
Sleeping
Sleeping
zhiweili
commited on
Commit
·
6bdded7
1
Parent(s):
0ec7070
add croper
Browse files- .gitignore +2 -1
- app.py +4 -49
- croper.py +71 -0
.gitignore
CHANGED
@@ -1,2 +1,3 @@
|
|
1 |
.vscode
|
2 |
-
.DS_Store
|
|
|
|
1 |
.vscode
|
2 |
+
.DS_Store
|
3 |
+
__pycache__
|
app.py
CHANGED
@@ -1,10 +1,12 @@
|
|
1 |
import gradio as gr
|
2 |
import mediapipe as mp
|
3 |
import numpy as np
|
|
|
4 |
from PIL import Image
|
5 |
from mediapipe.tasks import python
|
6 |
from mediapipe.tasks.python import vision
|
7 |
from scipy.ndimage import binary_dilation, label
|
|
|
8 |
|
9 |
BG_COLOR = np.array([0, 0, 0], dtype=np.uint8) # black
|
10 |
MASK_COLOR = np.array([255, 255, 255], dtype=np.uint8) # white
|
@@ -15,10 +17,8 @@ base_options = python.BaseOptions(model_asset_path=MODEL_PATH)
|
|
15 |
options = vision.ImageSegmenterOptions(base_options=base_options,output_category_mask=True)
|
16 |
segmenter = vision.ImageSegmenter.create_from_options(options)
|
17 |
labels = segmenter.labels
|
18 |
-
expand_size = 40
|
19 |
|
20 |
def segment(input_image, category):
|
21 |
-
original_height, original_width = input_image.size
|
22 |
image = mp.Image(image_format=mp.ImageFormat.SRGB, data=np.asarray(input_image))
|
23 |
segmentation_result = segmenter.segment(image)
|
24 |
category_mask = segmentation_result.category_mask
|
@@ -30,54 +30,9 @@ def segment(input_image, category):
|
|
30 |
target_mask = get_clothes_mask(category_mask_np)
|
31 |
else:
|
32 |
target_mask = category_mask_np == 0
|
|
|
33 |
|
34 |
-
|
35 |
-
start_y = np.min(target_indices[0]) - expand_size
|
36 |
-
if start_y < 0:
|
37 |
-
start_y = 0
|
38 |
-
end_y = np.max(target_indices[0]) + expand_size
|
39 |
-
if end_y > original_height:
|
40 |
-
end_y = original_height
|
41 |
-
start_x = np.min(target_indices[1]) - expand_size
|
42 |
-
if start_x < 0:
|
43 |
-
start_x = 0
|
44 |
-
end_x = np.max(target_indices[1]) + expand_size
|
45 |
-
if end_x > original_width:
|
46 |
-
end_x = original_width
|
47 |
-
target_height = end_y - start_y
|
48 |
-
target_width = end_x - start_x
|
49 |
-
|
50 |
-
# choose the max side length
|
51 |
-
max_side_length = max(target_height, target_width)
|
52 |
-
# calculate the crop area
|
53 |
-
crop_mask = target_mask[start_y:end_y, start_x:end_x]
|
54 |
-
crop_mask_height, crop_mask_width = crop_mask.shape
|
55 |
-
crop_mask_start_y = (max_side_length - crop_mask_height) // 2
|
56 |
-
crop_mask_end_y = crop_mask_start_y + crop_mask_height
|
57 |
-
crop_mask_start_x = (max_side_length - crop_mask_width) // 2
|
58 |
-
crop_mask_end_x = crop_mask_start_x + crop_mask_width
|
59 |
-
# create a square mask
|
60 |
-
crop_mask_square = np.zeros((max_side_length, max_side_length), dtype=target_mask.dtype)
|
61 |
-
crop_mask_square[crop_mask_start_y:crop_mask_end_y, crop_mask_start_x:crop_mask_end_x] = crop_mask
|
62 |
-
# create a square image
|
63 |
-
crop_mask_image = Image.fromarray((crop_mask_square * 255).astype(np.uint8))
|
64 |
-
|
65 |
-
crop_image = input_image.crop((start_x, start_y, end_x, end_y))
|
66 |
-
crop_image_square = Image.new("RGB", (max_side_length, max_side_length))
|
67 |
-
crop_image_square.paste(crop_image, (crop_mask_start_x, crop_mask_start_y))
|
68 |
-
|
69 |
-
# Generate solid color images for showing the output segmentation mask.
|
70 |
-
image_data = image.numpy_view()
|
71 |
-
fg_image = np.zeros(image_data.shape, dtype=np.uint8)
|
72 |
-
fg_image[:] = MASK_COLOR
|
73 |
-
bg_image = np.zeros(image_data.shape, dtype=np.uint8)
|
74 |
-
bg_image[:] = BG_COLOR
|
75 |
-
|
76 |
-
condition = np.stack((target_mask,) * 3, axis=-1) > 0.2
|
77 |
-
|
78 |
-
output_image = np.where(condition, fg_image, bg_image)
|
79 |
-
output_image = Image.fromarray(output_image)
|
80 |
-
return crop_mask_image, crop_image_square
|
81 |
|
82 |
def get_clothes_mask(category_mask_np):
|
83 |
body_skin_mask = category_mask_np == 2
|
|
|
1 |
import gradio as gr
|
2 |
import mediapipe as mp
|
3 |
import numpy as np
|
4 |
+
|
5 |
from PIL import Image
|
6 |
from mediapipe.tasks import python
|
7 |
from mediapipe.tasks.python import vision
|
8 |
from scipy.ndimage import binary_dilation, label
|
9 |
+
from croper import Croper
|
10 |
|
11 |
BG_COLOR = np.array([0, 0, 0], dtype=np.uint8) # black
|
12 |
MASK_COLOR = np.array([255, 255, 255], dtype=np.uint8) # white
|
|
|
17 |
options = vision.ImageSegmenterOptions(base_options=base_options,output_category_mask=True)
|
18 |
segmenter = vision.ImageSegmenter.create_from_options(options)
|
19 |
labels = segmenter.labels
|
|
|
20 |
|
21 |
def segment(input_image, category):
|
|
|
22 |
image = mp.Image(image_format=mp.ImageFormat.SRGB, data=np.asarray(input_image))
|
23 |
segmentation_result = segmenter.segment(image)
|
24 |
category_mask = segmentation_result.category_mask
|
|
|
30 |
target_mask = get_clothes_mask(category_mask_np)
|
31 |
else:
|
32 |
target_mask = category_mask_np == 0
|
33 |
+
croper = Croper(input_image, target_mask)
|
34 |
|
35 |
+
return croper.corp_mask_image()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
|
37 |
def get_clothes_mask(category_mask_np):
|
38 |
body_skin_mask = category_mask_np == 2
|
croper.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import PIL
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
from PIL import Image
|
5 |
+
|
6 |
+
class Croper:
|
7 |
+
def __init__(
|
8 |
+
self,
|
9 |
+
input_image: PIL.Image,
|
10 |
+
target_mask: np.ndarray,
|
11 |
+
):
|
12 |
+
self.input_image = input_image
|
13 |
+
self.target_mask = target_mask
|
14 |
+
|
15 |
+
def corp_mask_image(self):
|
16 |
+
target_mask = self.target_mask
|
17 |
+
input_image = self.input_image
|
18 |
+
crop_length = 512
|
19 |
+
expand_size = 40
|
20 |
+
original_width, original_height = input_image.size
|
21 |
+
mask_indices = np.where(target_mask)
|
22 |
+
start_y = np.min(mask_indices[0]) - expand_size
|
23 |
+
if start_y < 0:
|
24 |
+
start_y = 0
|
25 |
+
end_y = np.max(mask_indices[0]) + expand_size
|
26 |
+
if end_y > original_height:
|
27 |
+
end_y = original_height
|
28 |
+
start_x = np.min(mask_indices[1]) - expand_size
|
29 |
+
if start_x < 0:
|
30 |
+
start_x = 0
|
31 |
+
end_x = np.max(mask_indices[1]) + expand_size
|
32 |
+
if end_x > original_width:
|
33 |
+
end_x = original_width
|
34 |
+
mask_height = end_y - start_y
|
35 |
+
mask_width = end_x - start_x
|
36 |
+
|
37 |
+
# choose the max side length
|
38 |
+
max_side_length = max(mask_height, mask_width)
|
39 |
+
# calculate the crop area
|
40 |
+
crop_mask = target_mask[start_y:end_y, start_x:end_x]
|
41 |
+
crop_mask_start_y = (max_side_length - mask_height) // 2
|
42 |
+
crop_mask_end_y = crop_mask_start_y + mask_height
|
43 |
+
crop_mask_start_x = (max_side_length - mask_width) // 2
|
44 |
+
crop_mask_end_x = crop_mask_start_x + mask_width
|
45 |
+
# create a square mask
|
46 |
+
square_mask = np.zeros((max_side_length, max_side_length), dtype=target_mask.dtype)
|
47 |
+
square_mask[crop_mask_start_y:crop_mask_end_y, crop_mask_start_x:crop_mask_end_x] = crop_mask
|
48 |
+
square_mask_image = Image.fromarray((square_mask * 255).astype(np.uint8))
|
49 |
+
|
50 |
+
crop_image = input_image.crop((start_x, start_y, end_x, end_y))
|
51 |
+
square_image = Image.new("RGB", (max_side_length, max_side_length))
|
52 |
+
square_image.paste(crop_image, (crop_mask_start_x, crop_mask_start_y))
|
53 |
+
|
54 |
+
self.origin_start_x = start_x
|
55 |
+
self.origin_start_y = start_y
|
56 |
+
self.origin_end_x = end_x
|
57 |
+
self.origin_end_y = end_y
|
58 |
+
|
59 |
+
self.square_start_x = crop_mask_start_x
|
60 |
+
self.square_start_y = crop_mask_start_y
|
61 |
+
self.square_end_x = crop_mask_end_x
|
62 |
+
self.square_end_y = crop_mask_end_y
|
63 |
+
|
64 |
+
self.square_length = max_side_length
|
65 |
+
self.square_mask_image = square_mask_image
|
66 |
+
self.square_image = square_image
|
67 |
+
|
68 |
+
self.resized_square_mask_image = square_mask_image.resize((crop_length, crop_length))
|
69 |
+
self.resized_square_image = square_image.resize((crop_length, crop_length))
|
70 |
+
|
71 |
+
return self.square_image, self.resized_square_image
|