adpro commited on
Commit
a4a91be
·
1 Parent(s): ce4e094

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +243 -29
app.py CHANGED
@@ -1,37 +1,251 @@
1
- from torch import torch.hub
2
- import gluoncv
3
- import mxnet as mx
4
- from gluoncv.utils.viz import get_color_pallete
5
- import gradio as gr
 
 
6
  import numpy as np
7
- from PIL import Image
8
- from gluoncv.data.transforms.presets.segmentation import test_transform
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
- # using cpu
11
- ctx = mx.cpu(0)
 
 
 
12
 
13
- model = torch.hub.load("kazuto1011/deeplab-pytorch", "deeplabv2_resnet101", pretrained='cocostuff164k', n_classes=182)
 
 
 
 
14
 
 
 
15
 
16
- def segmentation(image):
17
- img = Image.fromarray(image)
18
- img = mx.ndarray.array(img)
19
- img = test_transform(img, ctx)
20
- output = model.predict(img)
21
- predict = mx.nd.squeeze(mx.nd.argmax(output, 1)).asnumpy()
22
- mask = get_color_pallete(predict, "ade20k")
23
- return mask
24
 
 
 
25
 
26
- image_in = gr.Image()
27
- image_out = gr.components.Image()
28
- description = "MXNet Image Segmentation Model"
29
- examples=['cat.jpeg']
30
 
31
- Iface = gr.Interface(
32
- fn=segmentation,
33
- inputs=image_in,
34
- outputs=image_out,
35
- title="Semantic Segmentation - MXNet",
36
- examples=examples
37
- ).launch()
 
1
+ from __future__ import absolute_import, division, print_function
2
+
3
+ import click
4
+ import cv2
5
+ import matplotlib
6
+ import matplotlib.cm as cm
7
+ import matplotlib.pyplot as plt
8
  import numpy as np
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from omegaconf import OmegaConf
13
+
14
+ from libs.models import *
15
+ from libs.utils import DenseCRF
16
+
17
+
18
+ def get_device(cuda):
19
+ cuda = cuda and torch.cuda.is_available()
20
+ device = torch.device("cuda" if cuda else "cpu")
21
+ if cuda:
22
+ current_device = torch.cuda.current_device()
23
+ print("Device:", torch.cuda.get_device_name(current_device))
24
+ else:
25
+ print("Device: CPU")
26
+ return device
27
+
28
+
29
+ def get_classtable(CONFIG):
30
+ with open(CONFIG.DATASET.LABELS) as f:
31
+ classes = {}
32
+ for label in f:
33
+ label = label.rstrip().split("\t")
34
+ classes[int(label[0])] = label[1].split(",")[0]
35
+ return classes
36
+
37
+
38
+ def setup_postprocessor(CONFIG):
39
+ # CRF post-processor
40
+ postprocessor = DenseCRF(
41
+ iter_max=CONFIG.CRF.ITER_MAX,
42
+ pos_xy_std=CONFIG.CRF.POS_XY_STD,
43
+ pos_w=CONFIG.CRF.POS_W,
44
+ bi_xy_std=CONFIG.CRF.BI_XY_STD,
45
+ bi_rgb_std=CONFIG.CRF.BI_RGB_STD,
46
+ bi_w=CONFIG.CRF.BI_W,
47
+ )
48
+ return postprocessor
49
+
50
+
51
+ def preprocessing(image, device, CONFIG):
52
+ # Resize
53
+ scale = CONFIG.IMAGE.SIZE.TEST / max(image.shape[:2])
54
+ image = cv2.resize(image, dsize=None, fx=scale, fy=scale)
55
+ raw_image = image.astype(np.uint8)
56
+
57
+ # Subtract mean values
58
+ image = image.astype(np.float32)
59
+ image -= np.array(
60
+ [
61
+ float(CONFIG.IMAGE.MEAN.B),
62
+ float(CONFIG.IMAGE.MEAN.G),
63
+ float(CONFIG.IMAGE.MEAN.R),
64
+ ]
65
+ )
66
+
67
+ # Convert to torch.Tensor and add "batch" axis
68
+ image = torch.from_numpy(image.transpose(2, 0, 1)).float().unsqueeze(0)
69
+ image = image.to(device)
70
+
71
+ return image, raw_image
72
+
73
+
74
+ def inference(model, image, raw_image=None, postprocessor=None):
75
+ _, _, H, W = image.shape
76
+
77
+ # Image -> Probability map
78
+ logits = model(image)
79
+ logits = F.interpolate(logits, size=(H, W), mode="bilinear", align_corners=False)
80
+ probs = F.softmax(logits, dim=1)[0]
81
+ probs = probs.cpu().numpy()
82
+
83
+ # Refine the prob map with CRF
84
+ if postprocessor and raw_image is not None:
85
+ probs = postprocessor(raw_image, probs)
86
+
87
+ labelmap = np.argmax(probs, axis=0)
88
+
89
+ return labelmap
90
+
91
+
92
+ @click.group()
93
+ @click.pass_context
94
+ def main(ctx):
95
+ """
96
+ Demo with a trained model
97
+ """
98
+
99
+ print("Mode:", ctx.invoked_subcommand)
100
+
101
+
102
+ @main.command()
103
+ @click.option(
104
+ "-c",
105
+ "--config-path",
106
+ type=click.File(),
107
+ required=True,
108
+ help="Dataset configuration file in YAML",
109
+ )
110
+ @click.option(
111
+ "-m",
112
+ "--model-path",
113
+ type=click.Path(exists=True),
114
+ required=True,
115
+ help="PyTorch model to be loaded",
116
+ )
117
+ @click.option(
118
+ "-i",
119
+ "--image-path",
120
+ type=click.Path(exists=True),
121
+ required=True,
122
+ help="Image to be processed",
123
+ )
124
+ @click.option(
125
+ "--cuda/--cpu", default=True, help="Enable CUDA if available [default: --cuda]"
126
+ )
127
+ @click.option("--crf", is_flag=True, show_default=True, help="CRF post-processing")
128
+ def single(config_path, model_path, image_path, cuda, crf):
129
+ """
130
+ Inference from a single image
131
+ """
132
+
133
+ # Setup
134
+ CONFIG = OmegaConf.load(config_path)
135
+ device = get_device(cuda)
136
+ torch.set_grad_enabled(False)
137
+
138
+ classes = get_classtable(CONFIG)
139
+ postprocessor = setup_postprocessor(CONFIG) if crf else None
140
+
141
+ model = eval(CONFIG.MODEL.NAME)(n_classes=CONFIG.DATASET.N_CLASSES)
142
+ state_dict = torch.load(model_path, map_location=lambda storage, loc: storage)
143
+ model.load_state_dict(state_dict)
144
+ model.eval()
145
+ model.to(device)
146
+ print("Model:", CONFIG.MODEL.NAME)
147
+
148
+ # Inference
149
+ image = cv2.imread(image_path, cv2.IMREAD_COLOR)
150
+ image, raw_image = preprocessing(image, device, CONFIG)
151
+ labelmap = inference(model, image, raw_image, postprocessor)
152
+ labels = np.unique(labelmap)
153
+
154
+ # Show result for each class
155
+ rows = np.floor(np.sqrt(len(labels) + 1))
156
+ cols = np.ceil((len(labels) + 1) / rows)
157
+
158
+ plt.figure(figsize=(10, 10))
159
+ ax = plt.subplot(rows, cols, 1)
160
+ ax.set_title("Input image")
161
+ ax.imshow(raw_image[:, :, ::-1])
162
+ ax.axis("off")
163
+
164
+ for i, label in enumerate(labels):
165
+ mask = labelmap == label
166
+ ax = plt.subplot(rows, cols, i + 2)
167
+ ax.set_title(classes[label])
168
+ ax.imshow(raw_image[..., ::-1])
169
+ ax.imshow(mask.astype(np.float32), alpha=0.5)
170
+ ax.axis("off")
171
+
172
+ plt.tight_layout()
173
+ plt.show()
174
+
175
+
176
+ @main.command()
177
+ @click.option(
178
+ "-c",
179
+ "--config-path",
180
+ type=click.File(),
181
+ required=True,
182
+ help="Dataset configuration file in YAML",
183
+ )
184
+ @click.option(
185
+ "-m",
186
+ "--model-path",
187
+ type=click.Path(exists=True),
188
+ required=True,
189
+ help="PyTorch model to be loaded",
190
+ )
191
+ @click.option(
192
+ "--cuda/--cpu", default=True, help="Enable CUDA if available [default: --cuda]"
193
+ )
194
+ @click.option("--crf", is_flag=True, show_default=True, help="CRF post-processing")
195
+ @click.option("--camera-id", type=int, default=0, show_default=True, help="Device ID")
196
+ def live(config_path, model_path, cuda, crf, camera_id):
197
+ """
198
+ Inference from camera stream
199
+ """
200
+
201
+ # Setup
202
+ CONFIG = OmegaConf.load(config_path)
203
+ device = get_device(cuda)
204
+ torch.set_grad_enabled(False)
205
+ torch.backends.cudnn.benchmark = True
206
+
207
+ classes = get_classtable(CONFIG)
208
+ postprocessor = setup_postprocessor(CONFIG) if crf else None
209
+
210
+ model = eval(CONFIG.MODEL.NAME)(n_classes=CONFIG.DATASET.N_CLASSES)
211
+ state_dict = torch.load(model_path, map_location=lambda storage, loc: storage)
212
+ model.load_state_dict(state_dict)
213
+ model.eval()
214
+ model.to(device)
215
+ print("Model:", CONFIG.MODEL.NAME)
216
+
217
+ # UVC camera stream
218
+ cap = cv2.VideoCapture(camera_id)
219
+ cap.set(cv2.CAP_PROP_FOURCC, cv2.VideoWriter_fourcc(*"YUYV"))
220
 
221
+ def colorize(labelmap):
222
+ # Assign a unique color to each label
223
+ labelmap = labelmap.astype(np.float32) / CONFIG.DATASET.N_CLASSES
224
+ colormap = cm.jet_r(labelmap)[..., :-1] * 255.0
225
+ return np.uint8(colormap)
226
 
227
+ def mouse_event(event, x, y, flags, labelmap):
228
+ # Show a class name of a mouse-overed pixel
229
+ label = labelmap[y, x]
230
+ name = classes[label]
231
+ print(name)
232
 
233
+ window_name = "{} + {}".format(CONFIG.MODEL.NAME, CONFIG.DATASET.NAME)
234
+ cv2.namedWindow(window_name, cv2.WINDOW_AUTOSIZE)
235
 
236
+ while True:
237
+ _, frame = cap.read()
238
+ image, raw_image = preprocessing(frame, device, CONFIG)
239
+ labelmap = inference(model, image, raw_image, postprocessor)
240
+ colormap = colorize(labelmap)
 
 
 
241
 
242
+ # Register mouse callback function
243
+ cv2.setMouseCallback(window_name, mouse_event, labelmap)
244
 
245
+ # Overlay prediction
246
+ cv2.addWeighted(colormap, 0.5, raw_image, 0.5, 0.0, raw_image)
 
 
247
 
248
+ # Quit by pressing "q" key
249
+ cv2.imshow(window_name, raw_image)
250
+ if cv2.waitKey(10) == ord("q"):
251
+ break