Spaces:
Sleeping
Sleeping
zhiweili
commited on
Commit
·
0ec7070
1
Parent(s):
0f3fb3e
add crop
Browse files
app.py
CHANGED
@@ -15,8 +15,10 @@ base_options = python.BaseOptions(model_asset_path=MODEL_PATH)
|
|
15 |
options = vision.ImageSegmenterOptions(base_options=base_options,output_category_mask=True)
|
16 |
segmenter = vision.ImageSegmenter.create_from_options(options)
|
17 |
labels = segmenter.labels
|
|
|
18 |
|
19 |
def segment(input_image, category):
|
|
|
20 |
image = mp.Image(image_format=mp.ImageFormat.SRGB, data=np.asarray(input_image))
|
21 |
segmentation_result = segmenter.segment(image)
|
22 |
category_mask = segmentation_result.category_mask
|
@@ -29,6 +31,41 @@ def segment(input_image, category):
|
|
29 |
else:
|
30 |
target_mask = category_mask_np == 0
|
31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
# Generate solid color images for showing the output segmentation mask.
|
33 |
image_data = image.numpy_view()
|
34 |
fg_image = np.zeros(image_data.shape, dtype=np.uint8)
|
@@ -40,7 +77,7 @@ def segment(input_image, category):
|
|
40 |
|
41 |
output_image = np.where(condition, fg_image, bg_image)
|
42 |
output_image = Image.fromarray(output_image)
|
43 |
-
return
|
44 |
|
45 |
def get_clothes_mask(category_mask_np):
|
46 |
body_skin_mask = category_mask_np == 2
|
@@ -82,10 +119,10 @@ def get_hair_mask(category_mask_np, should_dilate=False):
|
|
82 |
expanded_hair_mask = binary_dilation(expanded_hair_mask, iterations=10)
|
83 |
|
84 |
# Trim the expanded_hair_mask
|
85 |
-
# 1. Remove the area above hair_mask by
|
86 |
hair_indices = np.where(hair_mask)
|
87 |
min_hair_y = np.min(hair_indices[0])
|
88 |
-
expanded_hair_mask[:min_hair_y -
|
89 |
|
90 |
# 2. Remove the areas on both sides that exceed the clothing coordinates
|
91 |
clothes_indices = np.where(clothes_mask)
|
@@ -110,7 +147,8 @@ with gr.Blocks() as app:
|
|
110 |
category = gr.Dropdown(label='Category', choices=category_options, value=category_options[0])
|
111 |
submit_btn = gr.Button(value='Submit', variant='primary')
|
112 |
with gr.Column():
|
113 |
-
|
|
|
114 |
|
115 |
submit_btn.click(
|
116 |
fn=segment,
|
@@ -118,7 +156,7 @@ with gr.Blocks() as app:
|
|
118 |
input_image,
|
119 |
category,
|
120 |
],
|
121 |
-
outputs=[output_image]
|
122 |
)
|
123 |
|
124 |
app.launch(debug=False, show_error=True)
|
|
|
15 |
options = vision.ImageSegmenterOptions(base_options=base_options,output_category_mask=True)
|
16 |
segmenter = vision.ImageSegmenter.create_from_options(options)
|
17 |
labels = segmenter.labels
|
18 |
+
expand_size = 40
|
19 |
|
20 |
def segment(input_image, category):
|
21 |
+
original_height, original_width = input_image.size
|
22 |
image = mp.Image(image_format=mp.ImageFormat.SRGB, data=np.asarray(input_image))
|
23 |
segmentation_result = segmenter.segment(image)
|
24 |
category_mask = segmentation_result.category_mask
|
|
|
31 |
else:
|
32 |
target_mask = category_mask_np == 0
|
33 |
|
34 |
+
target_indices = np.where(target_mask)
|
35 |
+
start_y = np.min(target_indices[0]) - expand_size
|
36 |
+
if start_y < 0:
|
37 |
+
start_y = 0
|
38 |
+
end_y = np.max(target_indices[0]) + expand_size
|
39 |
+
if end_y > original_height:
|
40 |
+
end_y = original_height
|
41 |
+
start_x = np.min(target_indices[1]) - expand_size
|
42 |
+
if start_x < 0:
|
43 |
+
start_x = 0
|
44 |
+
end_x = np.max(target_indices[1]) + expand_size
|
45 |
+
if end_x > original_width:
|
46 |
+
end_x = original_width
|
47 |
+
target_height = end_y - start_y
|
48 |
+
target_width = end_x - start_x
|
49 |
+
|
50 |
+
# choose the max side length
|
51 |
+
max_side_length = max(target_height, target_width)
|
52 |
+
# calculate the crop area
|
53 |
+
crop_mask = target_mask[start_y:end_y, start_x:end_x]
|
54 |
+
crop_mask_height, crop_mask_width = crop_mask.shape
|
55 |
+
crop_mask_start_y = (max_side_length - crop_mask_height) // 2
|
56 |
+
crop_mask_end_y = crop_mask_start_y + crop_mask_height
|
57 |
+
crop_mask_start_x = (max_side_length - crop_mask_width) // 2
|
58 |
+
crop_mask_end_x = crop_mask_start_x + crop_mask_width
|
59 |
+
# create a square mask
|
60 |
+
crop_mask_square = np.zeros((max_side_length, max_side_length), dtype=target_mask.dtype)
|
61 |
+
crop_mask_square[crop_mask_start_y:crop_mask_end_y, crop_mask_start_x:crop_mask_end_x] = crop_mask
|
62 |
+
# create a square image
|
63 |
+
crop_mask_image = Image.fromarray((crop_mask_square * 255).astype(np.uint8))
|
64 |
+
|
65 |
+
crop_image = input_image.crop((start_x, start_y, end_x, end_y))
|
66 |
+
crop_image_square = Image.new("RGB", (max_side_length, max_side_length))
|
67 |
+
crop_image_square.paste(crop_image, (crop_mask_start_x, crop_mask_start_y))
|
68 |
+
|
69 |
# Generate solid color images for showing the output segmentation mask.
|
70 |
image_data = image.numpy_view()
|
71 |
fg_image = np.zeros(image_data.shape, dtype=np.uint8)
|
|
|
77 |
|
78 |
output_image = np.where(condition, fg_image, bg_image)
|
79 |
output_image = Image.fromarray(output_image)
|
80 |
+
return crop_mask_image, crop_image_square
|
81 |
|
82 |
def get_clothes_mask(category_mask_np):
|
83 |
body_skin_mask = category_mask_np == 2
|
|
|
119 |
expanded_hair_mask = binary_dilation(expanded_hair_mask, iterations=10)
|
120 |
|
121 |
# Trim the expanded_hair_mask
|
122 |
+
# 1. Remove the area above hair_mask by 10 pixels
|
123 |
hair_indices = np.where(hair_mask)
|
124 |
min_hair_y = np.min(hair_indices[0])
|
125 |
+
expanded_hair_mask[:min_hair_y - 10, :] = 0
|
126 |
|
127 |
# 2. Remove the areas on both sides that exceed the clothing coordinates
|
128 |
clothes_indices = np.where(clothes_mask)
|
|
|
147 |
category = gr.Dropdown(label='Category', choices=category_options, value=category_options[0])
|
148 |
submit_btn = gr.Button(value='Submit', variant='primary')
|
149 |
with gr.Column():
|
150 |
+
mask_image = gr.Image(type='pil', label='Segmentation mask')
|
151 |
+
output_image = gr.Image(type='pil', label='Segmented image')
|
152 |
|
153 |
submit_btn.click(
|
154 |
fn=segment,
|
|
|
156 |
input_image,
|
157 |
category,
|
158 |
],
|
159 |
+
outputs=[mask_image, output_image]
|
160 |
)
|
161 |
|
162 |
app.launch(debug=False, show_error=True)
|