zhiweili commited on
Commit
0ec7070
·
1 Parent(s): 0f3fb3e
Files changed (1) hide show
  1. app.py +43 -5
app.py CHANGED
@@ -15,8 +15,10 @@ base_options = python.BaseOptions(model_asset_path=MODEL_PATH)
15
  options = vision.ImageSegmenterOptions(base_options=base_options,output_category_mask=True)
16
  segmenter = vision.ImageSegmenter.create_from_options(options)
17
  labels = segmenter.labels
 
18
 
19
  def segment(input_image, category):
 
20
  image = mp.Image(image_format=mp.ImageFormat.SRGB, data=np.asarray(input_image))
21
  segmentation_result = segmenter.segment(image)
22
  category_mask = segmentation_result.category_mask
@@ -29,6 +31,41 @@ def segment(input_image, category):
29
  else:
30
  target_mask = category_mask_np == 0
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  # Generate solid color images for showing the output segmentation mask.
33
  image_data = image.numpy_view()
34
  fg_image = np.zeros(image_data.shape, dtype=np.uint8)
@@ -40,7 +77,7 @@ def segment(input_image, category):
40
 
41
  output_image = np.where(condition, fg_image, bg_image)
42
  output_image = Image.fromarray(output_image)
43
- return output_image
44
 
45
  def get_clothes_mask(category_mask_np):
46
  body_skin_mask = category_mask_np == 2
@@ -82,10 +119,10 @@ def get_hair_mask(category_mask_np, should_dilate=False):
82
  expanded_hair_mask = binary_dilation(expanded_hair_mask, iterations=10)
83
 
84
  # Trim the expanded_hair_mask
85
- # 1. Remove the area above hair_mask by 20 pixels
86
  hair_indices = np.where(hair_mask)
87
  min_hair_y = np.min(hair_indices[0])
88
- expanded_hair_mask[:min_hair_y - 20, :] = 0
89
 
90
  # 2. Remove the areas on both sides that exceed the clothing coordinates
91
  clothes_indices = np.where(clothes_mask)
@@ -110,7 +147,8 @@ with gr.Blocks() as app:
110
  category = gr.Dropdown(label='Category', choices=category_options, value=category_options[0])
111
  submit_btn = gr.Button(value='Submit', variant='primary')
112
  with gr.Column():
113
- output_image = gr.Image(type='pil', label='Image Output')
 
114
 
115
  submit_btn.click(
116
  fn=segment,
@@ -118,7 +156,7 @@ with gr.Blocks() as app:
118
  input_image,
119
  category,
120
  ],
121
- outputs=[output_image]
122
  )
123
 
124
  app.launch(debug=False, show_error=True)
 
15
  options = vision.ImageSegmenterOptions(base_options=base_options,output_category_mask=True)
16
  segmenter = vision.ImageSegmenter.create_from_options(options)
17
  labels = segmenter.labels
18
+ expand_size = 40
19
 
20
  def segment(input_image, category):
21
+ original_height, original_width = input_image.size
22
  image = mp.Image(image_format=mp.ImageFormat.SRGB, data=np.asarray(input_image))
23
  segmentation_result = segmenter.segment(image)
24
  category_mask = segmentation_result.category_mask
 
31
  else:
32
  target_mask = category_mask_np == 0
33
 
34
+ target_indices = np.where(target_mask)
35
+ start_y = np.min(target_indices[0]) - expand_size
36
+ if start_y < 0:
37
+ start_y = 0
38
+ end_y = np.max(target_indices[0]) + expand_size
39
+ if end_y > original_height:
40
+ end_y = original_height
41
+ start_x = np.min(target_indices[1]) - expand_size
42
+ if start_x < 0:
43
+ start_x = 0
44
+ end_x = np.max(target_indices[1]) + expand_size
45
+ if end_x > original_width:
46
+ end_x = original_width
47
+ target_height = end_y - start_y
48
+ target_width = end_x - start_x
49
+
50
+ # choose the max side length
51
+ max_side_length = max(target_height, target_width)
52
+ # calculate the crop area
53
+ crop_mask = target_mask[start_y:end_y, start_x:end_x]
54
+ crop_mask_height, crop_mask_width = crop_mask.shape
55
+ crop_mask_start_y = (max_side_length - crop_mask_height) // 2
56
+ crop_mask_end_y = crop_mask_start_y + crop_mask_height
57
+ crop_mask_start_x = (max_side_length - crop_mask_width) // 2
58
+ crop_mask_end_x = crop_mask_start_x + crop_mask_width
59
+ # create a square mask
60
+ crop_mask_square = np.zeros((max_side_length, max_side_length), dtype=target_mask.dtype)
61
+ crop_mask_square[crop_mask_start_y:crop_mask_end_y, crop_mask_start_x:crop_mask_end_x] = crop_mask
62
+ # create a square image
63
+ crop_mask_image = Image.fromarray((crop_mask_square * 255).astype(np.uint8))
64
+
65
+ crop_image = input_image.crop((start_x, start_y, end_x, end_y))
66
+ crop_image_square = Image.new("RGB", (max_side_length, max_side_length))
67
+ crop_image_square.paste(crop_image, (crop_mask_start_x, crop_mask_start_y))
68
+
69
  # Generate solid color images for showing the output segmentation mask.
70
  image_data = image.numpy_view()
71
  fg_image = np.zeros(image_data.shape, dtype=np.uint8)
 
77
 
78
  output_image = np.where(condition, fg_image, bg_image)
79
  output_image = Image.fromarray(output_image)
80
+ return crop_mask_image, crop_image_square
81
 
82
  def get_clothes_mask(category_mask_np):
83
  body_skin_mask = category_mask_np == 2
 
119
  expanded_hair_mask = binary_dilation(expanded_hair_mask, iterations=10)
120
 
121
  # Trim the expanded_hair_mask
122
+ # 1. Remove the area above hair_mask by 10 pixels
123
  hair_indices = np.where(hair_mask)
124
  min_hair_y = np.min(hair_indices[0])
125
+ expanded_hair_mask[:min_hair_y - 10, :] = 0
126
 
127
  # 2. Remove the areas on both sides that exceed the clothing coordinates
128
  clothes_indices = np.where(clothes_mask)
 
147
  category = gr.Dropdown(label='Category', choices=category_options, value=category_options[0])
148
  submit_btn = gr.Button(value='Submit', variant='primary')
149
  with gr.Column():
150
+ mask_image = gr.Image(type='pil', label='Segmentation mask')
151
+ output_image = gr.Image(type='pil', label='Segmented image')
152
 
153
  submit_btn.click(
154
  fn=segment,
 
156
  input_image,
157
  category,
158
  ],
159
+ outputs=[mask_image, output_image]
160
  )
161
 
162
  app.launch(debug=False, show_error=True)