chansung commited on
Commit
cf99308
·
1 Parent(s): 7b929e1

Create new file

Browse files
Files changed (1) hide show
  1. app.py +243 -0
app.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ import numpy as np
4
+ from PIL import Image
5
+ from matplotlib import gridspec
6
+ import matplotlib.pyplot as plt
7
+ import onnxruntime as ort
8
+
9
+ import wget
10
+
11
+ def ade_palette():
12
+ """ADE20K palette that maps each class to RGB values."""
13
+ return [
14
+ [120, 120, 120],
15
+ [180, 120, 120],
16
+ [6, 230, 230],
17
+ [80, 50, 50],
18
+ [4, 200, 3],
19
+ [120, 120, 80],
20
+ [140, 140, 140],
21
+ [204, 5, 255],
22
+ [230, 230, 230],
23
+ [4, 250, 7],
24
+ [224, 5, 255],
25
+ [235, 255, 7],
26
+ [150, 5, 61],
27
+ [120, 120, 70],
28
+ [8, 255, 51],
29
+ [255, 6, 82],
30
+ [143, 255, 140],
31
+ [204, 255, 4],
32
+ [255, 51, 7],
33
+ [204, 70, 3],
34
+ [0, 102, 200],
35
+ [61, 230, 250],
36
+ [255, 6, 51],
37
+ [11, 102, 255],
38
+ [255, 7, 71],
39
+ [255, 9, 224],
40
+ [9, 7, 230],
41
+ [220, 220, 220],
42
+ [255, 9, 92],
43
+ [112, 9, 255],
44
+ [8, 255, 214],
45
+ [7, 255, 224],
46
+ [255, 184, 6],
47
+ [10, 255, 71],
48
+ [255, 41, 10],
49
+ [7, 255, 255],
50
+ [224, 255, 8],
51
+ [102, 8, 255],
52
+ [255, 61, 6],
53
+ [255, 194, 7],
54
+ [255, 122, 8],
55
+ [0, 255, 20],
56
+ [255, 8, 41],
57
+ [255, 5, 153],
58
+ [6, 51, 255],
59
+ [235, 12, 255],
60
+ [160, 150, 20],
61
+ [0, 163, 255],
62
+ [140, 140, 140],
63
+ [250, 10, 15],
64
+ [20, 255, 0],
65
+ [31, 255, 0],
66
+ [255, 31, 0],
67
+ [255, 224, 0],
68
+ [153, 255, 0],
69
+ [0, 0, 255],
70
+ [255, 71, 0],
71
+ [0, 235, 255],
72
+ [0, 173, 255],
73
+ [31, 0, 255],
74
+ [11, 200, 200],
75
+ [255, 82, 0],
76
+ [0, 255, 245],
77
+ [0, 61, 255],
78
+ [0, 255, 112],
79
+ [0, 255, 133],
80
+ [255, 0, 0],
81
+ [255, 163, 0],
82
+ [255, 102, 0],
83
+ [194, 255, 0],
84
+ [0, 143, 255],
85
+ [51, 255, 0],
86
+ [0, 82, 255],
87
+ [0, 255, 41],
88
+ [0, 255, 173],
89
+ [10, 0, 255],
90
+ [173, 255, 0],
91
+ [0, 255, 153],
92
+ [255, 92, 0],
93
+ [255, 0, 255],
94
+ [255, 0, 245],
95
+ [255, 0, 102],
96
+ [255, 173, 0],
97
+ [255, 0, 20],
98
+ [255, 184, 184],
99
+ [0, 31, 255],
100
+ [0, 255, 61],
101
+ [0, 71, 255],
102
+ [255, 0, 204],
103
+ [0, 255, 194],
104
+ [0, 255, 82],
105
+ [0, 10, 255],
106
+ [0, 112, 255],
107
+ [51, 0, 255],
108
+ [0, 194, 255],
109
+ [0, 122, 255],
110
+ [0, 255, 163],
111
+ [255, 153, 0],
112
+ [0, 255, 10],
113
+ [255, 112, 0],
114
+ [143, 255, 0],
115
+ [82, 0, 255],
116
+ [163, 255, 0],
117
+ [255, 235, 0],
118
+ [8, 184, 170],
119
+ [133, 0, 255],
120
+ [0, 255, 92],
121
+ [184, 0, 255],
122
+ [255, 0, 31],
123
+ [0, 184, 255],
124
+ [0, 214, 255],
125
+ [255, 0, 112],
126
+ [92, 255, 0],
127
+ [0, 224, 255],
128
+ [112, 224, 255],
129
+ [70, 184, 160],
130
+ [163, 0, 255],
131
+ [153, 0, 255],
132
+ [71, 255, 0],
133
+ [255, 0, 163],
134
+ [255, 204, 0],
135
+ [255, 0, 143],
136
+ [0, 255, 235],
137
+ [133, 255, 0],
138
+ [255, 0, 235],
139
+ [245, 0, 255],
140
+ [255, 0, 122],
141
+ [255, 245, 0],
142
+ [10, 190, 212],
143
+ [214, 255, 0],
144
+ [0, 204, 255],
145
+ [20, 0, 255],
146
+ [255, 255, 0],
147
+ [0, 153, 255],
148
+ [0, 41, 255],
149
+ [0, 255, 204],
150
+ [41, 0, 255],
151
+ [41, 255, 0],
152
+ [173, 0, 255],
153
+ [0, 245, 255],
154
+ [71, 0, 255],
155
+ [122, 0, 255],
156
+ [0, 255, 184],
157
+ [0, 92, 255],
158
+ [184, 255, 0],
159
+ [0, 133, 255],
160
+ [255, 214, 0],
161
+ [25, 194, 194],
162
+ [102, 255, 0],
163
+ [92, 0, 255],
164
+ ]
165
+
166
+ url='https://github.com/deep-diver/segformer-tf-transformers/releases/download/1.0/segformer-b5-finetuned-ade-640-640.onnx'
167
+ labels_list = []
168
+ colormap = np.asarray(ade_palette())
169
+
170
+ model_path = wget.download(url)
171
+ sess = ort.InferenceSession(model_path)
172
+
173
+ with open(r'labels.txt', 'r') as fp:
174
+ for line in fp:
175
+ labels_list.append(line[:-1])
176
+
177
+ def label_to_color_image(label):
178
+ if label.ndim != 2:
179
+ raise ValueError("Expect 2-D input label")
180
+
181
+ if np.max(label) >= len(colormap):
182
+ raise ValueError("label value too large.")
183
+
184
+ return colormap[label]
185
+
186
+ def draw_plot(pred_img, seg):
187
+ fig = plt.figure(figsize=(20, 15))
188
+
189
+ grid_spec = gridspec.GridSpec(1, 2, width_ratios=[6, 1])
190
+
191
+ plt.subplot(grid_spec[0])
192
+ plt.imshow(pred_img)
193
+ plt.axis('off')
194
+
195
+ LABEL_NAMES = np.asarray(labels_list)
196
+ FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1)
197
+ FULL_COLOR_MAP = label_to_color_image(FULL_LABEL_MAP)
198
+
199
+ unique_labels = np.unique(seg.numpy().astype("uint8"))
200
+ ax = plt.subplot(grid_spec[1])
201
+ plt.imshow(FULL_COLOR_MAP[unique_labels].astype(np.uint8), interpolation="nearest")
202
+ ax.yaxis.tick_right()
203
+ plt.yticks(range(len(unique_labels)), LABEL_NAMES[unique_labels])
204
+ plt.xticks([], [])
205
+ ax.tick_params(width=0.0, labelsize=25)
206
+ return fig
207
+
208
+ def sepia(input_img):
209
+ input_img = Image.fromarray(input_img)
210
+ outputs = sess.run(None, {"pixel_values": input_img})
211
+
212
+ logits = outputs.logits
213
+
214
+ logits = tf.transpose(logits, [0, 2, 3, 1])
215
+ logits = tf.image.resize(
216
+ logits, input_img.size[::-1]
217
+ ) # We reverse the shape of `image` because `image.size` returns width and height.
218
+ seg = tf.math.argmax(logits, axis=-1)[0]
219
+
220
+ color_seg = np.zeros(
221
+ (seg.shape[0], seg.shape[1], 3), dtype=np.uint8
222
+ ) # height, width, 3
223
+
224
+ for label, color in enumerate(colormap):
225
+ color_seg[seg == label, :] = color
226
+
227
+ # Convert to BGR
228
+ color_seg = color_seg[..., ::-1]
229
+
230
+ # Show image + mask
231
+ pred_img = np.array(input_img) * 0.5 + color_seg * 0.5
232
+ pred_img = pred_img.astype(np.uint8)
233
+
234
+ fig = draw_plot(pred_img, seg)
235
+ return fig
236
+
237
+ demo = gr.Interface(sepia,
238
+ gr.Image(shape=(200, 200)),
239
+ outputs=['plot'],
240
+ # examples=["ADE_val_00000001.jpeg"],
241
+ allow_flagging='never')
242
+
243
+ demo.launch()