zhiweili commited on
Commit
40b1711
·
1 Parent(s): e8e2aa0

auto dilate the hair mask

Browse files
Files changed (2) hide show
  1. .gitignore +2 -1
  2. app.py +73 -6
.gitignore CHANGED
@@ -1 +1,2 @@
1
- .vscode
 
 
1
+ .vscode
2
+ .DS_Store
app.py CHANGED
@@ -4,23 +4,30 @@ import numpy as np
4
  from PIL import Image
5
  from mediapipe.tasks import python
6
  from mediapipe.tasks.python import vision
7
- from scipy.ndimage import binary_dilation
8
 
9
  BG_COLOR = np.array([0, 0, 0], dtype=np.uint8) # black
10
  MASK_COLOR = np.array([255, 255, 255], dtype=np.uint8) # white
11
 
12
  MODEL_PATH = "checkpoints/selfie_multiclass_256x256.tflite"
 
13
  base_options = python.BaseOptions(model_asset_path=MODEL_PATH)
14
  options = vision.ImageSegmenterOptions(base_options=base_options,output_category_mask=True)
15
  segmenter = vision.ImageSegmenter.create_from_options(options)
16
- MASK_CATEGORY = segmenter.labels
17
 
18
  def segment(input_image, category):
19
  image = mp.Image(image_format=mp.ImageFormat.SRGB, data=np.asarray(input_image))
20
  segmentation_result = segmenter.segment(image)
21
  category_mask = segmentation_result.category_mask
22
  category_mask_np = category_mask.numpy_view()
23
- target_mask = category_mask_np == MASK_CATEGORY.index(category)
 
 
 
 
 
 
24
 
25
  # Generate solid color images for showing the output segmentation mask.
26
  image_data = image.numpy_view()
@@ -29,18 +36,78 @@ def segment(input_image, category):
29
  bg_image = np.zeros(image_data.shape, dtype=np.uint8)
30
  bg_image[:] = BG_COLOR
31
 
32
- dilated_mask = binary_dilation(target_mask, iterations=4)
33
- condition = np.stack((dilated_mask,) * 3, axis=-1) > 0.2
34
 
35
  output_image = np.where(condition, fg_image, bg_image)
36
  output_image = Image.fromarray(output_image)
37
  return output_image
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  with gr.Blocks() as app:
40
  with gr.Row():
41
  with gr.Column():
42
  input_image = gr.Image(type='pil', label='Upload image')
43
- category = gr.Dropdown(label='Category', choices=MASK_CATEGORY, value=MASK_CATEGORY[1])
44
  submit_btn = gr.Button(value='Submit', variant='primary')
45
  with gr.Column():
46
  output_image = gr.Image(type='pil', label='Image Output')
 
4
  from PIL import Image
5
  from mediapipe.tasks import python
6
  from mediapipe.tasks.python import vision
7
+ from scipy.ndimage import binary_dilation, label
8
 
9
  BG_COLOR = np.array([0, 0, 0], dtype=np.uint8) # black
10
  MASK_COLOR = np.array([255, 255, 255], dtype=np.uint8) # white
11
 
12
  MODEL_PATH = "checkpoints/selfie_multiclass_256x256.tflite"
13
+ category_options = ["hair", "clothes", "background"]
14
  base_options = python.BaseOptions(model_asset_path=MODEL_PATH)
15
  options = vision.ImageSegmenterOptions(base_options=base_options,output_category_mask=True)
16
  segmenter = vision.ImageSegmenter.create_from_options(options)
17
+ labels = segmenter.labels
18
 
19
  def segment(input_image, category):
20
  image = mp.Image(image_format=mp.ImageFormat.SRGB, data=np.asarray(input_image))
21
  segmentation_result = segmenter.segment(image)
22
  category_mask = segmentation_result.category_mask
23
  category_mask_np = category_mask.numpy_view()
24
+
25
+ if category == "hair":
26
+ target_mask = get_hair_mask(category_mask_np, should_dilate=True)
27
+ elif category == "clothes":
28
+ target_mask = get_clothes_mask(category_mask_np)
29
+ else:
30
+ target_mask = category_mask_np == 0
31
 
32
  # Generate solid color images for showing the output segmentation mask.
33
  image_data = image.numpy_view()
 
36
  bg_image = np.zeros(image_data.shape, dtype=np.uint8)
37
  bg_image[:] = BG_COLOR
38
 
39
+ condition = np.stack((target_mask,) * 3, axis=-1) > 0.2
 
40
 
41
  output_image = np.where(condition, fg_image, bg_image)
42
  output_image = Image.fromarray(output_image)
43
  return output_image
44
 
45
+ def get_clothes_mask(category_mask_np):
46
+ body_skin_mask = category_mask_np == 2
47
+ clothes_mask = category_mask_np == 4
48
+ combined_mask = np.logical_or(body_skin_mask, clothes_mask)
49
+ combined_mask = binary_dilation(combined_mask, iterations=4)
50
+ return combined_mask
51
+
52
+ def get_hair_mask(category_mask_np, should_dilate=False):
53
+ hair_mask = category_mask_np == 1
54
+ hair_mask = binary_dilation(hair_mask, iterations=4)
55
+ if not should_dilate:
56
+ return hair_mask
57
+ body_skin_mask = category_mask_np == 2
58
+ face_skin_mask = category_mask_np == 3
59
+ clothes_mask = category_mask_np == 4
60
+
61
+ face_indices = np.where(face_skin_mask)
62
+ min_face_y = np.min(face_indices[0])
63
+
64
+ labeled_hair, hair_features = label(hair_mask)
65
+ top_hair_mask = np.zeros_like(hair_mask)
66
+ for i in range(1, hair_features + 1):
67
+ component_mask = labeled_hair == i
68
+ component_indices = np.where(component_mask)
69
+ min_component_y = np.min(component_indices[0])
70
+ if min_component_y <= min_face_y:
71
+ top_hair_mask[component_mask] = True
72
+
73
+ expanded_face_mask = binary_dilation(face_skin_mask, iterations=40)
74
+ # Combine the reference masks (body, clothes)
75
+ reference_mask = np.logical_or(body_skin_mask, clothes_mask)
76
+ # Exclude the expanded face mask from the reference mask
77
+ reference_mask = np.logical_and(reference_mask, ~expanded_face_mask)
78
+
79
+ # Expand the hair mask downward until it reaches the reference areas
80
+ expanded_hair_mask = top_hair_mask
81
+ while not np.any(np.logical_and(expanded_hair_mask, reference_mask)):
82
+ expanded_hair_mask = binary_dilation(expanded_hair_mask, iterations=10)
83
+
84
+ # Trim the expanded_hair_mask
85
+ # 1. Remove the area above hair_mask by 20 pixels
86
+ hair_indices = np.where(hair_mask)
87
+ min_hair_y = np.min(hair_indices[0]) - 20
88
+ expanded_hair_mask[:min_hair_y, :] = 0
89
+
90
+ # 2. Remove the areas on both sides that exceed the clothing coordinates
91
+ clothes_indices = np.where(clothes_mask)
92
+ min_clothes_x = np.min(clothes_indices[1])
93
+ max_clothes_x = np.max(clothes_indices[1])
94
+ expanded_hair_mask[:, :min_clothes_x] = 0
95
+ expanded_hair_mask[:, max_clothes_x+1:] = 0
96
+
97
+ # exclude the face-skin, body-skin and clothes areas
98
+ expanded_hair_mask = np.logical_and(expanded_hair_mask, ~face_skin_mask)
99
+ expanded_hair_mask = np.logical_and(expanded_hair_mask, ~body_skin_mask)
100
+ expanded_hair_mask = np.logical_and(expanded_hair_mask, ~clothes_mask)
101
+ # combine the hair mask with the expanded hair mask
102
+ expanded_hair_mask = np.logical_or(hair_mask, expanded_hair_mask)
103
+
104
+ return expanded_hair_mask
105
+
106
  with gr.Blocks() as app:
107
  with gr.Row():
108
  with gr.Column():
109
  input_image = gr.Image(type='pil', label='Upload image')
110
+ category = gr.Dropdown(label='Category', choices=category_options, value=category_options[0])
111
  submit_btn = gr.Button(value='Submit', variant='primary')
112
  with gr.Column():
113
  output_image = gr.Image(type='pil', label='Image Output')