s3nh's picture
Create app.py
97b896c
raw
history blame
1.7 kB
import albumentations as A
import numpy as np
import gradio as gr
import torch
from torch import nn
import numpy as np
import matplotlib.pyplot as plt
from ade20k_colors import colors
from PIL import Image
from transformers import SegformerForSemanticSegmentation
model = SegformerForSemanticSegmentation.from_pretrained('s3nh/SegFormer-b0-person-segmentation')
def inference(image, chosen_model):
# Transforms
_transform = A.Compose([
A.Resize(height = 512, width=512),
ToTensorV2(),
])
trans_image = _transform(image=np.array(image))
outputs = model(trans_image['image'].float().unsqueeze(0))
logits = outputs.logits
output = torch.sigmoid(logits).detach().numpy()[0]
output = np.transpose(output, (1,2,0))
upsampled_logits = nn.functional.interpolate(logits,
size=image.size[::-1], # (height, width)
mode='bilinear',
align_corners=False)
seg = upsampled_logits.argmax(dim=1)[0]
color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) # height, width, 3
palette = np.array([[0, 0, 0],[255, 255, 255]])
for label, color in enumerate(palette):
color_seg[seg == label, :] = color
# Convert to BGR
color_seg = color_seg[..., ::-1]
img = np.array(image) * 0.5 + color_seg * 0.5
img = img.astype(np.uint8)
return Image.fromarray(img)
inputs = [gr.inputs.Image(label='Input Image'),
gr.inputs.Radio(['Base', 'Large'], label='BEiT Model', type='index')]
gr.Interface(
inference,
inputs,
gr.outputs.Image(label='Output'),
title='Segformer B0 - People segmentation',
description='Segformer',
).launch()