ytfeng commited on
Commit
284d26a
·
1 Parent(s): a07f7bd

shorten int8-quantized naming (#149)

Browse files
Files changed (1) hide show
  1. tools/quantize/quantize-ort.py +9 -9
tools/quantize/quantize-ort.py CHANGED
@@ -59,29 +59,30 @@ class Quantize:
59
  # data reader
60
  self.dr = DataReader(self.model_path, self.calibration_image_dir, self.transforms, data_dim)
61
 
62
- def check_opset(self, convert=True):
63
  model = onnx.load(self.model_path)
64
  if model.opset_import[0].version != 13:
65
  print('\tmodel opset version: {}. Converting to opset 13'.format(model.opset_import[0].version))
66
  # convert opset version to 13
67
  model_opset13 = version_converter.convert_version(model, 13)
68
  # save converted model
69
- output_name = '{}-opset.onnx'.format(self.model_path[:-5])
70
  onnx.save_model(model_opset13, output_name)
71
  # update model_path for quantization
72
- self.model_path = output_name
 
73
 
74
  def run(self):
75
  print('Quantizing {}: act_type {}, wt_type {}'.format(self.model_path, self.act_type, self.wt_type))
76
- self.check_opset()
77
- output_name = '{}-act_{}-wt_{}-quantized.onnx'.format(self.model_path[:-5], self.act_type, self.wt_type)
78
- quantize_static(self.model_path, output_name, self.dr,
79
  quant_format=QuantFormat.QOperator, # start from onnxruntime==1.11.0, quant_format is set to QuantFormat.QDQ by default, which performs fake quantization
80
  per_channel=self.per_channel,
81
  weight_type=self.type_dict[self.wt_type],
82
  activation_type=self.type_dict[self.act_type])
83
- os.remove('augmented_model.onnx')
84
- os.remove('{}-opt.onnx'.format(self.model_path[:-5]))
85
  print('\tQuantized model saved to {}'.format(output_name))
86
 
87
  models=dict(
@@ -132,4 +133,3 @@ if __name__ == '__main__':
132
  for selected_model_name in selected_models:
133
  q = models[selected_model_name]
134
  q.run()
135
-
 
59
  # data reader
60
  self.dr = DataReader(self.model_path, self.calibration_image_dir, self.transforms, data_dim)
61
 
62
+ def check_opset(self):
63
  model = onnx.load(self.model_path)
64
  if model.opset_import[0].version != 13:
65
  print('\tmodel opset version: {}. Converting to opset 13'.format(model.opset_import[0].version))
66
  # convert opset version to 13
67
  model_opset13 = version_converter.convert_version(model, 13)
68
  # save converted model
69
+ output_name = '{}-opset13.onnx'.format(self.model_path[:-5])
70
  onnx.save_model(model_opset13, output_name)
71
  # update model_path for quantization
72
+ return output_name
73
+ return self.model_path
74
 
75
  def run(self):
76
  print('Quantizing {}: act_type {}, wt_type {}'.format(self.model_path, self.act_type, self.wt_type))
77
+ new_model_path = self.check_opset()
78
+ output_name = '{}_{}.onnx'.format(self.model_path[:-5], self.wt_type)
79
+ quantize_static(new_model_path, output_name, self.dr,
80
  quant_format=QuantFormat.QOperator, # start from onnxruntime==1.11.0, quant_format is set to QuantFormat.QDQ by default, which performs fake quantization
81
  per_channel=self.per_channel,
82
  weight_type=self.type_dict[self.wt_type],
83
  activation_type=self.type_dict[self.act_type])
84
+ if new_model_path != self.model_path:
85
+ os.remove(new_model_path)
86
  print('\tQuantized model saved to {}'.format(output_name))
87
 
88
  models=dict(
 
133
  for selected_model_name in selected_models:
134
  q = models[selected_model_name]
135
  q.run()