Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,37 +1,251 @@
|
|
1 |
-
from
|
2 |
-
|
3 |
-
import
|
4 |
-
|
5 |
-
import
|
|
|
|
|
6 |
import numpy as np
|
7 |
-
|
8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
-
|
11 |
-
|
|
|
|
|
|
|
12 |
|
13 |
-
|
|
|
|
|
|
|
|
|
14 |
|
|
|
|
|
15 |
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
predict = mx.nd.squeeze(mx.nd.argmax(output, 1)).asnumpy()
|
22 |
-
mask = get_color_pallete(predict, "ade20k")
|
23 |
-
return mask
|
24 |
|
|
|
|
|
25 |
|
26 |
-
|
27 |
-
|
28 |
-
description = "MXNet Image Segmentation Model"
|
29 |
-
examples=['cat.jpeg']
|
30 |
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
title="Semantic Segmentation - MXNet",
|
36 |
-
examples=examples
|
37 |
-
).launch()
|
|
|
1 |
+
from __future__ import absolute_import, division, print_function
|
2 |
+
|
3 |
+
import click
|
4 |
+
import cv2
|
5 |
+
import matplotlib
|
6 |
+
import matplotlib.cm as cm
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
import numpy as np
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.nn.functional as F
|
12 |
+
from omegaconf import OmegaConf
|
13 |
+
|
14 |
+
from libs.models import *
|
15 |
+
from libs.utils import DenseCRF
|
16 |
+
|
17 |
+
|
18 |
+
def get_device(cuda):
|
19 |
+
cuda = cuda and torch.cuda.is_available()
|
20 |
+
device = torch.device("cuda" if cuda else "cpu")
|
21 |
+
if cuda:
|
22 |
+
current_device = torch.cuda.current_device()
|
23 |
+
print("Device:", torch.cuda.get_device_name(current_device))
|
24 |
+
else:
|
25 |
+
print("Device: CPU")
|
26 |
+
return device
|
27 |
+
|
28 |
+
|
29 |
+
def get_classtable(CONFIG):
|
30 |
+
with open(CONFIG.DATASET.LABELS) as f:
|
31 |
+
classes = {}
|
32 |
+
for label in f:
|
33 |
+
label = label.rstrip().split("\t")
|
34 |
+
classes[int(label[0])] = label[1].split(",")[0]
|
35 |
+
return classes
|
36 |
+
|
37 |
+
|
38 |
+
def setup_postprocessor(CONFIG):
|
39 |
+
# CRF post-processor
|
40 |
+
postprocessor = DenseCRF(
|
41 |
+
iter_max=CONFIG.CRF.ITER_MAX,
|
42 |
+
pos_xy_std=CONFIG.CRF.POS_XY_STD,
|
43 |
+
pos_w=CONFIG.CRF.POS_W,
|
44 |
+
bi_xy_std=CONFIG.CRF.BI_XY_STD,
|
45 |
+
bi_rgb_std=CONFIG.CRF.BI_RGB_STD,
|
46 |
+
bi_w=CONFIG.CRF.BI_W,
|
47 |
+
)
|
48 |
+
return postprocessor
|
49 |
+
|
50 |
+
|
51 |
+
def preprocessing(image, device, CONFIG):
|
52 |
+
# Resize
|
53 |
+
scale = CONFIG.IMAGE.SIZE.TEST / max(image.shape[:2])
|
54 |
+
image = cv2.resize(image, dsize=None, fx=scale, fy=scale)
|
55 |
+
raw_image = image.astype(np.uint8)
|
56 |
+
|
57 |
+
# Subtract mean values
|
58 |
+
image = image.astype(np.float32)
|
59 |
+
image -= np.array(
|
60 |
+
[
|
61 |
+
float(CONFIG.IMAGE.MEAN.B),
|
62 |
+
float(CONFIG.IMAGE.MEAN.G),
|
63 |
+
float(CONFIG.IMAGE.MEAN.R),
|
64 |
+
]
|
65 |
+
)
|
66 |
+
|
67 |
+
# Convert to torch.Tensor and add "batch" axis
|
68 |
+
image = torch.from_numpy(image.transpose(2, 0, 1)).float().unsqueeze(0)
|
69 |
+
image = image.to(device)
|
70 |
+
|
71 |
+
return image, raw_image
|
72 |
+
|
73 |
+
|
74 |
+
def inference(model, image, raw_image=None, postprocessor=None):
|
75 |
+
_, _, H, W = image.shape
|
76 |
+
|
77 |
+
# Image -> Probability map
|
78 |
+
logits = model(image)
|
79 |
+
logits = F.interpolate(logits, size=(H, W), mode="bilinear", align_corners=False)
|
80 |
+
probs = F.softmax(logits, dim=1)[0]
|
81 |
+
probs = probs.cpu().numpy()
|
82 |
+
|
83 |
+
# Refine the prob map with CRF
|
84 |
+
if postprocessor and raw_image is not None:
|
85 |
+
probs = postprocessor(raw_image, probs)
|
86 |
+
|
87 |
+
labelmap = np.argmax(probs, axis=0)
|
88 |
+
|
89 |
+
return labelmap
|
90 |
+
|
91 |
+
|
92 |
+
@click.group()
|
93 |
+
@click.pass_context
|
94 |
+
def main(ctx):
|
95 |
+
"""
|
96 |
+
Demo with a trained model
|
97 |
+
"""
|
98 |
+
|
99 |
+
print("Mode:", ctx.invoked_subcommand)
|
100 |
+
|
101 |
+
|
102 |
+
@main.command()
|
103 |
+
@click.option(
|
104 |
+
"-c",
|
105 |
+
"--config-path",
|
106 |
+
type=click.File(),
|
107 |
+
required=True,
|
108 |
+
help="Dataset configuration file in YAML",
|
109 |
+
)
|
110 |
+
@click.option(
|
111 |
+
"-m",
|
112 |
+
"--model-path",
|
113 |
+
type=click.Path(exists=True),
|
114 |
+
required=True,
|
115 |
+
help="PyTorch model to be loaded",
|
116 |
+
)
|
117 |
+
@click.option(
|
118 |
+
"-i",
|
119 |
+
"--image-path",
|
120 |
+
type=click.Path(exists=True),
|
121 |
+
required=True,
|
122 |
+
help="Image to be processed",
|
123 |
+
)
|
124 |
+
@click.option(
|
125 |
+
"--cuda/--cpu", default=True, help="Enable CUDA if available [default: --cuda]"
|
126 |
+
)
|
127 |
+
@click.option("--crf", is_flag=True, show_default=True, help="CRF post-processing")
|
128 |
+
def single(config_path, model_path, image_path, cuda, crf):
|
129 |
+
"""
|
130 |
+
Inference from a single image
|
131 |
+
"""
|
132 |
+
|
133 |
+
# Setup
|
134 |
+
CONFIG = OmegaConf.load(config_path)
|
135 |
+
device = get_device(cuda)
|
136 |
+
torch.set_grad_enabled(False)
|
137 |
+
|
138 |
+
classes = get_classtable(CONFIG)
|
139 |
+
postprocessor = setup_postprocessor(CONFIG) if crf else None
|
140 |
+
|
141 |
+
model = eval(CONFIG.MODEL.NAME)(n_classes=CONFIG.DATASET.N_CLASSES)
|
142 |
+
state_dict = torch.load(model_path, map_location=lambda storage, loc: storage)
|
143 |
+
model.load_state_dict(state_dict)
|
144 |
+
model.eval()
|
145 |
+
model.to(device)
|
146 |
+
print("Model:", CONFIG.MODEL.NAME)
|
147 |
+
|
148 |
+
# Inference
|
149 |
+
image = cv2.imread(image_path, cv2.IMREAD_COLOR)
|
150 |
+
image, raw_image = preprocessing(image, device, CONFIG)
|
151 |
+
labelmap = inference(model, image, raw_image, postprocessor)
|
152 |
+
labels = np.unique(labelmap)
|
153 |
+
|
154 |
+
# Show result for each class
|
155 |
+
rows = np.floor(np.sqrt(len(labels) + 1))
|
156 |
+
cols = np.ceil((len(labels) + 1) / rows)
|
157 |
+
|
158 |
+
plt.figure(figsize=(10, 10))
|
159 |
+
ax = plt.subplot(rows, cols, 1)
|
160 |
+
ax.set_title("Input image")
|
161 |
+
ax.imshow(raw_image[:, :, ::-1])
|
162 |
+
ax.axis("off")
|
163 |
+
|
164 |
+
for i, label in enumerate(labels):
|
165 |
+
mask = labelmap == label
|
166 |
+
ax = plt.subplot(rows, cols, i + 2)
|
167 |
+
ax.set_title(classes[label])
|
168 |
+
ax.imshow(raw_image[..., ::-1])
|
169 |
+
ax.imshow(mask.astype(np.float32), alpha=0.5)
|
170 |
+
ax.axis("off")
|
171 |
+
|
172 |
+
plt.tight_layout()
|
173 |
+
plt.show()
|
174 |
+
|
175 |
+
|
176 |
+
@main.command()
|
177 |
+
@click.option(
|
178 |
+
"-c",
|
179 |
+
"--config-path",
|
180 |
+
type=click.File(),
|
181 |
+
required=True,
|
182 |
+
help="Dataset configuration file in YAML",
|
183 |
+
)
|
184 |
+
@click.option(
|
185 |
+
"-m",
|
186 |
+
"--model-path",
|
187 |
+
type=click.Path(exists=True),
|
188 |
+
required=True,
|
189 |
+
help="PyTorch model to be loaded",
|
190 |
+
)
|
191 |
+
@click.option(
|
192 |
+
"--cuda/--cpu", default=True, help="Enable CUDA if available [default: --cuda]"
|
193 |
+
)
|
194 |
+
@click.option("--crf", is_flag=True, show_default=True, help="CRF post-processing")
|
195 |
+
@click.option("--camera-id", type=int, default=0, show_default=True, help="Device ID")
|
196 |
+
def live(config_path, model_path, cuda, crf, camera_id):
|
197 |
+
"""
|
198 |
+
Inference from camera stream
|
199 |
+
"""
|
200 |
+
|
201 |
+
# Setup
|
202 |
+
CONFIG = OmegaConf.load(config_path)
|
203 |
+
device = get_device(cuda)
|
204 |
+
torch.set_grad_enabled(False)
|
205 |
+
torch.backends.cudnn.benchmark = True
|
206 |
+
|
207 |
+
classes = get_classtable(CONFIG)
|
208 |
+
postprocessor = setup_postprocessor(CONFIG) if crf else None
|
209 |
+
|
210 |
+
model = eval(CONFIG.MODEL.NAME)(n_classes=CONFIG.DATASET.N_CLASSES)
|
211 |
+
state_dict = torch.load(model_path, map_location=lambda storage, loc: storage)
|
212 |
+
model.load_state_dict(state_dict)
|
213 |
+
model.eval()
|
214 |
+
model.to(device)
|
215 |
+
print("Model:", CONFIG.MODEL.NAME)
|
216 |
+
|
217 |
+
# UVC camera stream
|
218 |
+
cap = cv2.VideoCapture(camera_id)
|
219 |
+
cap.set(cv2.CAP_PROP_FOURCC, cv2.VideoWriter_fourcc(*"YUYV"))
|
220 |
|
221 |
+
def colorize(labelmap):
|
222 |
+
# Assign a unique color to each label
|
223 |
+
labelmap = labelmap.astype(np.float32) / CONFIG.DATASET.N_CLASSES
|
224 |
+
colormap = cm.jet_r(labelmap)[..., :-1] * 255.0
|
225 |
+
return np.uint8(colormap)
|
226 |
|
227 |
+
def mouse_event(event, x, y, flags, labelmap):
|
228 |
+
# Show a class name of a mouse-overed pixel
|
229 |
+
label = labelmap[y, x]
|
230 |
+
name = classes[label]
|
231 |
+
print(name)
|
232 |
|
233 |
+
window_name = "{} + {}".format(CONFIG.MODEL.NAME, CONFIG.DATASET.NAME)
|
234 |
+
cv2.namedWindow(window_name, cv2.WINDOW_AUTOSIZE)
|
235 |
|
236 |
+
while True:
|
237 |
+
_, frame = cap.read()
|
238 |
+
image, raw_image = preprocessing(frame, device, CONFIG)
|
239 |
+
labelmap = inference(model, image, raw_image, postprocessor)
|
240 |
+
colormap = colorize(labelmap)
|
|
|
|
|
|
|
241 |
|
242 |
+
# Register mouse callback function
|
243 |
+
cv2.setMouseCallback(window_name, mouse_event, labelmap)
|
244 |
|
245 |
+
# Overlay prediction
|
246 |
+
cv2.addWeighted(colormap, 0.5, raw_image, 0.5, 0.0, raw_image)
|
|
|
|
|
247 |
|
248 |
+
# Quit by pressing "q" key
|
249 |
+
cv2.imshow(window_name, raw_image)
|
250 |
+
if cv2.waitKey(10) == ord("q"):
|
251 |
+
break
|
|
|
|
|
|