adpro commited on
Commit
36d8e96
·
1 Parent(s): a4a91be

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -243
app.py CHANGED
@@ -1,251 +1,36 @@
1
- from __future__ import absolute_import, division, print_function
2
-
3
- import click
4
- import cv2
5
- import matplotlib
6
- import matplotlib.cm as cm
7
- import matplotlib.pyplot as plt
8
  import numpy as np
9
- import torch
10
- import torch.nn as nn
11
- import torch.nn.functional as F
12
- from omegaconf import OmegaConf
13
-
14
- from libs.models import *
15
- from libs.utils import DenseCRF
16
-
17
-
18
- def get_device(cuda):
19
- cuda = cuda and torch.cuda.is_available()
20
- device = torch.device("cuda" if cuda else "cpu")
21
- if cuda:
22
- current_device = torch.cuda.current_device()
23
- print("Device:", torch.cuda.get_device_name(current_device))
24
- else:
25
- print("Device: CPU")
26
- return device
27
-
28
-
29
- def get_classtable(CONFIG):
30
- with open(CONFIG.DATASET.LABELS) as f:
31
- classes = {}
32
- for label in f:
33
- label = label.rstrip().split("\t")
34
- classes[int(label[0])] = label[1].split(",")[0]
35
- return classes
36
-
37
-
38
- def setup_postprocessor(CONFIG):
39
- # CRF post-processor
40
- postprocessor = DenseCRF(
41
- iter_max=CONFIG.CRF.ITER_MAX,
42
- pos_xy_std=CONFIG.CRF.POS_XY_STD,
43
- pos_w=CONFIG.CRF.POS_W,
44
- bi_xy_std=CONFIG.CRF.BI_XY_STD,
45
- bi_rgb_std=CONFIG.CRF.BI_RGB_STD,
46
- bi_w=CONFIG.CRF.BI_W,
47
- )
48
- return postprocessor
49
-
50
-
51
- def preprocessing(image, device, CONFIG):
52
- # Resize
53
- scale = CONFIG.IMAGE.SIZE.TEST / max(image.shape[:2])
54
- image = cv2.resize(image, dsize=None, fx=scale, fy=scale)
55
- raw_image = image.astype(np.uint8)
56
-
57
- # Subtract mean values
58
- image = image.astype(np.float32)
59
- image -= np.array(
60
- [
61
- float(CONFIG.IMAGE.MEAN.B),
62
- float(CONFIG.IMAGE.MEAN.G),
63
- float(CONFIG.IMAGE.MEAN.R),
64
- ]
65
- )
66
-
67
- # Convert to torch.Tensor and add "batch" axis
68
- image = torch.from_numpy(image.transpose(2, 0, 1)).float().unsqueeze(0)
69
- image = image.to(device)
70
-
71
- return image, raw_image
72
-
73
-
74
- def inference(model, image, raw_image=None, postprocessor=None):
75
- _, _, H, W = image.shape
76
-
77
- # Image -> Probability map
78
- logits = model(image)
79
- logits = F.interpolate(logits, size=(H, W), mode="bilinear", align_corners=False)
80
- probs = F.softmax(logits, dim=1)[0]
81
- probs = probs.cpu().numpy()
82
-
83
- # Refine the prob map with CRF
84
- if postprocessor and raw_image is not None:
85
- probs = postprocessor(raw_image, probs)
86
-
87
- labelmap = np.argmax(probs, axis=0)
88
-
89
- return labelmap
90
-
91
-
92
- @click.group()
93
- @click.pass_context
94
- def main(ctx):
95
- """
96
- Demo with a trained model
97
- """
98
-
99
- print("Mode:", ctx.invoked_subcommand)
100
-
101
-
102
- @main.command()
103
- @click.option(
104
- "-c",
105
- "--config-path",
106
- type=click.File(),
107
- required=True,
108
- help="Dataset configuration file in YAML",
109
- )
110
- @click.option(
111
- "-m",
112
- "--model-path",
113
- type=click.Path(exists=True),
114
- required=True,
115
- help="PyTorch model to be loaded",
116
- )
117
- @click.option(
118
- "-i",
119
- "--image-path",
120
- type=click.Path(exists=True),
121
- required=True,
122
- help="Image to be processed",
123
- )
124
- @click.option(
125
- "--cuda/--cpu", default=True, help="Enable CUDA if available [default: --cuda]"
126
- )
127
- @click.option("--crf", is_flag=True, show_default=True, help="CRF post-processing")
128
- def single(config_path, model_path, image_path, cuda, crf):
129
- """
130
- Inference from a single image
131
- """
132
-
133
- # Setup
134
- CONFIG = OmegaConf.load(config_path)
135
- device = get_device(cuda)
136
- torch.set_grad_enabled(False)
137
-
138
- classes = get_classtable(CONFIG)
139
- postprocessor = setup_postprocessor(CONFIG) if crf else None
140
-
141
- model = eval(CONFIG.MODEL.NAME)(n_classes=CONFIG.DATASET.N_CLASSES)
142
- state_dict = torch.load(model_path, map_location=lambda storage, loc: storage)
143
- model.load_state_dict(state_dict)
144
- model.eval()
145
- model.to(device)
146
- print("Model:", CONFIG.MODEL.NAME)
147
-
148
- # Inference
149
- image = cv2.imread(image_path, cv2.IMREAD_COLOR)
150
- image, raw_image = preprocessing(image, device, CONFIG)
151
- labelmap = inference(model, image, raw_image, postprocessor)
152
- labels = np.unique(labelmap)
153
-
154
- # Show result for each class
155
- rows = np.floor(np.sqrt(len(labels) + 1))
156
- cols = np.ceil((len(labels) + 1) / rows)
157
-
158
- plt.figure(figsize=(10, 10))
159
- ax = plt.subplot(rows, cols, 1)
160
- ax.set_title("Input image")
161
- ax.imshow(raw_image[:, :, ::-1])
162
- ax.axis("off")
163
-
164
- for i, label in enumerate(labels):
165
- mask = labelmap == label
166
- ax = plt.subplot(rows, cols, i + 2)
167
- ax.set_title(classes[label])
168
- ax.imshow(raw_image[..., ::-1])
169
- ax.imshow(mask.astype(np.float32), alpha=0.5)
170
- ax.axis("off")
171
-
172
- plt.tight_layout()
173
- plt.show()
174
-
175
-
176
- @main.command()
177
- @click.option(
178
- "-c",
179
- "--config-path",
180
- type=click.File(),
181
- required=True,
182
- help="Dataset configuration file in YAML",
183
- )
184
- @click.option(
185
- "-m",
186
- "--model-path",
187
- type=click.Path(exists=True),
188
- required=True,
189
- help="PyTorch model to be loaded",
190
- )
191
- @click.option(
192
- "--cuda/--cpu", default=True, help="Enable CUDA if available [default: --cuda]"
193
- )
194
- @click.option("--crf", is_flag=True, show_default=True, help="CRF post-processing")
195
- @click.option("--camera-id", type=int, default=0, show_default=True, help="Device ID")
196
- def live(config_path, model_path, cuda, crf, camera_id):
197
- """
198
- Inference from camera stream
199
- """
200
-
201
- # Setup
202
- CONFIG = OmegaConf.load(config_path)
203
- device = get_device(cuda)
204
- torch.set_grad_enabled(False)
205
- torch.backends.cudnn.benchmark = True
206
-
207
- classes = get_classtable(CONFIG)
208
- postprocessor = setup_postprocessor(CONFIG) if crf else None
209
-
210
- model = eval(CONFIG.MODEL.NAME)(n_classes=CONFIG.DATASET.N_CLASSES)
211
- state_dict = torch.load(model_path, map_location=lambda storage, loc: storage)
212
- model.load_state_dict(state_dict)
213
- model.eval()
214
- model.to(device)
215
- print("Model:", CONFIG.MODEL.NAME)
216
-
217
- # UVC camera stream
218
- cap = cv2.VideoCapture(camera_id)
219
- cap.set(cv2.CAP_PROP_FOURCC, cv2.VideoWriter_fourcc(*"YUYV"))
220
 
221
- def colorize(labelmap):
222
- # Assign a unique color to each label
223
- labelmap = labelmap.astype(np.float32) / CONFIG.DATASET.N_CLASSES
224
- colormap = cm.jet_r(labelmap)[..., :-1] * 255.0
225
- return np.uint8(colormap)
226
 
227
- def mouse_event(event, x, y, flags, labelmap):
228
- # Show a class name of a mouse-overed pixel
229
- label = labelmap[y, x]
230
- name = classes[label]
231
- print(name)
232
 
233
- window_name = "{} + {}".format(CONFIG.MODEL.NAME, CONFIG.DATASET.NAME)
234
- cv2.namedWindow(window_name, cv2.WINDOW_AUTOSIZE)
235
 
236
- while True:
237
- _, frame = cap.read()
238
- image, raw_image = preprocessing(frame, device, CONFIG)
239
- labelmap = inference(model, image, raw_image, postprocessor)
240
- colormap = colorize(labelmap)
 
 
 
241
 
242
- # Register mouse callback function
243
- cv2.setMouseCallback(window_name, mouse_event, labelmap)
244
 
245
- # Overlay prediction
246
- cv2.addWeighted(colormap, 0.5, raw_image, 0.5, 0.0, raw_image)
 
 
247
 
248
- # Quit by pressing "q" key
249
- cv2.imshow(window_name, raw_image)
250
- if cv2.waitKey(10) == ord("q"):
251
- break
 
 
 
 
1
+ import gluoncv
2
+ import mxnet as mx
3
+ from gluoncv.utils.viz import get_color_pallete
4
+ import gradio as gr
 
 
 
5
  import numpy as np
6
+ from PIL import Image
7
+ from gluoncv.data.transforms.presets.segmentation import test_transform
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
+ # using cpu
10
+ ctx = mx.cpu(0)
 
 
 
11
 
12
+ model = gluoncv.model_zoo.get_model("deeplab_resnet101_ade", pretrained=True)
 
 
 
 
13
 
 
 
14
 
15
+ def segmentation(image):
16
+ img = Image.fromarray(image)
17
+ img = mx.ndarray.array(img)
18
+ img = test_transform(img, ctx)
19
+ output = model.predict(img)
20
+ predict = mx.nd.squeeze(mx.nd.argmax(output, 1)).asnumpy()
21
+ mask = get_color_pallete(predict, "ade20k")
22
+ return mask
23
 
 
 
24
 
25
+ image_in = gr.Image()
26
+ image_out = gr.components.Image()
27
+ description = "MXNet Image Segmentation Model"
28
+ examples=['cat.jpeg']
29
 
30
+ Iface = gr.Interface(
31
+ fn=segmentation,
32
+ inputs=image_in,
33
+ outputs=image_out,
34
+ title="Semantic Segmentation - MXNet",
35
+ examples=examples
36
+ ).launch()