Spaces:
Sleeping
Sleeping
import gradio as gr | |
import mediapipe as mp | |
import numpy as np | |
from PIL import Image | |
from mediapipe.tasks import python | |
from mediapipe.tasks.python import vision | |
from scipy.ndimage import binary_dilation, label | |
BG_COLOR = np.array([0, 0, 0], dtype=np.uint8) # black | |
MASK_COLOR = np.array([255, 255, 255], dtype=np.uint8) # white | |
MODEL_PATH = "checkpoints/selfie_multiclass_256x256.tflite" | |
category_options = ["hair", "clothes", "background"] | |
base_options = python.BaseOptions(model_asset_path=MODEL_PATH) | |
options = vision.ImageSegmenterOptions(base_options=base_options,output_category_mask=True) | |
segmenter = vision.ImageSegmenter.create_from_options(options) | |
labels = segmenter.labels | |
expand_size = 40 | |
def segment(input_image, category): | |
original_height, original_width = input_image.size | |
image = mp.Image(image_format=mp.ImageFormat.SRGB, data=np.asarray(input_image)) | |
segmentation_result = segmenter.segment(image) | |
category_mask = segmentation_result.category_mask | |
category_mask_np = category_mask.numpy_view() | |
if category == "hair": | |
target_mask = get_hair_mask(category_mask_np, should_dilate=True) | |
elif category == "clothes": | |
target_mask = get_clothes_mask(category_mask_np) | |
else: | |
target_mask = category_mask_np == 0 | |
target_indices = np.where(target_mask) | |
start_y = np.min(target_indices[0]) - expand_size | |
if start_y < 0: | |
start_y = 0 | |
end_y = np.max(target_indices[0]) + expand_size | |
if end_y > original_height: | |
end_y = original_height | |
start_x = np.min(target_indices[1]) - expand_size | |
if start_x < 0: | |
start_x = 0 | |
end_x = np.max(target_indices[1]) + expand_size | |
if end_x > original_width: | |
end_x = original_width | |
target_height = end_y - start_y | |
target_width = end_x - start_x | |
# choose the max side length | |
max_side_length = max(target_height, target_width) | |
# calculate the crop area | |
crop_mask = target_mask[start_y:end_y, start_x:end_x] | |
crop_mask_height, crop_mask_width = crop_mask.shape | |
crop_mask_start_y = (max_side_length - crop_mask_height) // 2 | |
crop_mask_end_y = crop_mask_start_y + crop_mask_height | |
crop_mask_start_x = (max_side_length - crop_mask_width) // 2 | |
crop_mask_end_x = crop_mask_start_x + crop_mask_width | |
# create a square mask | |
crop_mask_square = np.zeros((max_side_length, max_side_length), dtype=target_mask.dtype) | |
crop_mask_square[crop_mask_start_y:crop_mask_end_y, crop_mask_start_x:crop_mask_end_x] = crop_mask | |
# create a square image | |
crop_mask_image = Image.fromarray((crop_mask_square * 255).astype(np.uint8)) | |
crop_image = input_image.crop((start_x, start_y, end_x, end_y)) | |
crop_image_square = Image.new("RGB", (max_side_length, max_side_length)) | |
crop_image_square.paste(crop_image, (crop_mask_start_x, crop_mask_start_y)) | |
# Generate solid color images for showing the output segmentation mask. | |
image_data = image.numpy_view() | |
fg_image = np.zeros(image_data.shape, dtype=np.uint8) | |
fg_image[:] = MASK_COLOR | |
bg_image = np.zeros(image_data.shape, dtype=np.uint8) | |
bg_image[:] = BG_COLOR | |
condition = np.stack((target_mask,) * 3, axis=-1) > 0.2 | |
output_image = np.where(condition, fg_image, bg_image) | |
output_image = Image.fromarray(output_image) | |
return crop_mask_image, crop_image_square | |
def get_clothes_mask(category_mask_np): | |
body_skin_mask = category_mask_np == 2 | |
clothes_mask = category_mask_np == 4 | |
combined_mask = np.logical_or(body_skin_mask, clothes_mask) | |
combined_mask = binary_dilation(combined_mask, iterations=4) | |
return combined_mask | |
def get_hair_mask(category_mask_np, should_dilate=False): | |
hair_mask = category_mask_np == 1 | |
hair_mask = binary_dilation(hair_mask, iterations=4) | |
if not should_dilate: | |
return hair_mask | |
body_skin_mask = category_mask_np == 2 | |
face_skin_mask = category_mask_np == 3 | |
clothes_mask = category_mask_np == 4 | |
face_indices = np.where(face_skin_mask) | |
min_face_y = np.min(face_indices[0]) | |
max_face_y = np.max(face_indices[0]) | |
labeled_hair, hair_features = label(hair_mask) | |
top_hair_mask = np.zeros_like(hair_mask) | |
for i in range(1, hair_features + 1): | |
component_mask = labeled_hair == i | |
component_indices = np.where(component_mask) | |
min_component_y = np.min(component_indices[0]) | |
if min_component_y <= min_face_y: | |
top_hair_mask[component_mask] = True | |
# Combine the reference masks (body, clothes) | |
reference_mask = np.logical_or(body_skin_mask, clothes_mask) | |
# Remove the area above the face by 40 pixels | |
reference_mask[:max_face_y+40, :] = 0 | |
# Expand the hair mask downward until it reaches the reference areas | |
expanded_hair_mask = top_hair_mask | |
while not np.any(np.logical_and(expanded_hair_mask, reference_mask)): | |
expanded_hair_mask = binary_dilation(expanded_hair_mask, iterations=10) | |
# Trim the expanded_hair_mask | |
# 1. Remove the area above hair_mask by 10 pixels | |
hair_indices = np.where(hair_mask) | |
min_hair_y = np.min(hair_indices[0]) | |
expanded_hair_mask[:min_hair_y - 10, :] = 0 | |
# 2. Remove the areas on both sides that exceed the clothing coordinates | |
clothes_indices = np.where(clothes_mask) | |
min_clothes_x = np.min(clothes_indices[1]) | |
max_clothes_x = np.max(clothes_indices[1]) | |
expanded_hair_mask[:, :min_clothes_x] = 0 | |
expanded_hair_mask[:, max_clothes_x+1:] = 0 | |
# exclude the face-skin, body-skin and clothes areas | |
expanded_hair_mask = np.logical_and(expanded_hair_mask, ~face_skin_mask) | |
expanded_hair_mask = np.logical_and(expanded_hair_mask, ~body_skin_mask) | |
expanded_hair_mask = np.logical_and(expanded_hair_mask, ~clothes_mask) | |
# combine the hair mask with the expanded hair mask | |
expanded_hair_mask = np.logical_or(hair_mask, expanded_hair_mask) | |
return expanded_hair_mask | |
with gr.Blocks() as app: | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.Image(type='pil', label='Upload image') | |
category = gr.Dropdown(label='Category', choices=category_options, value=category_options[0]) | |
submit_btn = gr.Button(value='Submit', variant='primary') | |
with gr.Column(): | |
mask_image = gr.Image(type='pil', label='Segmentation mask') | |
output_image = gr.Image(type='pil', label='Segmented image') | |
submit_btn.click( | |
fn=segment, | |
inputs=[ | |
input_image, | |
category, | |
], | |
outputs=[mask_image, output_image] | |
) | |
app.launch(debug=False, show_error=True) |