segment_test / app.py
zhiweili
change to selfie_multiclass
5540863
raw
history blame
2.09 kB
import gradio as gr
import mediapipe as mp
import numpy as np
from PIL import Image
from mediapipe.tasks import python
from mediapipe.tasks.python import vision
from scipy.ndimage import binary_dilation
BG_COLOR = np.array([0, 0, 0], dtype=np.uint8) # black
MASK_COLOR = np.array([255, 255, 255], dtype=np.uint8) # white
MODEL_PATH = "checkpoints/selfie_multiclass_256x256.tflite"
base_options = python.BaseOptions(model_asset_path=MODEL_PATH)
options = vision.ImageSegmenterOptions(base_options=base_options,output_category_mask=True)
segmenter = vision.ImageSegmenter.create_from_options(options)
MASK_CATEGORY = segmenter.labels
def segment(input_image, category):
image = mp.Image(image_format=mp.ImageFormat.SRGB, data=np.asarray(input_image))
segmentation_result = segmenter.segment(image)
category_mask = segmentation_result.category_mask
category_mask = category_mask == MASK_CATEGORY.index(category)
# Generate solid color images for showing the output segmentation mask.
image_data = image.numpy_view()
fg_image = np.zeros(image_data.shape, dtype=np.uint8)
fg_image[:] = MASK_COLOR
bg_image = np.zeros(image_data.shape, dtype=np.uint8)
bg_image[:] = BG_COLOR
dilated_mask = binary_dilation(category_mask.numpy_view(), iterations=4)
condition = np.stack((dilated_mask,) * 3, axis=-1) > 0.2
output_image = np.where(condition, fg_image, bg_image)
output_image = Image.fromarray(output_image)
return output_image
with gr.Blocks() as app:
with gr.Row():
with gr.Column():
input_image = gr.Image(type='pil', label='Upload image')
category = gr.Dropdown(label='Category', choices=MASK_CATEGORY, value=MASK_CATEGORY[1])
submit_btn = gr.Button(value='Submit', variant='primary')
with gr.Column():
output_image = gr.Image(type='pil', label='Image Output')
submit_btn.click(
fn=segment,
inputs=[
input_image,
category,
],
outputs=[output_image]
)
app.launch(debug=False, show_error=True)