ONNX
Wanli commited on
Commit
17cd08b
·
1 Parent(s): 850fa1e

fix some error and bugs (#112)

Browse files
Files changed (3) hide show
  1. demo.py +2 -2
  2. mobilenet_v1.py +2 -2
  3. mobilenet_v2.py +2 -2
demo.py CHANGED
@@ -39,8 +39,8 @@ if __name__ == '__main__':
39
  models = {
40
  'v1': MobileNetV1(modelPath='./image_classification_mobilenetv1_2022apr.onnx', labelPath=args.label, backendId=args.backend, targetId=args.target),
41
  'v2': MobileNetV2(modelPath='./image_classification_mobilenetv2_2022apr.onnx', labelPath=args.label, backendId=args.backend, targetId=args.target),
42
- 'v1-q': MobileNetV1(modelPath='./image_classification_mobilenetv1_2022apr-act_int8-wt_int8-quantized.onnx', labelPath=args.label, backendId=args.backend, targetId=args.target),
43
- 'v2-q': MobileNetV2(modelPath='./image_classification_mobilenetv2_2022apr-act_int8-wt_int8-quantized.onnx', labelPath=args.label, backendId=args.backend, targetId=args.target)
44
 
45
  }
46
  model = models[args.model]
 
39
  models = {
40
  'v1': MobileNetV1(modelPath='./image_classification_mobilenetv1_2022apr.onnx', labelPath=args.label, backendId=args.backend, targetId=args.target),
41
  'v2': MobileNetV2(modelPath='./image_classification_mobilenetv2_2022apr.onnx', labelPath=args.label, backendId=args.backend, targetId=args.target),
42
+ 'v1-q': MobileNetV1(modelPath='./image_classification_mobilenetv1_2022apr-int8-quantized.onnx', labelPath=args.label, backendId=args.backend, targetId=args.target),
43
+ 'v2-q': MobileNetV2(modelPath='./image_classification_mobilenetv2_2022apr-int8-quantized.onnx', labelPath=args.label, backendId=args.backend, targetId=args.target)
44
 
45
  }
46
  model = models[args.model]
mobilenet_v1.py CHANGED
@@ -21,7 +21,7 @@ class MobileNetV1:
21
  self.std=[0.229, 0.224, 0.225]
22
 
23
  # load labels
24
- self.labels = self._load_labels()
25
 
26
  def _load_labels(self):
27
  labels = []
@@ -68,7 +68,7 @@ class MobileNetV1:
68
  for o in output_blob:
69
  class_id_list = o.argsort()[::-1][:self.top_k]
70
  batched_class_id_list.append(class_id_list)
71
- if len(self.labels) > 0:
72
  batched_predicted_labels = []
73
  for class_id_list in batched_class_id_list:
74
  predicted_labels = []
 
21
  self.std=[0.229, 0.224, 0.225]
22
 
23
  # load labels
24
+ self._labels = self._load_labels()
25
 
26
  def _load_labels(self):
27
  labels = []
 
68
  for o in output_blob:
69
  class_id_list = o.argsort()[::-1][:self.top_k]
70
  batched_class_id_list.append(class_id_list)
71
+ if len(self._labels) > 0:
72
  batched_predicted_labels = []
73
  for class_id_list in batched_class_id_list:
74
  predicted_labels = []
mobilenet_v2.py CHANGED
@@ -21,7 +21,7 @@ class MobileNetV2:
21
  self.std=[0.229, 0.224, 0.225]
22
 
23
  # load labels
24
- self.labels = self._load_labels()
25
 
26
  def _load_labels(self):
27
  labels = []
@@ -68,7 +68,7 @@ class MobileNetV2:
68
  for o in output_blob:
69
  class_id_list = o.argsort()[::-1][:self.top_k]
70
  batched_class_id_list.append(class_id_list)
71
- if len(self.labels) > 0:
72
  batched_predicted_labels = []
73
  for class_id_list in batched_class_id_list:
74
  predicted_labels = []
 
21
  self.std=[0.229, 0.224, 0.225]
22
 
23
  # load labels
24
+ self._labels = self._load_labels()
25
 
26
  def _load_labels(self):
27
  labels = []
 
68
  for o in output_blob:
69
  class_id_list = o.argsort()[::-1][:self.top_k]
70
  batched_class_id_list.append(class_id_list)
71
+ if len(self._labels) > 0:
72
  batched_predicted_labels = []
73
  for class_id_list in batched_class_id_list:
74
  predicted_labels = []