zhiweili commited on
Commit
e8e2aa0
·
1 Parent(s): 5540863

fix category mask

Browse files
Files changed (1) hide show
  1. app.py +3 -2
app.py CHANGED
@@ -19,7 +19,8 @@ def segment(input_image, category):
19
  image = mp.Image(image_format=mp.ImageFormat.SRGB, data=np.asarray(input_image))
20
  segmentation_result = segmenter.segment(image)
21
  category_mask = segmentation_result.category_mask
22
- category_mask = category_mask == MASK_CATEGORY.index(category)
 
23
 
24
  # Generate solid color images for showing the output segmentation mask.
25
  image_data = image.numpy_view()
@@ -28,7 +29,7 @@ def segment(input_image, category):
28
  bg_image = np.zeros(image_data.shape, dtype=np.uint8)
29
  bg_image[:] = BG_COLOR
30
 
31
- dilated_mask = binary_dilation(category_mask.numpy_view(), iterations=4)
32
  condition = np.stack((dilated_mask,) * 3, axis=-1) > 0.2
33
 
34
  output_image = np.where(condition, fg_image, bg_image)
 
19
  image = mp.Image(image_format=mp.ImageFormat.SRGB, data=np.asarray(input_image))
20
  segmentation_result = segmenter.segment(image)
21
  category_mask = segmentation_result.category_mask
22
+ category_mask_np = category_mask.numpy_view()
23
+ target_mask = category_mask_np == MASK_CATEGORY.index(category)
24
 
25
  # Generate solid color images for showing the output segmentation mask.
26
  image_data = image.numpy_view()
 
29
  bg_image = np.zeros(image_data.shape, dtype=np.uint8)
30
  bg_image[:] = BG_COLOR
31
 
32
+ dilated_mask = binary_dilation(target_mask, iterations=4)
33
  condition = np.stack((dilated_mask,) * 3, axis=-1) > 0.2
34
 
35
  output_image = np.where(condition, fg_image, bg_image)