s3nh commited on
Commit
97b896c
·
1 Parent(s): 9e84ddc

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -0
app.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import albumentations as A
2
+ import numpy as np
3
+
4
+ import gradio as gr
5
+ import torch
6
+ from torch import nn
7
+ import numpy as np
8
+ import matplotlib.pyplot as plt
9
+ from ade20k_colors import colors
10
+ from PIL import Image
11
+ from transformers import SegformerForSemanticSegmentation
12
+
13
+ model = SegformerForSemanticSegmentation.from_pretrained('s3nh/SegFormer-b0-person-segmentation')
14
+
15
+ def inference(image, chosen_model):
16
+
17
+ # Transforms
18
+ _transform = A.Compose([
19
+ A.Resize(height = 512, width=512),
20
+ ToTensorV2(),
21
+ ])
22
+ trans_image = _transform(image=np.array(image))
23
+ outputs = model(trans_image['image'].float().unsqueeze(0))
24
+
25
+
26
+ logits = outputs.logits
27
+
28
+ output = torch.sigmoid(logits).detach().numpy()[0]
29
+ output = np.transpose(output, (1,2,0))
30
+ upsampled_logits = nn.functional.interpolate(logits,
31
+ size=image.size[::-1], # (height, width)
32
+ mode='bilinear',
33
+ align_corners=False)
34
+
35
+ seg = upsampled_logits.argmax(dim=1)[0]
36
+ color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) # height, width, 3
37
+ palette = np.array([[0, 0, 0],[255, 255, 255]])
38
+ for label, color in enumerate(palette):
39
+ color_seg[seg == label, :] = color
40
+ # Convert to BGR
41
+ color_seg = color_seg[..., ::-1]
42
+ img = np.array(image) * 0.5 + color_seg * 0.5
43
+ img = img.astype(np.uint8)
44
+ return Image.fromarray(img)
45
+
46
+
47
+ inputs = [gr.inputs.Image(label='Input Image'),
48
+ gr.inputs.Radio(['Base', 'Large'], label='BEiT Model', type='index')]
49
+
50
+ gr.Interface(
51
+ inference,
52
+ inputs,
53
+ gr.outputs.Image(label='Output'),
54
+ title='Segformer B0 - People segmentation',
55
+ description='Segformer',
56
+ ).launch()