Spaces:
Runtime error
Runtime error
import albumentations as A | |
from albumentations.pytorch.transforms import ToTensorV2 | |
import numpy as np | |
import gradio as gr | |
import torch | |
from torch import nn | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from PIL import Image | |
from transformers import SegformerForSemanticSegmentation | |
model = SegformerForSemanticSegmentation.from_pretrained('s3nh/SegFormer-b0-person-segmentation') | |
def inference(image, chosen_model): | |
# Transforms | |
_transform = A.Compose([ | |
A.Resize(height = 512, width=512), | |
ToTensorV2(), | |
]) | |
trans_image = _transform(image=np.array(image)) | |
outputs = model(trans_image['image'].float().unsqueeze(0)) | |
logits = outputs.logits | |
output = torch.sigmoid(logits).detach().numpy()[0] | |
# output = np.transpose(output, (1,2,0)) | |
# upsampled_logits = nn.functional.interpolate(logits, | |
# size=image.size[::-1], # (height, width) | |
# mode='bilinear', | |
# align_corners=False) | |
seg = output | |
color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) # height, width, 3 | |
palette = np.array([[0, 0, 0],[255, 255, 255]]) | |
for label, color in enumerate(palette): | |
color_seg[seg == label] = color | |
# Convert to BGR | |
color_seg = color_seg[..., ::-1] | |
img = np.array(image) * 0.5 + color_seg * 0.5 | |
output = Image.fromarray(img.astype(np.uint8)) | |
return output | |
demo = gr.Interface( | |
inference, | |
inputs = gr.Image(), | |
outputs= gr.Image(type="pil"), | |
title='Segformer B0 - People segmentation', | |
description='Segformer', | |
) | |
demo.launch() |