File size: 2,087 Bytes
440fd96
 
 
 
 
 
 
 
 
 
 
5540863
440fd96
 
 
5540863
440fd96
5540863
440fd96
 
 
5540863
440fd96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5540863
440fd96
 
 
 
 
 
 
 
5540863
440fd96
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import gradio as gr
import mediapipe as mp
import numpy as np
from PIL import Image
from mediapipe.tasks import python
from mediapipe.tasks.python import vision
from scipy.ndimage import binary_dilation

BG_COLOR = np.array([0, 0, 0], dtype=np.uint8) # black
MASK_COLOR = np.array([255, 255, 255], dtype=np.uint8) # white

MODEL_PATH = "checkpoints/selfie_multiclass_256x256.tflite"
base_options = python.BaseOptions(model_asset_path=MODEL_PATH)
options = vision.ImageSegmenterOptions(base_options=base_options,output_category_mask=True)
segmenter = vision.ImageSegmenter.create_from_options(options)
MASK_CATEGORY = segmenter.labels

def segment(input_image, category):
    image = mp.Image(image_format=mp.ImageFormat.SRGB, data=np.asarray(input_image))
    segmentation_result = segmenter.segment(image)
    category_mask = segmentation_result.category_mask
    category_mask = category_mask == MASK_CATEGORY.index(category)

    # Generate solid color images for showing the output segmentation mask.
    image_data = image.numpy_view()
    fg_image = np.zeros(image_data.shape, dtype=np.uint8)
    fg_image[:] = MASK_COLOR
    bg_image = np.zeros(image_data.shape, dtype=np.uint8)
    bg_image[:] = BG_COLOR

    dilated_mask = binary_dilation(category_mask.numpy_view(), iterations=4)
    condition = np.stack((dilated_mask,) * 3, axis=-1) > 0.2

    output_image = np.where(condition, fg_image, bg_image)
    output_image = Image.fromarray(output_image)
    return output_image

with gr.Blocks() as app:
    with gr.Row():
        with gr.Column():
            input_image = gr.Image(type='pil', label='Upload image')
            category = gr.Dropdown(label='Category', choices=MASK_CATEGORY, value=MASK_CATEGORY[1])
            submit_btn = gr.Button(value='Submit', variant='primary')
        with gr.Column():
            output_image = gr.Image(type='pil', label='Image Output')

    submit_btn.click(
        fn=segment,
        inputs=[
            input_image,
            category,
        ],
        outputs=[output_image]
    )

app.launch(debug=False, show_error=True)