shorten int8-quantized naming (#149)
Browse files
tools/quantize/quantize-ort.py
CHANGED
|
@@ -59,29 +59,30 @@ class Quantize:
|
|
| 59 |
# data reader
|
| 60 |
self.dr = DataReader(self.model_path, self.calibration_image_dir, self.transforms, data_dim)
|
| 61 |
|
| 62 |
-
def check_opset(self
|
| 63 |
model = onnx.load(self.model_path)
|
| 64 |
if model.opset_import[0].version != 13:
|
| 65 |
print('\tmodel opset version: {}. Converting to opset 13'.format(model.opset_import[0].version))
|
| 66 |
# convert opset version to 13
|
| 67 |
model_opset13 = version_converter.convert_version(model, 13)
|
| 68 |
# save converted model
|
| 69 |
-
output_name = '{}-
|
| 70 |
onnx.save_model(model_opset13, output_name)
|
| 71 |
# update model_path for quantization
|
| 72 |
-
|
|
|
|
| 73 |
|
| 74 |
def run(self):
|
| 75 |
print('Quantizing {}: act_type {}, wt_type {}'.format(self.model_path, self.act_type, self.wt_type))
|
| 76 |
-
self.check_opset()
|
| 77 |
-
output_name = '{}
|
| 78 |
-
quantize_static(
|
| 79 |
quant_format=QuantFormat.QOperator, # start from onnxruntime==1.11.0, quant_format is set to QuantFormat.QDQ by default, which performs fake quantization
|
| 80 |
per_channel=self.per_channel,
|
| 81 |
weight_type=self.type_dict[self.wt_type],
|
| 82 |
activation_type=self.type_dict[self.act_type])
|
| 83 |
-
|
| 84 |
-
|
| 85 |
print('\tQuantized model saved to {}'.format(output_name))
|
| 86 |
|
| 87 |
models=dict(
|
|
@@ -132,4 +133,3 @@ if __name__ == '__main__':
|
|
| 132 |
for selected_model_name in selected_models:
|
| 133 |
q = models[selected_model_name]
|
| 134 |
q.run()
|
| 135 |
-
|
|
|
|
| 59 |
# data reader
|
| 60 |
self.dr = DataReader(self.model_path, self.calibration_image_dir, self.transforms, data_dim)
|
| 61 |
|
| 62 |
+
def check_opset(self):
|
| 63 |
model = onnx.load(self.model_path)
|
| 64 |
if model.opset_import[0].version != 13:
|
| 65 |
print('\tmodel opset version: {}. Converting to opset 13'.format(model.opset_import[0].version))
|
| 66 |
# convert opset version to 13
|
| 67 |
model_opset13 = version_converter.convert_version(model, 13)
|
| 68 |
# save converted model
|
| 69 |
+
output_name = '{}-opset13.onnx'.format(self.model_path[:-5])
|
| 70 |
onnx.save_model(model_opset13, output_name)
|
| 71 |
# update model_path for quantization
|
| 72 |
+
return output_name
|
| 73 |
+
return self.model_path
|
| 74 |
|
| 75 |
def run(self):
|
| 76 |
print('Quantizing {}: act_type {}, wt_type {}'.format(self.model_path, self.act_type, self.wt_type))
|
| 77 |
+
new_model_path = self.check_opset()
|
| 78 |
+
output_name = '{}_{}.onnx'.format(self.model_path[:-5], self.wt_type)
|
| 79 |
+
quantize_static(new_model_path, output_name, self.dr,
|
| 80 |
quant_format=QuantFormat.QOperator, # start from onnxruntime==1.11.0, quant_format is set to QuantFormat.QDQ by default, which performs fake quantization
|
| 81 |
per_channel=self.per_channel,
|
| 82 |
weight_type=self.type_dict[self.wt_type],
|
| 83 |
activation_type=self.type_dict[self.act_type])
|
| 84 |
+
if new_model_path != self.model_path:
|
| 85 |
+
os.remove(new_model_path)
|
| 86 |
print('\tQuantized model saved to {}'.format(output_name))
|
| 87 |
|
| 88 |
models=dict(
|
|
|
|
| 133 |
for selected_model_name in selected_models:
|
| 134 |
q = models[selected_model_name]
|
| 135 |
q.run()
|
|
|