Wanli
commited on
Commit
·
17cd08b
1
Parent(s):
850fa1e
fix some error and bugs (#112)
Browse files- demo.py +2 -2
- mobilenet_v1.py +2 -2
- mobilenet_v2.py +2 -2
demo.py
CHANGED
@@ -39,8 +39,8 @@ if __name__ == '__main__':
|
|
39 |
models = {
|
40 |
'v1': MobileNetV1(modelPath='./image_classification_mobilenetv1_2022apr.onnx', labelPath=args.label, backendId=args.backend, targetId=args.target),
|
41 |
'v2': MobileNetV2(modelPath='./image_classification_mobilenetv2_2022apr.onnx', labelPath=args.label, backendId=args.backend, targetId=args.target),
|
42 |
-
'v1-q': MobileNetV1(modelPath='./image_classification_mobilenetv1_2022apr-
|
43 |
-
'v2-q': MobileNetV2(modelPath='./image_classification_mobilenetv2_2022apr-
|
44 |
|
45 |
}
|
46 |
model = models[args.model]
|
|
|
39 |
models = {
|
40 |
'v1': MobileNetV1(modelPath='./image_classification_mobilenetv1_2022apr.onnx', labelPath=args.label, backendId=args.backend, targetId=args.target),
|
41 |
'v2': MobileNetV2(modelPath='./image_classification_mobilenetv2_2022apr.onnx', labelPath=args.label, backendId=args.backend, targetId=args.target),
|
42 |
+
'v1-q': MobileNetV1(modelPath='./image_classification_mobilenetv1_2022apr-int8-quantized.onnx', labelPath=args.label, backendId=args.backend, targetId=args.target),
|
43 |
+
'v2-q': MobileNetV2(modelPath='./image_classification_mobilenetv2_2022apr-int8-quantized.onnx', labelPath=args.label, backendId=args.backend, targetId=args.target)
|
44 |
|
45 |
}
|
46 |
model = models[args.model]
|
mobilenet_v1.py
CHANGED
@@ -21,7 +21,7 @@ class MobileNetV1:
|
|
21 |
self.std=[0.229, 0.224, 0.225]
|
22 |
|
23 |
# load labels
|
24 |
-
self.
|
25 |
|
26 |
def _load_labels(self):
|
27 |
labels = []
|
@@ -68,7 +68,7 @@ class MobileNetV1:
|
|
68 |
for o in output_blob:
|
69 |
class_id_list = o.argsort()[::-1][:self.top_k]
|
70 |
batched_class_id_list.append(class_id_list)
|
71 |
-
if len(self.
|
72 |
batched_predicted_labels = []
|
73 |
for class_id_list in batched_class_id_list:
|
74 |
predicted_labels = []
|
|
|
21 |
self.std=[0.229, 0.224, 0.225]
|
22 |
|
23 |
# load labels
|
24 |
+
self._labels = self._load_labels()
|
25 |
|
26 |
def _load_labels(self):
|
27 |
labels = []
|
|
|
68 |
for o in output_blob:
|
69 |
class_id_list = o.argsort()[::-1][:self.top_k]
|
70 |
batched_class_id_list.append(class_id_list)
|
71 |
+
if len(self._labels) > 0:
|
72 |
batched_predicted_labels = []
|
73 |
for class_id_list in batched_class_id_list:
|
74 |
predicted_labels = []
|
mobilenet_v2.py
CHANGED
@@ -21,7 +21,7 @@ class MobileNetV2:
|
|
21 |
self.std=[0.229, 0.224, 0.225]
|
22 |
|
23 |
# load labels
|
24 |
-
self.
|
25 |
|
26 |
def _load_labels(self):
|
27 |
labels = []
|
@@ -68,7 +68,7 @@ class MobileNetV2:
|
|
68 |
for o in output_blob:
|
69 |
class_id_list = o.argsort()[::-1][:self.top_k]
|
70 |
batched_class_id_list.append(class_id_list)
|
71 |
-
if len(self.
|
72 |
batched_predicted_labels = []
|
73 |
for class_id_list in batched_class_id_list:
|
74 |
predicted_labels = []
|
|
|
21 |
self.std=[0.229, 0.224, 0.225]
|
22 |
|
23 |
# load labels
|
24 |
+
self._labels = self._load_labels()
|
25 |
|
26 |
def _load_labels(self):
|
27 |
labels = []
|
|
|
68 |
for o in output_blob:
|
69 |
class_id_list = o.argsort()[::-1][:self.top_k]
|
70 |
batched_class_id_list.append(class_id_list)
|
71 |
+
if len(self._labels) > 0:
|
72 |
batched_predicted_labels = []
|
73 |
for class_id_list in batched_class_id_list:
|
74 |
predicted_labels = []
|