Yuantao Feng commited on
Commit
55654c5
·
1 Parent(s): 2074f99

Add tools for quantization and quantized models (#36)

Browse files

* add scripts for quantization

* update path to pp-resnet50

* add quantized models

* rename dict to models

* add requirements and readme

* fix typos

tools/quantize/README.md ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Quantization with ONNXRUNTIME
2
+
3
+ ONNXRUNTIME is used for quantization in the Zoo.
4
+
5
+ Install dependencies before trying quantization:
6
+ ```shell
7
+ pip install -r requirements.txt
8
+ ```
9
+
10
+ ## Usage
11
+
12
+ Quantize all models in the Zoo:
13
+ ```shell
14
+ python quantize.py
15
+ ```
16
+
17
+ Quantize one of the models in the Zoo:
18
+ ```shell
19
+ # python quantize.py <key_in_models>
20
+ python quantize.py yunet
21
+ ```
22
+
23
+ Customizing quantization configs:
24
+ ```python
25
+ # add model into `models` dict in quantize.py
26
+ models = dict(
27
+ # ...
28
+ model1=Quantize(model_path='/path/to/model1.onnx'
29
+ calibration_image_dir='/path/to/images',
30
+ transforms=Compose([''' transforms ''']), # transforms can be found in transforms.py
31
+ per_channel=False, # set False to quantize in per-tensor style
32
+ act_type='int8', # available types: 'int8', 'uint8'
33
+ wt_type='int8' # available types: 'int8', 'uint8'
34
+ )
35
+ )
36
+ # quantize the added models
37
+ python quantize.py model1
38
+ ```
tools/quantize/quantize.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file is part of OpenCV Zoo project.
2
+ # It is subject to the license terms in the LICENSE file found in the same directory.
3
+ #
4
+ # Copyright (C) 2021, Shenzhen Institute of Artificial Intelligence and Robotics for Society, all rights reserved.
5
+ # Third party copyrights are property of their respective owners.
6
+
7
+ import os
8
+ import sys
9
+ import numpy as ny
10
+ import cv2 as cv
11
+
12
+ import onnx
13
+ from onnx import version_converter
14
+ import onnxruntime
15
+ from onnxruntime.quantization import quantize_static, CalibrationDataReader, QuantType
16
+
17
+ from transform import Compose, Resize, ColorConvert
18
+
19
+ class DataReader(CalibrationDataReader):
20
+ def __init__(self, model_path, image_dir, transforms):
21
+ model = onnx.load(model_path)
22
+ self.input_name = model.graph.input[0].name
23
+ self.transforms = transforms
24
+ self.data = self.get_calibration_data(image_dir)
25
+ self.enum_data_dicts = iter([{self.input_name: x} for x in self.data])
26
+
27
+ def get_next(self):
28
+ return next(self.enum_data_dicts, None)
29
+
30
+ def get_calibration_data(self, image_dir):
31
+ blobs = []
32
+ for image_name in os.listdir(image_dir):
33
+ if not image_name.endswith('jpg'):
34
+ continue
35
+ img = cv.imread(os.path.join(image_dir, image_name))
36
+ img = self.transforms(img)
37
+ blob = cv.dnn.blobFromImage(img)
38
+ blobs.append(blob)
39
+ return blobs
40
+
41
+ class Quantize:
42
+ def __init__(self, model_path, calibration_image_dir, transforms=Compose(), per_channel=False, act_type='int8', wt_type='int8'):
43
+ self.type_dict = {"uint8" : QuantType.QUInt8, "int8" : QuantType.QInt8}
44
+
45
+ self.model_path = model_path
46
+ self.calibration_image_dir = calibration_image_dir
47
+ self.transforms = transforms
48
+ self.per_channel = per_channel
49
+ self.act_type = act_type
50
+ self.wt_type = wt_type
51
+
52
+ # data reader
53
+ self.dr = DataReader(self.model_path, self.calibration_image_dir, self.transforms)
54
+
55
+ def check_opset(self, convert=True):
56
+ model = onnx.load(self.model_path)
57
+ if model.opset_import[0].version != 11:
58
+ print('\tmodel opset version: {}. Converting to opset 11'.format(model.opset_import[0].version))
59
+ # convert opset version to 11
60
+ model_opset11 = version_converter.convert_version(model, 11)
61
+ # save converted model
62
+ output_name = '{}-opset11.onnx'.format(self.model_path[:-5])
63
+ onnx.save_model(model_opset11, output_name)
64
+ # update model_path for quantization
65
+ self.model_path = output_name
66
+
67
+ def run(self):
68
+ print('Quantizing {}: act_type {}, wt_type {}'.format(self.model_path, self.act_type, self.wt_type))
69
+ self.check_opset()
70
+ output_name = '{}-act_{}-wt_{}-quantized.onnx'.format(self.model_path[:-5], self.act_type, self.wt_type)
71
+ quantize_static(self.model_path, output_name, self.dr,
72
+ per_channel=self.per_channel,
73
+ weight_type=self.type_dict[self.wt_type],
74
+ activation_type=self.type_dict[self.act_type])
75
+ os.remove('augmented_model.onnx')
76
+ os.remove('{}-opt.onnx'.format(self.model_path[:-5]))
77
+ print('\tQuantized model saved to {}'.format(output_name))
78
+
79
+
80
+ models=dict(
81
+ yunet=Quantize(model_path='../../models/face_detection_yunet/face_detection_yunet_2021dec.onnx',
82
+ calibration_image_dir='../../benchmark/data/face_detection'),
83
+ sface=Quantize(model_path='../../models/face_recognition_sface/face_recognition_sface_2021dec.onnx',
84
+ calibration_image_dir='../../benchmark/data/face_recognition',
85
+ transforms=Compose([Resize(size=(112, 112))])),
86
+ pphumenseg=Quantize(model_path='../../models/human_segmentation_pphumanseg/human_segmentation_pphumanseg_2021oct.onnx',
87
+ calibration_image_dir='../../benchmark/data/human_segmentation',
88
+ transforms=Compose([Resize(size=(192, 192))])),
89
+ ppresnet50=Quantize(model_path='../../models/image_classification_ppresnet/image_classification_ppresnet50_2022jan.onnx',
90
+ calibration_image_dir='../../benchmark/data/image_classification',
91
+ transforms=Compose([Resize(size=(224, 224))])),
92
+ # TBD: DaSiamRPN
93
+ youtureid=Quantize(model_path='../../models/person_reid_youtureid/person_reid_youtu_2021nov.onnx',
94
+ calibration_image_dir='../../benchmark/data/person_reid',
95
+ transforms=Compose([Resize(size=(128, 256))])),
96
+ # TBD: DB-EN & DB-CN
97
+ crnn_en=Quantize(model_path='../../models/text_recognition_crnn/text_recognition_CRNN_EN_2021sep.onnx',
98
+ calibration_image_dir='../../benchmark/data/text',
99
+ transforms=Compose([Resize(size=(100, 32)), ColorConvert(ctype=cv.COLOR_BGR2GRAY)])),
100
+ crnn_cn=Quantize(model_path='../../models/text_recognition_crnn/text_recognition_CRNN_CN_2021nov.onnx',
101
+ calibration_image_dir='../../benchmark/data/text',
102
+ transforms=Compose([Resize(size=(100, 32))]))
103
+ )
104
+
105
+ if __name__ == '__main__':
106
+ selected_models = []
107
+ for i in range(1, len(sys.argv)):
108
+ selected_models.append(sys.argv[i])
109
+ if not selected_models:
110
+ selected_models = list(models.keys())
111
+ print('Models to be quantized: {}'.format(str(selected_models)))
112
+
113
+ for selected_model_name in selected_models:
114
+ q = models[selected_model_name]
115
+ q.run()
116
+
tools/quantize/requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ opencv-python>=4.5.4.58
2
+ onnx
3
+ onnxruntime
tools/quantize/transform.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file is part of OpenCV Zoo project.
2
+ # It is subject to the license terms in the LICENSE file found in the same directory.
3
+ #
4
+ # Copyright (C) 2021, Shenzhen Institute of Artificial Intelligence and Robotics for Society, all rights reserved.
5
+ # Third party copyrights are property of their respective owners.
6
+
7
+ import numpy as numpy
8
+ import cv2 as cv
9
+
10
+ class Compose:
11
+ def __init__(self, transforms=[]):
12
+ self.transforms = transforms
13
+
14
+ def __call__(self, img):
15
+ for t in self.transforms:
16
+ img = t(img)
17
+ return img
18
+
19
+ class Resize:
20
+ def __init__(self, size, interpolation=cv.INTER_LINEAR):
21
+ self.size = size
22
+ self.interpolation = interpolation
23
+
24
+ def __call__(self, img):
25
+ return cv.resize(img, self.size)
26
+
27
+ class ColorConvert:
28
+ def __init__(self, ctype):
29
+ self.ctype = ctype
30
+
31
+ def __call__(self, img):
32
+ return cv.cvtColor(img, self.ctype)