Spaces:
Sleeping
Sleeping
zhiweili
commited on
Commit
·
e8e2aa0
1
Parent(s):
5540863
fix category mask
Browse files
app.py
CHANGED
@@ -19,7 +19,8 @@ def segment(input_image, category):
|
|
19 |
image = mp.Image(image_format=mp.ImageFormat.SRGB, data=np.asarray(input_image))
|
20 |
segmentation_result = segmenter.segment(image)
|
21 |
category_mask = segmentation_result.category_mask
|
22 |
-
|
|
|
23 |
|
24 |
# Generate solid color images for showing the output segmentation mask.
|
25 |
image_data = image.numpy_view()
|
@@ -28,7 +29,7 @@ def segment(input_image, category):
|
|
28 |
bg_image = np.zeros(image_data.shape, dtype=np.uint8)
|
29 |
bg_image[:] = BG_COLOR
|
30 |
|
31 |
-
dilated_mask = binary_dilation(
|
32 |
condition = np.stack((dilated_mask,) * 3, axis=-1) > 0.2
|
33 |
|
34 |
output_image = np.where(condition, fg_image, bg_image)
|
|
|
19 |
image = mp.Image(image_format=mp.ImageFormat.SRGB, data=np.asarray(input_image))
|
20 |
segmentation_result = segmenter.segment(image)
|
21 |
category_mask = segmentation_result.category_mask
|
22 |
+
category_mask_np = category_mask.numpy_view()
|
23 |
+
target_mask = category_mask_np == MASK_CATEGORY.index(category)
|
24 |
|
25 |
# Generate solid color images for showing the output segmentation mask.
|
26 |
image_data = image.numpy_view()
|
|
|
29 |
bg_image = np.zeros(image_data.shape, dtype=np.uint8)
|
30 |
bg_image[:] = BG_COLOR
|
31 |
|
32 |
+
dilated_mask = binary_dilation(target_mask, iterations=4)
|
33 |
condition = np.stack((dilated_mask,) * 3, axis=-1) > 0.2
|
34 |
|
35 |
output_image = np.where(condition, fg_image, bg_image)
|