zhiweili commited on
Commit
0f3fb3e
·
1 Parent(s): 40b1711

fix hair expand

Browse files
Files changed (1) hide show
  1. app.py +5 -5
app.py CHANGED
@@ -60,6 +60,7 @@ def get_hair_mask(category_mask_np, should_dilate=False):
60
 
61
  face_indices = np.where(face_skin_mask)
62
  min_face_y = np.min(face_indices[0])
 
63
 
64
  labeled_hair, hair_features = label(hair_mask)
65
  top_hair_mask = np.zeros_like(hair_mask)
@@ -70,11 +71,10 @@ def get_hair_mask(category_mask_np, should_dilate=False):
70
  if min_component_y <= min_face_y:
71
  top_hair_mask[component_mask] = True
72
 
73
- expanded_face_mask = binary_dilation(face_skin_mask, iterations=40)
74
  # Combine the reference masks (body, clothes)
75
  reference_mask = np.logical_or(body_skin_mask, clothes_mask)
76
- # Exclude the expanded face mask from the reference mask
77
- reference_mask = np.logical_and(reference_mask, ~expanded_face_mask)
78
 
79
  # Expand the hair mask downward until it reaches the reference areas
80
  expanded_hair_mask = top_hair_mask
@@ -84,8 +84,8 @@ def get_hair_mask(category_mask_np, should_dilate=False):
84
  # Trim the expanded_hair_mask
85
  # 1. Remove the area above hair_mask by 20 pixels
86
  hair_indices = np.where(hair_mask)
87
- min_hair_y = np.min(hair_indices[0]) - 20
88
- expanded_hair_mask[:min_hair_y, :] = 0
89
 
90
  # 2. Remove the areas on both sides that exceed the clothing coordinates
91
  clothes_indices = np.where(clothes_mask)
 
60
 
61
  face_indices = np.where(face_skin_mask)
62
  min_face_y = np.min(face_indices[0])
63
+ max_face_y = np.max(face_indices[0])
64
 
65
  labeled_hair, hair_features = label(hair_mask)
66
  top_hair_mask = np.zeros_like(hair_mask)
 
71
  if min_component_y <= min_face_y:
72
  top_hair_mask[component_mask] = True
73
 
 
74
  # Combine the reference masks (body, clothes)
75
  reference_mask = np.logical_or(body_skin_mask, clothes_mask)
76
+ # Remove the area above the face by 40 pixels
77
+ reference_mask[:max_face_y+40, :] = 0
78
 
79
  # Expand the hair mask downward until it reaches the reference areas
80
  expanded_hair_mask = top_hair_mask
 
84
  # Trim the expanded_hair_mask
85
  # 1. Remove the area above hair_mask by 20 pixels
86
  hair_indices = np.where(hair_mask)
87
+ min_hair_y = np.min(hair_indices[0])
88
+ expanded_hair_mask[:min_hair_y - 20, :] = 0
89
 
90
  # 2. Remove the areas on both sides that exceed the clothing coordinates
91
  clothes_indices = np.where(clothes_mask)