s3nh's picture
Update app.py
ef83dce
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2
import numpy as np
import gradio as gr
import torch
from torch import nn
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from transformers import SegformerForSemanticSegmentation
model = SegformerForSemanticSegmentation.from_pretrained('s3nh/SegFormer-b0-person-segmentation')
def inference(image, chosen_model):
# Transforms
_transform = A.Compose([
A.Resize(height = 512, width=512),
ToTensorV2(),
])
trans_image = _transform(image=np.array(image))
outputs = model(trans_image['image'].float().unsqueeze(0))
logits = outputs.logits
output = torch.sigmoid(logits).detach().numpy()[0]
# output = np.transpose(output, (1,2,0))
# upsampled_logits = nn.functional.interpolate(logits,
# size=image.size[::-1], # (height, width)
# mode='bilinear',
# align_corners=False)
seg = output
color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) # height, width, 3
palette = np.array([[0, 0, 0],[255, 255, 255]])
for label, color in enumerate(palette):
color_seg[seg == label] = color
# Convert to BGR
color_seg = color_seg[..., ::-1]
img = np.array(image) * 0.5 + color_seg * 0.5
output = Image.fromarray(img.astype(np.uint8))
return output
demo = gr.Interface(
inference,
inputs = gr.Image(),
outputs= gr.Image(type="pil"),
title='Segformer B0 - People segmentation',
description='Segformer',
)
demo.launch()