Spaces:
Runtime error
Runtime error
import albumentations as A | |
import numpy as np | |
import gradio as gr | |
import torch | |
from torch import nn | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from ade20k_colors import colors | |
from PIL import Image | |
from transformers import SegformerForSemanticSegmentation | |
model = SegformerForSemanticSegmentation.from_pretrained('s3nh/SegFormer-b0-person-segmentation') | |
def inference(image, chosen_model): | |
# Transforms | |
_transform = A.Compose([ | |
A.Resize(height = 512, width=512), | |
ToTensorV2(), | |
]) | |
trans_image = _transform(image=np.array(image)) | |
outputs = model(trans_image['image'].float().unsqueeze(0)) | |
logits = outputs.logits | |
output = torch.sigmoid(logits).detach().numpy()[0] | |
output = np.transpose(output, (1,2,0)) | |
upsampled_logits = nn.functional.interpolate(logits, | |
size=image.size[::-1], # (height, width) | |
mode='bilinear', | |
align_corners=False) | |
seg = upsampled_logits.argmax(dim=1)[0] | |
color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) # height, width, 3 | |
palette = np.array([[0, 0, 0],[255, 255, 255]]) | |
for label, color in enumerate(palette): | |
color_seg[seg == label, :] = color | |
# Convert to BGR | |
color_seg = color_seg[..., ::-1] | |
img = np.array(image) * 0.5 + color_seg * 0.5 | |
img = img.astype(np.uint8) | |
return Image.fromarray(img) | |
inputs = [gr.inputs.Image(label='Input Image'), | |
gr.inputs.Radio(['Base', 'Large'], label='BEiT Model', type='index')] | |
gr.Interface( | |
inference, | |
inputs, | |
gr.outputs.Image(label='Output'), | |
title='Segformer B0 - People segmentation', | |
description='Segformer', | |
).launch() |