Spaces:
Sleeping
Sleeping
import gradio as gr | |
import mediapipe as mp | |
import numpy as np | |
from PIL import Image | |
from mediapipe.tasks import python | |
from mediapipe.tasks.python import vision | |
from scipy.ndimage import binary_dilation | |
BG_COLOR = np.array([0, 0, 0], dtype=np.uint8) # black | |
MASK_COLOR = np.array([255, 255, 255], dtype=np.uint8) # white | |
MODEL_PATH = "checkpoints/selfie_multiclass_256x256.tflite" | |
base_options = python.BaseOptions(model_asset_path=MODEL_PATH) | |
options = vision.ImageSegmenterOptions(base_options=base_options,output_category_mask=True) | |
segmenter = vision.ImageSegmenter.create_from_options(options) | |
MASK_CATEGORY = segmenter.labels | |
def segment(input_image, category): | |
image = mp.Image(image_format=mp.ImageFormat.SRGB, data=np.asarray(input_image)) | |
segmentation_result = segmenter.segment(image) | |
category_mask = segmentation_result.category_mask | |
category_mask = category_mask == MASK_CATEGORY.index(category) | |
# Generate solid color images for showing the output segmentation mask. | |
image_data = image.numpy_view() | |
fg_image = np.zeros(image_data.shape, dtype=np.uint8) | |
fg_image[:] = MASK_COLOR | |
bg_image = np.zeros(image_data.shape, dtype=np.uint8) | |
bg_image[:] = BG_COLOR | |
dilated_mask = binary_dilation(category_mask.numpy_view(), iterations=4) | |
condition = np.stack((dilated_mask,) * 3, axis=-1) > 0.2 | |
output_image = np.where(condition, fg_image, bg_image) | |
output_image = Image.fromarray(output_image) | |
return output_image | |
with gr.Blocks() as app: | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.Image(type='pil', label='Upload image') | |
category = gr.Dropdown(label='Category', choices=MASK_CATEGORY, value=MASK_CATEGORY[1]) | |
submit_btn = gr.Button(value='Submit', variant='primary') | |
with gr.Column(): | |
output_image = gr.Image(type='pil', label='Image Output') | |
submit_btn.click( | |
fn=segment, | |
inputs=[ | |
input_image, | |
category, | |
], | |
outputs=[output_image] | |
) | |
app.launch(debug=False, show_error=True) |