import gradio as gr import mediapipe as mp import numpy as np from PIL import Image from mediapipe.tasks import python from mediapipe.tasks.python import vision from scipy.ndimage import binary_dilation BG_COLOR = np.array([0, 0, 0], dtype=np.uint8) # black MASK_COLOR = np.array([255, 255, 255], dtype=np.uint8) # white MODEL_PATH = "checkpoints/selfie_multiclass_256x256.tflite" base_options = python.BaseOptions(model_asset_path=MODEL_PATH) options = vision.ImageSegmenterOptions(base_options=base_options,output_category_mask=True) segmenter = vision.ImageSegmenter.create_from_options(options) MASK_CATEGORY = segmenter.labels def segment(input_image, category): image = mp.Image(image_format=mp.ImageFormat.SRGB, data=np.asarray(input_image)) segmentation_result = segmenter.segment(image) category_mask = segmentation_result.category_mask category_mask = category_mask == MASK_CATEGORY.index(category) # Generate solid color images for showing the output segmentation mask. image_data = image.numpy_view() fg_image = np.zeros(image_data.shape, dtype=np.uint8) fg_image[:] = MASK_COLOR bg_image = np.zeros(image_data.shape, dtype=np.uint8) bg_image[:] = BG_COLOR dilated_mask = binary_dilation(category_mask.numpy_view(), iterations=4) condition = np.stack((dilated_mask,) * 3, axis=-1) > 0.2 output_image = np.where(condition, fg_image, bg_image) output_image = Image.fromarray(output_image) return output_image with gr.Blocks() as app: with gr.Row(): with gr.Column(): input_image = gr.Image(type='pil', label='Upload image') category = gr.Dropdown(label='Category', choices=MASK_CATEGORY, value=MASK_CATEGORY[1]) submit_btn = gr.Button(value='Submit', variant='primary') with gr.Column(): output_image = gr.Image(type='pil', label='Image Output') submit_btn.click( fn=segment, inputs=[ input_image, category, ], outputs=[output_image] ) app.launch(debug=False, show_error=True)