ONNX
Aser Atawya commited on
Commit
f5570ff
·
0 Parent(s):

Google Summer of Code: Adding RAFT Optical Flow Model using ONNX Format (#197)

Browse files

* RAFT ONNX GSoC

* use only opencv instead of onnx

* correct typo in help display message of demo.py

* add video functionality

* Add some clarity to README.md

Files changed (5) hide show
  1. BSD-3-LICENSE.txt +29 -0
  2. MITLICENSE.txt +21 -0
  3. README.md +68 -0
  4. demo.py +310 -0
  5. raft.py +53 -0
BSD-3-LICENSE.txt ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ BSD 3-Clause License
2
+
3
+ Copyright (c) 2020, princeton-vl
4
+ All rights reserved.
5
+
6
+ Redistribution and use in source and binary forms, with or without
7
+ modification, are permitted provided that the following conditions are met:
8
+
9
+ * Redistributions of source code must retain the above copyright notice, this
10
+ list of conditions and the following disclaimer.
11
+
12
+ * Redistributions in binary form must reproduce the above copyright notice,
13
+ this list of conditions and the following disclaimer in the documentation
14
+ and/or other materials provided with the distribution.
15
+
16
+ * Neither the name of the copyright holder nor the names of its
17
+ contributors may be used to endorse or promote products derived from
18
+ this software without specific prior written permission.
19
+
20
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
MITLICENSE.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2021 Jeong-gi Kwak
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # RAFT
2
+ This model is originally created by Zachary Teed and Jia Deng of Princeton University. The source code for the model is at [their repository on GitHub](https://github.com/princeton-vl/RAFT), and the original [research paper](https://arxiv.org/abs/2003.12039) is published on [Arxiv](https://arxiv.org/abs/2003.12039). The model was converted to ONNX by [PINTO0309](https://github.com/PINTO0309) in his [model zoo](https://github.com/PINTO0309/PINTO_model_zoo/tree/main/252_RAFT). The ONNX model has several variations depending on the training dataset and input dimesnions. The model used in this demo is trained on Sintel dataset with input size of 360 $\times$ 480.
3
+
4
+
5
+ ## Demo
6
+
7
+ Run any of the following commands to try the demo:
8
+
9
+ ```shell
10
+ # run on camera input
11
+ python demo.py
12
+
13
+ # run on two images and visualize result
14
+ python demo.py --input1 /path/to/image1 --input2 /path/to/image2 -vis
15
+
16
+ # run on two images and save result
17
+ python demo.py --input1 /path/to/image1 --input2 /path/to/image2 -s
18
+
19
+ # run on two images and both save and visualize result
20
+ python demo.py --input1 /path/to/image1 --input2 /path/to/image2 -s -vis
21
+
22
+ # run on one video and visualize result
23
+ python demo.py --video /path/to/video -vis
24
+
25
+ # run on one video and save result
26
+ python demo.py --video /path/to/video -s
27
+
28
+ # run on one video and both save and visualize result
29
+ python demo.py --video /path/to/video -s -vis
30
+
31
+ # get help regarding various parameters
32
+ python demo.py --help
33
+ ```
34
+
35
+ While running on video, you can press q anytime to stop. The model demo runs on camera input, video input, or takes two images to compute optical flow across frames. The save and vis arguments of the shell command are only valid in the case of using video or two images as input. To run a different variation of the model, such as a model trained on a different dataset or with a different input size, refer to [RAFT ONNX in PINTO Model Zoo](https://github.com/PINTO0309/PINTO_model_zoo/tree/main/252_RAFT) to download your chosen model. And if your chosen model has different input shape from 360 $\times$ 480, **change the input shape in raft.py line 15 to the new input shape**. Then, add the model path to the --model argument of the shell command, such as in the following example commands:
36
+
37
+ ```shell
38
+ # run on camera input
39
+ python demo.py --model /path/to/model
40
+ # run on two images
41
+ python demo.py --input1 /path/to/image1 --input2 /path/to/image2 --model /path/to/model
42
+ # run on video
43
+ python demo.py --video /path/to/video --model /path/to/model
44
+ ```
45
+
46
+ ### Example outputs
47
+ The visualization argument displays both image inputs as well as out result.
48
+
49
+ ![Visualization example](./example_outputs/vis.png)
50
+
51
+ The save argument saves the result only.
52
+
53
+ ![Output example](./example_outputs/result.jpg)
54
+
55
+
56
+
57
+ ## License
58
+
59
+ The original RAFT model is under [BSD-3-Clause license](./BSD-3-LICENSE.txt). <br />
60
+ The conversion of the RAFT model to the ONNX format by [PINTO0309](https://github.com/PINTO0309/PINTO_model_zoo/tree/main/252_RAFT) is under [MIT License](./MITLICENSE.txt). <br />
61
+ Some of the code in demo.py and raft.py is adapted from [ibaiGorordo's repository](https://github.com/ibaiGorordo/ONNX-RAFT-Optical-Flow-Estimation/tree/main) under [BSD-3-Clause license](./BSD-3-LICENSE.txt).<br />
62
+
63
+ ## Reference
64
+
65
+ - https://arxiv.org/abs/2003.12039
66
+ - https://github.com/princeton-vl/RAFT
67
+ - https://github.com/ibaiGorordo/ONNX-RAFT-Optical-Flow-Estimation/tree/main
68
+ - https://github.com/PINTO0309/PINTO_model_zoo/tree/main/252_RAFT
demo.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ import cv2 as cv
4
+ import numpy as np
5
+
6
+ from raft import Raft
7
+
8
+ parser = argparse.ArgumentParser(description='RAFT (https://github.com/princeton-vl/RAFT)')
9
+ parser.add_argument('--input1', '-i1', type=str,
10
+ help='Usage: Set input1 path to first image, omit if using camera or video.')
11
+ parser.add_argument('--input2', '-i2', type=str,
12
+ help='Usage: Set input2 path to second image, omit if using camera or video.')
13
+ parser.add_argument('--video', '-vid', type=str,
14
+ help='Usage: Set video path to desired input video, omit if using camera or two image inputs.')
15
+ parser.add_argument('--model', '-m', type=str, default='optical_flow_estimation_raft_2023aug.onnx',
16
+ help='Usage: Set model path, defaults to optical_flow_estimation_raft_2023aug.onnx.')
17
+ parser.add_argument('--save', '-s', action='store_true',
18
+ help='Usage: Specify to save a file with results. Invalid in case of camera input.')
19
+ parser.add_argument('--visual', '-vis', action='store_true',
20
+ help='Usage: Specify to open a new window to show results. Invalid in case of camera input.')
21
+ args = parser.parse_args()
22
+
23
+ UNKNOWN_FLOW_THRESH = 1e7
24
+
25
+ def make_color_wheel():
26
+ """ Generate color wheel according Middlebury color code.
27
+
28
+ Returns:
29
+ Color wheel(numpy.ndarray): Color wheel
30
+ """
31
+ RY = 15
32
+ YG = 6
33
+ GC = 4
34
+ CB = 11
35
+ BM = 13
36
+ MR = 6
37
+
38
+ ncols = RY + YG + GC + CB + BM + MR
39
+
40
+ colorwheel = np.zeros([ncols, 3])
41
+
42
+ col = 0
43
+
44
+ # RY
45
+ colorwheel[0:RY, 0] = 255
46
+ colorwheel[0:RY, 1] = np.transpose(np.floor(255*np.arange(0, RY) / RY))
47
+ col += RY
48
+
49
+ # YG
50
+ colorwheel[col:col+YG, 0] = 255 - np.transpose(np.floor(255*np.arange(0, YG) / YG))
51
+ colorwheel[col:col+YG, 1] = 255
52
+ col += YG
53
+
54
+ # GC
55
+ colorwheel[col:col+GC, 1] = 255
56
+ colorwheel[col:col+GC, 2] = np.transpose(np.floor(255*np.arange(0, GC) / GC))
57
+ col += GC
58
+
59
+ # CB
60
+ colorwheel[col:col+CB, 1] = 255 - np.transpose(np.floor(255*np.arange(0, CB) / CB))
61
+ colorwheel[col:col+CB, 2] = 255
62
+ col += CB
63
+
64
+ # BM
65
+ colorwheel[col:col+BM, 2] = 255
66
+ colorwheel[col:col+BM, 0] = np.transpose(np.floor(255*np.arange(0, BM) / BM))
67
+ col += + BM
68
+
69
+ # MR
70
+ colorwheel[col:col+MR, 2] = 255 - np.transpose(np.floor(255 * np.arange(0, MR) / MR))
71
+ colorwheel[col:col+MR, 0] = 255
72
+
73
+ return colorwheel
74
+
75
+ colorwheel = make_color_wheel()
76
+
77
+ def compute_color(u, v):
78
+ """ Compute optical flow color map
79
+
80
+ Args:
81
+ u(numpy.ndarray): Optical flow horizontal map
82
+ v(numpy.ndarray): Optical flow vertical map
83
+
84
+ Returns:
85
+ img (numpy.ndarray): Optical flow in color code
86
+ """
87
+ [h, w] = u.shape
88
+ img = np.zeros([h, w, 3])
89
+ nanIdx = np.isnan(u) | np.isnan(v)
90
+ u[nanIdx] = 0
91
+ v[nanIdx] = 0
92
+
93
+ ncols = np.size(colorwheel, 0)
94
+
95
+ rad = np.sqrt(u**2+v**2)
96
+
97
+ a = np.arctan2(-v, -u) / np.pi
98
+
99
+ fk = (a+1) / 2 * (ncols - 1) + 1
100
+
101
+ k0 = np.floor(fk).astype(int)
102
+
103
+ k1 = k0 + 1
104
+ k1[k1 == ncols+1] = 1
105
+ f = fk - k0
106
+
107
+ for i in range(0, np.size(colorwheel,1)):
108
+ tmp = colorwheel[:, i]
109
+ col0 = tmp[k0-1] / 255
110
+ col1 = tmp[k1-1] / 255
111
+ col = (1-f) * col0 + f * col1
112
+
113
+ idx = rad <= 1
114
+ col[idx] = 1-rad[idx]*(1-col[idx])
115
+ notidx = np.logical_not(idx)
116
+
117
+ col[notidx] *= 0.75
118
+ img[:, :, i] = np.uint8(np.floor(255 * col*(1-nanIdx)))
119
+
120
+ return img
121
+
122
+ def flow_to_image(flow):
123
+ """Convert flow into middlebury color code image
124
+
125
+ Args:
126
+ flow (np.ndarray): The computed flow map
127
+
128
+ Returns:
129
+ (np.ndarray): Image corresponding to the flow map.
130
+ """
131
+ u = flow[:, :, 0]
132
+ v = flow[:, :, 1]
133
+
134
+ maxu = -999.
135
+ maxv = -999.
136
+ minu = 999.
137
+ minv = 999.
138
+
139
+ idxUnknow = (abs(u) > UNKNOWN_FLOW_THRESH) | (abs(v) > UNKNOWN_FLOW_THRESH)
140
+ u[idxUnknow] = 0
141
+ v[idxUnknow] = 0
142
+
143
+ maxu = max(maxu, np.max(u))
144
+ minu = min(minu, np.min(u))
145
+
146
+ maxv = max(maxv, np.max(v))
147
+ minv = min(minv, np.min(v))
148
+
149
+ rad = np.sqrt(u ** 2 + v ** 2)
150
+ maxrad = max(-1, np.max(rad))
151
+
152
+ u = u/(maxrad + np.finfo(float).eps)
153
+ v = v/(maxrad + np.finfo(float).eps)
154
+
155
+ img = compute_color(u, v)
156
+
157
+ idx = np.repeat(idxUnknow[:, :, np.newaxis], 3, axis=2)
158
+ img[idx] = 0
159
+
160
+ return np.uint8(img)
161
+
162
+
163
+ def draw_flow(flow_map, img_width, img_height):
164
+ """Convert flow map to image
165
+
166
+ Args:
167
+ flow_map (np.ndarray): The computed flow map
168
+ img_width (int): The width of the first input photo
169
+ img_height (int): The height of the first input photo
170
+
171
+ Returns:
172
+ (np.ndarray): Image corresponding to the flow map.
173
+ """
174
+ # Convert flow to image
175
+ flow_img = flow_to_image(flow_map)
176
+ # Convert to BGR
177
+ flow_img = cv.cvtColor(flow_img, cv.COLOR_RGB2BGR)
178
+ # Resize the depth map to match the input image shape
179
+ return cv.resize(flow_img, (img_width, img_height))
180
+
181
+
182
+ def visualize(image1, image2, flow_img):
183
+ """
184
+ Combine two input images with resulting flow img and display them together
185
+
186
+ Args:
187
+ image1 (np.ndarray): The first input image.
188
+ imag2 (np.ndarray): The second input image.
189
+ flow_img (np.ndarray): The output flow map drawn as an image
190
+
191
+ Returns:
192
+ combined_img (np.ndarray): The visualized result.
193
+ """
194
+ combined_img = np.hstack((image1, image2, flow_img))
195
+ cv.namedWindow("Estimated flow", cv.WINDOW_NORMAL)
196
+ cv.imshow("Estimated flow", combined_img)
197
+ cv.waitKey(0)
198
+ return combined_img
199
+
200
+
201
+ if __name__ == '__main__':
202
+ # Instantiate RAFT
203
+ model = Raft(modelPath=args.model)
204
+
205
+ if args.input1 is not None and args.input2 is not None:
206
+ # Read image
207
+ image1 = cv.imread(args.input1)
208
+ image2 = cv.imread(args.input2)
209
+ img_height, img_width, img_channels = image1.shape
210
+
211
+ # Inference
212
+ result = model.infer(image1, image2)
213
+
214
+ # Create flow image based on the result flow map
215
+ flow_image = draw_flow(result, img_width, img_height)
216
+
217
+ # Save results if save is true
218
+ if args.save:
219
+ print('Results saved to result.jpg\n')
220
+ cv.imwrite('result.jpg', flow_image)
221
+
222
+ # Visualize results in a new window
223
+ if args.visual:
224
+ input_output_visualization = visualize(image1, image2, flow_image)
225
+
226
+
227
+ elif args.video is not None:
228
+ cap = cv.VideoCapture(args.video)
229
+ FLOW_FRAME_OFFSET = 3 # Number of frame difference to estimate the optical flow
230
+
231
+ if args.visual:
232
+ cv.namedWindow("Estimated flow", cv.WINDOW_NORMAL)
233
+
234
+ frame_list = []
235
+ img_array = []
236
+ frame_num = 0
237
+ while cap.isOpened():
238
+ try:
239
+ # Read frame from the video
240
+ ret, prev_frame = cap.read()
241
+ frame_list.append(prev_frame)
242
+ if not ret:
243
+ break
244
+ except:
245
+ continue
246
+
247
+ frame_num += 1
248
+ if frame_num <= FLOW_FRAME_OFFSET:
249
+ continue
250
+ else:
251
+ frame_num = 0
252
+
253
+ result = model.infer(frame_list[0], frame_list[-1])
254
+ img_height, img_width, img_channels = frame_list[0].shape
255
+ flow_img = draw_flow(result, img_width, img_height)
256
+
257
+ alpha = 0.6
258
+ combined_img = cv.addWeighted(frame_list[0], alpha, flow_img, (1-alpha),0)
259
+
260
+ if args.visual:
261
+ cv.imshow("Estimated flow", combined_img)
262
+ img_array.append(combined_img)
263
+ # Remove the oldest frame
264
+ frame_list.pop(0)
265
+
266
+ # Press key q to stop
267
+ if cv.waitKey(1) == ord('q'):
268
+ break
269
+
270
+ cap.release()
271
+
272
+ if args.save:
273
+ fourcc = cv.VideoWriter_fourcc(*'mp4v')
274
+ height,width,layers= img_array[0].shape
275
+ video = cv.VideoWriter('result.mp4', fourcc, 30.0, (width, height), isColor=True)
276
+ for img in img_array:
277
+ video.write(img)
278
+ video.release()
279
+
280
+ cv.destroyAllWindows()
281
+
282
+
283
+ else: # Omit input to call default camera
284
+ deviceId = 0
285
+ cap = cv.VideoCapture(deviceId)
286
+ w = int(cap.get(cv.CAP_PROP_FRAME_WIDTH))
287
+ h = int(cap.get(cv.CAP_PROP_FRAME_HEIGHT))
288
+
289
+ tm = cv.TickMeter()
290
+ while cv.waitKey(30) < 0:
291
+ hasFrame1, frame1 = cap.read()
292
+ hasFrame2, frame2 = cap.read()
293
+ if not hasFrame1:
294
+ print('First frame was not grabbed!')
295
+ break
296
+
297
+ if not hasFrame2:
298
+ print('Second frame was not grabbed!')
299
+ break
300
+
301
+ # Inference
302
+ tm.start()
303
+ result = model.infer(frame1, frame2)
304
+ tm.stop()
305
+ result = draw_flow(result, w, h)
306
+
307
+ # Draw results on the input image
308
+ frame = visualize(frame1, frame2, result)
309
+
310
+ tm.reset()
raft.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file is part of OpenCV Zoo project.
2
+
3
+ import cv2 as cv
4
+ import numpy as np
5
+
6
+
7
+ class Raft:
8
+ def __init__(self, modelPath):
9
+ self._modelPath = modelPath
10
+ self.model = cv.dnn.readNet(self._modelPath)
11
+
12
+ self.input_names = ['0', '1']
13
+ self.first_input_name = self.input_names[0]
14
+ self.second_input_name = self.input_names[1]
15
+ self.input_shape = [360, 480] # change if going to use different model with different input shape
16
+ self.input_height = self.input_shape[0]
17
+ self.input_width = self.input_shape[1]
18
+
19
+ @property
20
+ def name(self):
21
+ return self.__class__.__name__
22
+
23
+ def _preprocess(self, image):
24
+
25
+ image = cv.cvtColor(image, cv.COLOR_BGR2RGB)
26
+ img_input = cv.resize(image, (self.input_width,self.input_height))
27
+ img_input = img_input.transpose(2, 0, 1)
28
+ img_input = img_input[np.newaxis,:,:,:]
29
+ img_input = img_input.astype(np.float32)
30
+ return img_input
31
+
32
+ def infer(self, image1, image2):
33
+
34
+ # Preprocess
35
+ input_1 = self._preprocess(image1)
36
+ input_2 = self._preprocess(image2)
37
+
38
+ # Forward
39
+ self.model.setInput(input_1, self.first_input_name)
40
+ self.model.setInput(input_2, self.second_input_name)
41
+ layer_names = self.model.getLayerNames()
42
+ outputlayers = [layer_names[i-1] for i in self.model.getUnconnectedOutLayers()]
43
+ output = self.model.forward(outputlayers)
44
+
45
+ # Postprocess
46
+ results = self._postprocess(output)
47
+
48
+ return results
49
+
50
+ def _postprocess(self, output):
51
+
52
+ flow_map = output[1][0].transpose(1, 2, 0)
53
+ return flow_map