File size: 7,290 Bytes
55654c5
 
 
 
 
 
 
 
57699b7
55654c5
 
 
 
 
4b236af
55654c5
50fc340
55654c5
 
260eb6d
55654c5
 
 
260eb6d
55654c5
 
 
 
 
 
 
 
d30c3db
55654c5
d33294b
d30c3db
55654c5
 
 
50fc340
 
55654c5
260eb6d
 
55654c5
86cd32d
55654c5
e5b568e
260eb6d
55654c5
 
 
 
 
 
 
 
 
 
260eb6d
55654c5
 
 
d30c3db
 
 
 
55654c5
d30c3db
 
55654c5
 
 
 
 
 
 
 
4b236af
55654c5
 
 
 
 
 
 
 
e5b568e
62917b7
 
e5b568e
55654c5
 
b57baee
55654c5
 
e5b568e
55654c5
 
 
e5b568e
55654c5
 
 
e5b568e
55654c5
d30c3db
e5b568e
55654c5
260eb6d
 
 
 
 
50fc340
 
 
 
c62dd45
 
 
55654c5
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
# This file is part of OpenCV Zoo project.
# It is subject to the license terms in the LICENSE file found in the same directory.
#
# Copyright (C) 2021, Shenzhen Institute of Artificial Intelligence and Robotics for Society, all rights reserved.
# Third party copyrights are property of their respective owners.

import os
import sys
import numpy as np
import cv2 as cv

import onnx
from onnx import version_converter
import onnxruntime
from onnxruntime.quantization import quantize_static, CalibrationDataReader, QuantType, QuantFormat

from transform import Compose, Resize, CenterCrop, Normalize, ColorConvert, HandAlign

class DataReader(CalibrationDataReader):
    def __init__(self, model_path, image_dir, transforms, data_dim):
        model = onnx.load(model_path)
        self.input_name = model.graph.input[0].name
        self.transforms = transforms
        self.data_dim = data_dim
        self.data = self.get_calibration_data(image_dir)
        self.enum_data_dicts = iter([{self.input_name: x} for x in self.data])

    def get_next(self):
        return next(self.enum_data_dicts, None)

    def get_calibration_data(self, image_dir):
        blobs = []
        supported = ["jpg", "png"]  # supported file suffix
        for image_name in os.listdir(image_dir):
            image_name_suffix = image_name.split('.')[-1].lower()
            if image_name_suffix not in supported:
                continue
            img = cv.imread(os.path.join(image_dir, image_name))
            img = self.transforms(img)
            if img is None:
                continue
            blob = cv.dnn.blobFromImage(img)
            if self.data_dim == 'hwc':
                blob = cv.transposeND(blob, [0, 2, 3, 1])
            blobs.append(blob)
        return blobs

class Quantize:
    def __init__(self, model_path, calibration_image_dir, transforms=Compose(), per_channel=False, act_type='int8', wt_type='int8', data_dim='chw'):
        self.type_dict = {"uint8" : QuantType.QUInt8, "int8" : QuantType.QInt8}

        self.model_path = model_path
        self.calibration_image_dir = calibration_image_dir
        self.transforms = transforms
        self.per_channel = per_channel
        self.act_type = act_type
        self.wt_type = wt_type

        # data reader
        self.dr = DataReader(self.model_path, self.calibration_image_dir, self.transforms, data_dim)

    def check_opset(self, convert=True):
        model = onnx.load(self.model_path)
        if model.opset_import[0].version != 13:
            print('\tmodel opset version: {}. Converting to opset 13'.format(model.opset_import[0].version))
            # convert opset version to 13
            model_opset13 = version_converter.convert_version(model, 13)
            # save converted model
            output_name = '{}-opset.onnx'.format(self.model_path[:-5])
            onnx.save_model(model_opset13, output_name)
            # update model_path for quantization
            self.model_path = output_name

    def run(self):
        print('Quantizing {}: act_type {}, wt_type {}'.format(self.model_path, self.act_type, self.wt_type))
        self.check_opset()
        output_name = '{}-act_{}-wt_{}-quantized.onnx'.format(self.model_path[:-5], self.act_type, self.wt_type)
        quantize_static(self.model_path, output_name, self.dr,
                        quant_format=QuantFormat.QOperator, # start from onnxruntime==1.11.0, quant_format is set to QuantFormat.QDQ by default, which performs fake quantization
                        per_channel=self.per_channel,
                        weight_type=self.type_dict[self.wt_type],
                        activation_type=self.type_dict[self.act_type])
        os.remove('augmented_model.onnx')
        os.remove('{}-opt.onnx'.format(self.model_path[:-5]))
        print('\tQuantized model saved to {}'.format(output_name))

models=dict(
    yunet=Quantize(model_path='../../models/face_detection_yunet/face_detection_yunet_2022mar.onnx',
                   calibration_image_dir='../../benchmark/data/face_detection',
                   transforms=Compose([Resize(size=(160, 120))])),
    sface=Quantize(model_path='../../models/face_recognition_sface/face_recognition_sface_2021dec.onnx',
                   calibration_image_dir='../../benchmark/data/face_recognition',
                   transforms=Compose([Resize(size=(112, 112))])),
    pphumanseg=Quantize(model_path='../../models/human_segmentation_pphumanseg/human_segmentation_pphumanseg_2023mar.onnx',
                        calibration_image_dir='../../benchmark/data/human_segmentation',
                        transforms=Compose([Resize(size=(192, 192))])),
    ppresnet50=Quantize(model_path='../../models/image_classification_ppresnet/image_classification_ppresnet50_2022jan.onnx',
                        calibration_image_dir='../../benchmark/data/image_classification',
                        transforms=Compose([Resize(size=(224, 224))])),
    # TBD: DaSiamRPN
    youtureid=Quantize(model_path='../../models/person_reid_youtureid/person_reid_youtu_2021nov.onnx',
                       calibration_image_dir='../../benchmark/data/person_reid',
                       transforms=Compose([Resize(size=(128, 256))])),
    # TBD: DB-EN & DB-CN
    crnn_en=Quantize(model_path='../../models/text_recognition_crnn/text_recognition_CRNN_EN_2021sep.onnx',
                     calibration_image_dir='../../benchmark/data/text',
                     transforms=Compose([Resize(size=(100, 32)), Normalize(mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5]), ColorConvert(ctype=cv.COLOR_BGR2GRAY)])),
    crnn_cn=Quantize(model_path='../../models/text_recognition_crnn/text_recognition_CRNN_CN_2021nov.onnx',
                     calibration_image_dir='../../benchmark/data/text',
                     transforms=Compose([Resize(size=(100, 32))])),
    mp_palmdet=Quantize(model_path='../../models/palm_detection_mediapipe/palm_detection_mediapipe_2023feb.onnx',
                        calibration_image_dir='path/to/dataset',
                        transforms=Compose([Resize(size=(192, 192)), Normalize(std=[255, 255, 255]),
                        ColorConvert(ctype=cv.COLOR_BGR2RGB)]), data_dim='hwc'),
    mp_handpose=Quantize(model_path='../../models/handpose_estimation_mediapipe/handpose_estimation_mediapipe_2023feb.onnx',
                        calibration_image_dir='path/to/dataset',
                        transforms=Compose([HandAlign("mp_handpose"), Resize(size=(224, 224)), Normalize(std=[255, 255, 255]),
                        ColorConvert(ctype=cv.COLOR_BGR2RGB)]), data_dim='hwc'),
    lpd_yunet=Quantize(model_path='../../models/license_plate_detection_yunet/license_plate_detection_lpd_yunet_2023mar.onnx',
                       calibration_image_dir='../../benchmark/data/license_plate_detection',
                       transforms=Compose([Resize(size=(320, 240))])),
)

if __name__ == '__main__':
    selected_models = []
    for i in range(1, len(sys.argv)):
        selected_models.append(sys.argv[i])
    if not selected_models:
        selected_models = list(models.keys())
    print('Models to be quantized: {}'.format(str(selected_models)))

    for selected_model_name in selected_models:
        q = models[selected_model_name]
        q.run()