File size: 14,347 Bytes
4494462
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
import argparse
import os
import re
import sys
import time
import cv2
import math
import glob
import numpy as np

import axengine as axe
from axengine import axclrt_provider_name, axengine_provider_name

def load_model(model_path: str | os.PathLike, selected_provider: str, selected_device_id: int = 0):
    if selected_provider == 'AUTO':
        # Use AUTO to let the pyengine choose the first available provider
        return axe.InferenceSession(model_path)

    providers = []
    if selected_provider == axclrt_provider_name:
        provider_options = {"device_id": selected_device_id}
        providers.append((axclrt_provider_name, provider_options))
    if selected_provider == axengine_provider_name:
        providers.append(axengine_provider_name)

    return axe.InferenceSession(model_path, providers=providers)


def get_frames(video_name):
    """获取视频帧

    Args:
        video_name (_type_): _description_

    Yields:
        _type_: _description_
    """
    if not video_name:
        rtsp = "rtsp://%s:%s@%s:554/cam/realmonitor?channel=1&subtype=1" % ("admin", "123456", "192.168.1.108")
        cap = cv2.VideoCapture(rtsp) if rtsp else cv2.VideoCapture()
        
        # warmup
        for i in range(5):
            cap.read()
        while True:
            ret, frame = cap.read()
            if ret:
                # print('读取成功===>>>', frame.shape)
                yield cv2.resize(frame,(800, 600))
            else:
                break
    elif video_name.endswith('avi') or \
        video_name.endswith('mp4'):
        cap = cv2.VideoCapture(video_name)
        while True:
            ret, frame = cap.read()
            if ret:
                yield frame
            else:
                break
    else:
        images = sorted(glob(os.path.join(video_name, 'img', '*.jp*')))
        for img in images:
            frame = cv2.imread(img)
            yield frame


class Preprocessor_wo_mask(object):
    def __init__(self):
        self.mean = np.array([0.485, 0.456, 0.406]).reshape((1, 3, 1, 1)).astype(np.float32)
        self.std = np.array([0.229, 0.224, 0.225]).reshape((1, 3, 1, 1)).astype(np.float32)

    def process(self, img_arr: np.ndarray):
        # Deal with the image patch
        img_tensor = img_arr.transpose((2, 0, 1)).reshape((1, 3, img_arr.shape[0], img_arr.shape[1])).astype(np.float32) / 255.0
        img_tensor_norm = (img_tensor - self.mean) / self.std  # (1,3,H,W)
        return img_tensor_norm


class MFTrackerORT:
    def __init__(self, model_path, fp16=False) -> None:
        self.debug = True
        self.gpu_id = 0
        self.providers = ["CUDAExecutionProvider"]
        self.provider_options = [{"device_id": str(self.gpu_id)}]
        self.model_path = model_path
        self.fp16 = fp16
        
        self.init_track_net()
        self.preprocessor = Preprocessor_wo_mask()
        self.max_score_decay = 1.0
        self.search_factor = 4.5
        self.search_size = 224
        self.template_factor = 2.0
        self.template_size = 112
        self.update_interval = 200
        self.online_size = 1

    def init_track_net(self):
        """使用设置的参数初始化tracker网络
        """
        self.ax_session = load_model(self.model_path, selected_provider="AUTO")

    def track_init(self, frame, target_pos=None, target_sz = None):
        """使用第一帧进行初始化

        Args:
            frame (_type_): _description_
            target_pos (_type_, optional): _description_. Defaults to None.
            target_sz (_type_, optional): _description_. Defaults to None.
        """
        self.trace_list = []
        try:
            # [x, y, w, h]
            init_state = [target_pos[0], target_pos[1], target_sz[0], target_sz[1]]
            z_patch_arr, _, z_amask_arr = self.sample_target(frame, init_state, self.template_factor, output_sz=self.template_size)
            template = self.preprocessor.process(z_patch_arr)
            self.template = template
            self.online_template = template

            self.online_state = init_state
            self.online_image = frame
            self.max_pred_score = -1.0
            self.online_max_template = template
            self.online_forget_id = 0

            # save states
            self.state = init_state
            self.frame_id = 0
            print(f"第一帧初始化完毕!")
        except:
            print(f"第一帧初始化异常!")
            exit()

    def track(self, image, info: dict = None):
        H, W, _ = image.shape
        self.frame_id += 1
        x_patch_arr, resize_factor, x_amask_arr = self.sample_target(image, self.state, self.search_factor,
                                                                output_sz=self.search_size)  # (x1, y1, w, h)
        search = self.preprocessor.process(x_patch_arr)

        # compute ONNX Runtime output prediction
        ort_inputs = {'img_t': self.template, 'img_ot': self.online_template, 'img_search': search}

        ort_outs = self.ax_session.run(None, ort_inputs)

        # print(f">>> lenght trt_outputs: {ort_outs}")
        pred_boxes = ort_outs[0]
        pred_score = ort_outs[1]
        # print(f">>> box and score: {pred_boxes}  {pred_score}")
        # Baseline: Take the mean of all pred boxes as the final result
        pred_box = (np.mean(pred_boxes, axis=0) * self.search_size / resize_factor).tolist()  # (cx, cy, w, h) [0,1]
        # get the final box result
        self.state = self.clip_box(self.map_box_back(pred_box, resize_factor), H, W, margin=10)

        self.max_pred_score = self.max_pred_score * self.max_score_decay
        # update template
        if pred_score > 0.5 and pred_score > self.max_pred_score:
            z_patch_arr, _, z_amask_arr = self.sample_target(image, self.state,
                                                        self.template_factor,
                                                        output_sz=self.template_size)  # (x1, y1, w, h)
            self.online_max_template = self.preprocessor.process(z_patch_arr)
            self.max_pred_score = pred_score

        
        if self.frame_id % self.update_interval == 0:
            if self.online_size == 1:
                self.online_template = self.online_max_template
            else:
                self.online_template[self.online_forget_id:self.online_forget_id+1] = self.online_max_template
                self.online_forget_id = (self.online_forget_id + 1) % self.online_size

            self.max_pred_score = -1
            self.online_max_template = self.template

        # for debug
        if self.debug:
            x1, y1, w, h = self.state
            # image_BGR = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
            cv2.rectangle(image, (int(x1),int(y1)), (int(x1+w),int(y1+h)), color=(0,0,255), thickness=2)

        return {"target_bbox": self.state, "conf_score": pred_score}

    def map_box_back(self, pred_box: list, resize_factor: float):
        cx_prev, cy_prev = self.state[0] + 0.5 * self.state[2], self.state[1] + 0.5 * self.state[3]
        cx, cy, w, h = pred_box
        half_side = 0.5 * self.search_size / resize_factor
        cx_real = cx + (cx_prev - half_side)
        cy_real = cy + (cy_prev - half_side)
        return [cx_real - 0.5 * w, cy_real - 0.5 * h, w, h]

    def map_box_back_batch(self, pred_box: np.ndarray, resize_factor: float):
        cx_prev, cy_prev = self.state[0] + 0.5 * self.state[2], self.state[1] + 0.5 * self.state[3]
        cx, cy, w, h = pred_box.T # (N,4) --> (N,)
        half_side = 0.5 * self.search_size / resize_factor
        cx_real = cx + (cx_prev - half_side)
        cy_real = cy + (cy_prev - half_side)
        return np.stack([cx_real - 0.5 * w, cy_real - 0.5 * h, w, h], axis=-1)
    
    def sample_target(self, im, target_bb, search_area_factor, output_sz=None, mask=None):
        """ Extracts a square crop centered at target_bb box, of area search_area_factor^2 times target_bb area

        args:
            im - cv image
            target_bb - target box [x, y, w, h]
            search_area_factor - Ratio of crop size to target size
            output_sz - (float) Size to which the extracted crop is resized (always square). If None, no resizing is done.

        returns:
            cv image - extracted crop
            float - the factor by which the crop has been resized to make the crop size equal output_size
        """
        if not isinstance(target_bb, list):
            x, y, w, h = target_bb.tolist()
        else:
            x, y, w, h = target_bb
        # Crop image
        crop_sz = math.ceil(math.sqrt(w * h) * search_area_factor)

        if crop_sz < 1:
            raise Exception('Too small bounding box.')

        x1 = int(round(x + 0.5 * w - crop_sz * 0.5))
        x2 = int(x1 + crop_sz)

        y1 = int(round(y + 0.5 * h - crop_sz * 0.5))
        y2 = int(y1 + crop_sz)

        x1_pad = int(max(0, -x1))
        x2_pad = int(max(x2 - im.shape[1] + 1, 0))

        y1_pad = int(max(0, -y1))
        y2_pad = int(max(y2 - im.shape[0] + 1, 0))

        # Crop target
        im_crop = im[y1 + y1_pad:y2 - y2_pad, x1 + x1_pad:x2 - x2_pad, :]
        if mask is not None:
            mask_crop = mask[y1 + y1_pad:y2 - y2_pad, x1 + x1_pad:x2 - x2_pad]

        # Pad
        im_crop_padded = cv2.copyMakeBorder(im_crop, y1_pad, y2_pad, x1_pad, x2_pad, cv2.BORDER_CONSTANT)
        # deal with attention mask
        H, W, _ = im_crop_padded.shape
        att_mask = np.ones((H,W))
        end_x, end_y = -x2_pad, -y2_pad
        if y2_pad == 0:
            end_y = None
        if x2_pad == 0:
            end_x = None
        att_mask[y1_pad:end_y, x1_pad:end_x] = 0
        if mask is not None:
            mask_crop_padded = cv2.copyMakeBorder(mask_crop, y1_pad, y2_pad, x1_pad, x2_pad, cv2.BORDER_CONSTANT)

        if output_sz is not None:
            resize_factor = output_sz / crop_sz
            im_crop_padded = cv2.resize(im_crop_padded, (output_sz, output_sz))
            att_mask = cv2.resize(att_mask, (output_sz, output_sz)).astype(np.bool_)
            if mask is None:
                return im_crop_padded, resize_factor, att_mask
            mask_crop_padded = \
            mask_crop_padded = cv2.resize(mask_crop_padded, (output_sz, output_sz))
            return im_crop_padded, resize_factor, att_mask, mask_crop_padded

        else:
            if mask is None:
                return im_crop_padded, att_mask.astype(np.bool_), 1.0
            return im_crop_padded, 1.0, att_mask.astype(np.bool_), mask_crop_padded
        
    def clip_box(self, box: list, H, W, margin=0):
        x1, y1, w, h = box
        x2, y2 = x1 + w, y1 + h
        x1 = min(max(0, x1), W-margin)
        x2 = min(max(margin, x2), W)
        y1 = min(max(0, y1), H-margin)
        y2 = min(max(margin, y2), H)
        w = max(margin, x2-x1)
        h = max(margin, y2-y1)
        return [x1, y1, w, h]
    
    
def main(model_path, frame_path, repeat, selected_provider, selected_device_id):
    Tracker = MFTrackerORT(model_path = model_path, fp16=False)
    first_frame = True
    Tracker.video_name = frame_path

    frame_id = 0
    total_time = 0
    for frame in get_frames(Tracker.video_name):
        # print(f"frame shape {frame.shape}")
        
        # 如果超过了指定的帧数限制,则跳出循环
        if repeat is not None and frame_id >= repeat:
            print(f"Reached the maximum number of frames ({repeat}). Exiting loop.")
            break        
        
        tic = cv2.getTickCount()
        if first_frame:
            # x, y, w, h = cv2.selectROI(video_name, frame, fromCenter=False)
            x, y, w, h = 1079, 482, 99, 106

            target_pos = [x, y]
            target_sz = [w, h]
            print('====================type=================', target_pos, type(target_pos), type(target_sz))
            Tracker.track_init(frame, target_pos, target_sz)
            first_frame = False
        else:
            state = Tracker.track(frame)
            frame_id += 1

            os.makedirs('axmodel_output', exist_ok=True)
            cv2.imwrite(f'axmodel_output/{str(frame_id)}.png', frame)

        toc = cv2.getTickCount() - tic
        toc = int(1 / (toc / cv2.getTickFrequency()))
        total_time += toc
        print('Video: {:12s} {:3.1f}fps'.format('tracking', toc))
    
    print('video: average {:12s} {:3.1f} fps'.format('finale average tracking fps', total_time/(frame_id - 1)))

    

class ExampleParser(argparse.ArgumentParser):
    def error(self, message):
        self.print_usage(sys.stderr)
        print(f"\nError: {message}")
        print("\nExample usage:")
        print("  python3 run_mixformer2_axmodel.py -m <model_file> -f <frame_file>")
        print("  python3 run_mixformer2_axmodel.py -m compiled.axmodel -f car.avi")
        print(
            f"  python3 run_mixformer2_axmodel.py -m compiled.axmodel -f car.avi -p {axengine_provider_name}")
        print(
            f"  python3 run_mixformer2_axmodel.py -m compiled.axmodel -f car.avi -p {axclrt_provider_name}")
        sys.exit(1)


if __name__ == "__main__":
    ap = ExampleParser()
    ap.add_argument('-m', '--model-path', type=str, help='model path', required=True)
    ap.add_argument('-f', '--frame-path', type=str, help='frame path', required=True)
    ap.add_argument('-r', '--repeat', type=int, help='repeat times', default=100)
    ap.add_argument(
        '-p',
        '--provider',
        type=str,
        choices=["AUTO", f"{axclrt_provider_name}", f"{axengine_provider_name}"],
        help=f'"AUTO", "{axclrt_provider_name}", "{axengine_provider_name}"',
        default='AUTO'
    )
    ap.add_argument(
        '-d',
        '--device-id',
        type=int,
        help=R'axclrt device index, depends on how many cards inserted',
        default=0
    )
    args = ap.parse_args()

    model_file = args.model_path
    frame_file = args.frame_path

    # check if the model and image exist
    assert os.path.exists(model_file), f"model file path {model_file} does not exist"
    assert os.path.exists(frame_file), f"image file path {frame_file} does not exist"

    repeat = args.repeat

    provider = args.provider
    device_id = args.device_id

    main(model_file, frame_file, repeat, provider, device_id)