semantic-10 / app.py
adpro's picture
Update app.py
af6cfb2
raw
history blame
958 Bytes
import torch.hub
import gluoncv
import mxnet as mx
from gluoncv.utils.viz import get_color_pallete
import gradio as gr
import numpy as np
from PIL import Image
from gluoncv.data.transforms.presets.segmentation import test_transform
# using cpu
ctx = mx.cpu(0)
model = torch.hub.load("kazuto1011/deeplab-pytorch", "deeplabv2_resnet101", pretrained='cocostuff164k', n_classes=182)
def segmentation(image):
img = Image.fromarray(image)
img = mx.ndarray.array(img)
img = test_transform(img, ctx)
output = model.predict(img)
predict = mx.nd.squeeze(mx.nd.argmax(output, 1)).asnumpy()
mask = get_color_pallete(predict, "ade20k")
return mask
image_in = gr.Image()
image_out = gr.components.Image()
description = "MXNet Image Segmentation Model"
examples=['cat.jpeg']
Iface = gr.Interface(
fn=segmentation,
inputs=image_in,
outputs=image_out,
title="Semantic Segmentation - MXNet",
examples=examples
).launch()