Spaces:
Sleeping
Sleeping
File size: 2,087 Bytes
440fd96 5540863 440fd96 5540863 440fd96 5540863 440fd96 5540863 440fd96 5540863 440fd96 5540863 440fd96 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
import gradio as gr
import mediapipe as mp
import numpy as np
from PIL import Image
from mediapipe.tasks import python
from mediapipe.tasks.python import vision
from scipy.ndimage import binary_dilation
BG_COLOR = np.array([0, 0, 0], dtype=np.uint8) # black
MASK_COLOR = np.array([255, 255, 255], dtype=np.uint8) # white
MODEL_PATH = "checkpoints/selfie_multiclass_256x256.tflite"
base_options = python.BaseOptions(model_asset_path=MODEL_PATH)
options = vision.ImageSegmenterOptions(base_options=base_options,output_category_mask=True)
segmenter = vision.ImageSegmenter.create_from_options(options)
MASK_CATEGORY = segmenter.labels
def segment(input_image, category):
image = mp.Image(image_format=mp.ImageFormat.SRGB, data=np.asarray(input_image))
segmentation_result = segmenter.segment(image)
category_mask = segmentation_result.category_mask
category_mask = category_mask == MASK_CATEGORY.index(category)
# Generate solid color images for showing the output segmentation mask.
image_data = image.numpy_view()
fg_image = np.zeros(image_data.shape, dtype=np.uint8)
fg_image[:] = MASK_COLOR
bg_image = np.zeros(image_data.shape, dtype=np.uint8)
bg_image[:] = BG_COLOR
dilated_mask = binary_dilation(category_mask.numpy_view(), iterations=4)
condition = np.stack((dilated_mask,) * 3, axis=-1) > 0.2
output_image = np.where(condition, fg_image, bg_image)
output_image = Image.fromarray(output_image)
return output_image
with gr.Blocks() as app:
with gr.Row():
with gr.Column():
input_image = gr.Image(type='pil', label='Upload image')
category = gr.Dropdown(label='Category', choices=MASK_CATEGORY, value=MASK_CATEGORY[1])
submit_btn = gr.Button(value='Submit', variant='primary')
with gr.Column():
output_image = gr.Image(type='pil', label='Image Output')
submit_btn.click(
fn=segment,
inputs=[
input_image,
category,
],
outputs=[output_image]
)
app.launch(debug=False, show_error=True) |