Spaces:
Sleeping
Sleeping
File size: 4,947 Bytes
440fd96 40b1711 440fd96 5540863 40b1711 440fd96 40b1711 440fd96 5540863 440fd96 e8e2aa0 40b1711 440fd96 40b1711 440fd96 40b1711 440fd96 40b1711 440fd96 5540863 440fd96 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
import gradio as gr
import mediapipe as mp
import numpy as np
from PIL import Image
from mediapipe.tasks import python
from mediapipe.tasks.python import vision
from scipy.ndimage import binary_dilation, label
BG_COLOR = np.array([0, 0, 0], dtype=np.uint8) # black
MASK_COLOR = np.array([255, 255, 255], dtype=np.uint8) # white
MODEL_PATH = "checkpoints/selfie_multiclass_256x256.tflite"
category_options = ["hair", "clothes", "background"]
base_options = python.BaseOptions(model_asset_path=MODEL_PATH)
options = vision.ImageSegmenterOptions(base_options=base_options,output_category_mask=True)
segmenter = vision.ImageSegmenter.create_from_options(options)
labels = segmenter.labels
def segment(input_image, category):
image = mp.Image(image_format=mp.ImageFormat.SRGB, data=np.asarray(input_image))
segmentation_result = segmenter.segment(image)
category_mask = segmentation_result.category_mask
category_mask_np = category_mask.numpy_view()
if category == "hair":
target_mask = get_hair_mask(category_mask_np, should_dilate=True)
elif category == "clothes":
target_mask = get_clothes_mask(category_mask_np)
else:
target_mask = category_mask_np == 0
# Generate solid color images for showing the output segmentation mask.
image_data = image.numpy_view()
fg_image = np.zeros(image_data.shape, dtype=np.uint8)
fg_image[:] = MASK_COLOR
bg_image = np.zeros(image_data.shape, dtype=np.uint8)
bg_image[:] = BG_COLOR
condition = np.stack((target_mask,) * 3, axis=-1) > 0.2
output_image = np.where(condition, fg_image, bg_image)
output_image = Image.fromarray(output_image)
return output_image
def get_clothes_mask(category_mask_np):
body_skin_mask = category_mask_np == 2
clothes_mask = category_mask_np == 4
combined_mask = np.logical_or(body_skin_mask, clothes_mask)
combined_mask = binary_dilation(combined_mask, iterations=4)
return combined_mask
def get_hair_mask(category_mask_np, should_dilate=False):
hair_mask = category_mask_np == 1
hair_mask = binary_dilation(hair_mask, iterations=4)
if not should_dilate:
return hair_mask
body_skin_mask = category_mask_np == 2
face_skin_mask = category_mask_np == 3
clothes_mask = category_mask_np == 4
face_indices = np.where(face_skin_mask)
min_face_y = np.min(face_indices[0])
labeled_hair, hair_features = label(hair_mask)
top_hair_mask = np.zeros_like(hair_mask)
for i in range(1, hair_features + 1):
component_mask = labeled_hair == i
component_indices = np.where(component_mask)
min_component_y = np.min(component_indices[0])
if min_component_y <= min_face_y:
top_hair_mask[component_mask] = True
expanded_face_mask = binary_dilation(face_skin_mask, iterations=40)
# Combine the reference masks (body, clothes)
reference_mask = np.logical_or(body_skin_mask, clothes_mask)
# Exclude the expanded face mask from the reference mask
reference_mask = np.logical_and(reference_mask, ~expanded_face_mask)
# Expand the hair mask downward until it reaches the reference areas
expanded_hair_mask = top_hair_mask
while not np.any(np.logical_and(expanded_hair_mask, reference_mask)):
expanded_hair_mask = binary_dilation(expanded_hair_mask, iterations=10)
# Trim the expanded_hair_mask
# 1. Remove the area above hair_mask by 20 pixels
hair_indices = np.where(hair_mask)
min_hair_y = np.min(hair_indices[0]) - 20
expanded_hair_mask[:min_hair_y, :] = 0
# 2. Remove the areas on both sides that exceed the clothing coordinates
clothes_indices = np.where(clothes_mask)
min_clothes_x = np.min(clothes_indices[1])
max_clothes_x = np.max(clothes_indices[1])
expanded_hair_mask[:, :min_clothes_x] = 0
expanded_hair_mask[:, max_clothes_x+1:] = 0
# exclude the face-skin, body-skin and clothes areas
expanded_hair_mask = np.logical_and(expanded_hair_mask, ~face_skin_mask)
expanded_hair_mask = np.logical_and(expanded_hair_mask, ~body_skin_mask)
expanded_hair_mask = np.logical_and(expanded_hair_mask, ~clothes_mask)
# combine the hair mask with the expanded hair mask
expanded_hair_mask = np.logical_or(hair_mask, expanded_hair_mask)
return expanded_hair_mask
with gr.Blocks() as app:
with gr.Row():
with gr.Column():
input_image = gr.Image(type='pil', label='Upload image')
category = gr.Dropdown(label='Category', choices=category_options, value=category_options[0])
submit_btn = gr.Button(value='Submit', variant='primary')
with gr.Column():
output_image = gr.Image(type='pil', label='Image Output')
submit_btn.click(
fn=segment,
inputs=[
input_image,
category,
],
outputs=[output_image]
)
app.launch(debug=False, show_error=True) |