Wanli commited on
Commit
160af44
·
1 Parent(s): ea63088

Add 'load_label' parameter for image classification models (#185)

Browse files

* add 'load_label' parameter for image classification models

* move load_label flag to initializer

models/image_classification_mobilenet/demo.py CHANGED
@@ -31,13 +31,16 @@ parser.add_argument('--backend_target', '-bt', type=int, default=0,
31
  {:d}: TIM-VX + NPU,
32
  {:d}: CANN + NPU
33
  '''.format(*[x for x in range(len(backend_target_pairs))]))
 
 
34
  args = parser.parse_args()
35
 
36
  if __name__ == '__main__':
37
  backend_id = backend_target_pairs[args.backend_target][0]
38
  target_id = backend_target_pairs[args.backend_target][1]
 
39
  # Instantiate MobileNet
40
- model = MobileNet(modelPath=args.model, backendId=backend_id, targetId=target_id)
41
 
42
  # Read image and get a 224x224 crop from a 256x256 resized
43
  image = cv.imread(args.input)
 
31
  {:d}: TIM-VX + NPU,
32
  {:d}: CANN + NPU
33
  '''.format(*[x for x in range(len(backend_target_pairs))]))
34
+ parser.add_argument('--top_k', type=int, default=1,
35
+ help='Usage: Get top k predictions.')
36
  args = parser.parse_args()
37
 
38
  if __name__ == '__main__':
39
  backend_id = backend_target_pairs[args.backend_target][0]
40
  target_id = backend_target_pairs[args.backend_target][1]
41
+ top_k = args.top_k
42
  # Instantiate MobileNet
43
+ model = MobileNet(modelPath=args.model, topK=top_k, backendId=backend_id, targetId=target_id)
44
 
45
  # Read image and get a 224x224 crop from a 256x256 resized
46
  image = cv.imread(args.input)
models/image_classification_mobilenet/mobilenet.py CHANGED
@@ -6,10 +6,11 @@ class MobileNet:
6
  Works with MobileNet V1 & V2.
7
  '''
8
 
9
- def __init__(self, modelPath, topK=1, backendId=0, targetId=0):
10
  self.model_path = modelPath
11
  assert topK >= 1
12
  self.top_k = topK
 
13
  self.backend_id = backendId
14
  self.target_id = targetId
15
 
@@ -64,7 +65,7 @@ class MobileNet:
64
  for o in output_blob:
65
  class_id_list = o.argsort()[::-1][:self.top_k]
66
  batched_class_id_list.append(class_id_list)
67
- if len(self._labels) > 0:
68
  batched_predicted_labels = []
69
  for class_id_list in batched_class_id_list:
70
  predicted_labels = []
 
6
  Works with MobileNet V1 & V2.
7
  '''
8
 
9
+ def __init__(self, modelPath, topK=1, loadLabel=True, backendId=0, targetId=0):
10
  self.model_path = modelPath
11
  assert topK >= 1
12
  self.top_k = topK
13
+ self.load_label = loadLabel
14
  self.backend_id = backendId
15
  self.target_id = targetId
16
 
 
65
  for o in output_blob:
66
  class_id_list = o.argsort()[::-1][:self.top_k]
67
  batched_class_id_list.append(class_id_list)
68
+ if len(self._labels) > 0 and self.load_label:
69
  batched_predicted_labels = []
70
  for class_id_list in batched_class_id_list:
71
  predicted_labels = []
models/image_classification_ppresnet/demo.py CHANGED
@@ -37,13 +37,16 @@ parser.add_argument('--backend_target', '-bt', type=int, default=0,
37
  {:d}: TIM-VX + NPU,
38
  {:d}: CANN + NPU
39
  '''.format(*[x for x in range(len(backend_target_pairs))]))
 
 
40
  args = parser.parse_args()
41
 
42
  if __name__ == '__main__':
43
  backend_id = backend_target_pairs[args.backend_target][0]
44
  target_id = backend_target_pairs[args.backend_target][1]
 
45
  # Instantiate ResNet
46
- model = PPResNet(modelPath=args.model, backendId=backend_id, targetId=target_id)
47
 
48
  # Read image and get a 224x224 crop from a 256x256 resized
49
  image = cv.imread(args.input)
 
37
  {:d}: TIM-VX + NPU,
38
  {:d}: CANN + NPU
39
  '''.format(*[x for x in range(len(backend_target_pairs))]))
40
+ parser.add_argument('--top_k', type=int, default=1,
41
+ help='Usage: Get top k predictions.')
42
  args = parser.parse_args()
43
 
44
  if __name__ == '__main__':
45
  backend_id = backend_target_pairs[args.backend_target][0]
46
  target_id = backend_target_pairs[args.backend_target][1]
47
+ top_k = args.top_k
48
  # Instantiate ResNet
49
+ model = PPResNet(modelPath=args.model, topK=top_k, backendId=backend_id, targetId=target_id)
50
 
51
  # Read image and get a 224x224 crop from a 256x256 resized
52
  image = cv.imread(args.input)
models/image_classification_ppresnet/ppresnet.py CHANGED
@@ -9,10 +9,11 @@ import numpy as np
9
  import cv2 as cv
10
 
11
  class PPResNet:
12
- def __init__(self, modelPath, topK=1, backendId=0, targetId=0):
13
  self._modelPath = modelPath
14
  assert topK >= 1
15
  self._topK = topK
 
16
  self._backendId = backendId
17
  self._targetId = targetId
18
 
@@ -69,7 +70,7 @@ class PPResNet:
69
  for ob in outputBlob:
70
  class_id_list = ob.argsort()[::-1][:self._topK]
71
  batched_class_id_list.append(class_id_list)
72
- if len(self._labels) > 0:
73
  batched_predicted_labels = []
74
  for class_id_list in batched_class_id_list:
75
  predicted_labels = []
 
9
  import cv2 as cv
10
 
11
  class PPResNet:
12
+ def __init__(self, modelPath, topK=1, loadLabel=True, backendId=0, targetId=0):
13
  self._modelPath = modelPath
14
  assert topK >= 1
15
  self._topK = topK
16
+ self._load_label = loadLabel
17
  self._backendId = backendId
18
  self._targetId = targetId
19
 
 
70
  for ob in outputBlob:
71
  class_id_list = ob.argsort()[::-1][:self._topK]
72
  batched_class_id_list.append(class_id_list)
73
+ if len(self._labels) > 0 and self._load_label:
74
  batched_predicted_labels = []
75
  for class_id_list in batched_class_id_list:
76
  predicted_labels = []
tools/eval/eval.py CHANGED
@@ -25,32 +25,38 @@ models = dict(
25
  name="MobileNet",
26
  topic="image_classification",
27
  modelPath=os.path.join(root_dir, "models/image_classification_mobilenet/image_classification_mobilenetv1_2022apr.onnx"),
28
- topK=5),
 
29
  mobilenetv1_q=dict(
30
  name="MobileNet",
31
  topic="image_classification",
32
  modelPath=os.path.join(root_dir, "models/image_classification_mobilenet/image_classification_mobilenetv1_2022apr_int8.onnx"),
33
- topK=5),
 
34
  mobilenetv2=dict(
35
  name="MobileNet",
36
  topic="image_classification",
37
  modelPath=os.path.join(root_dir, "models/image_classification_mobilenet/image_classification_mobilenetv2_2022apr.onnx"),
38
- topK=5),
 
39
  mobilenetv2_q=dict(
40
  name="MobileNet",
41
  topic="image_classification",
42
  modelPath=os.path.join(root_dir, "models/image_classification_mobilenet/image_classification_mobilenetv2_2022apr_int8.onnx"),
43
- topK=5),
 
44
  ppresnet=dict(
45
  name="PPResNet",
46
  topic="image_classification",
47
  modelPath=os.path.join(root_dir, "models/image_classification_ppresnet/image_classification_ppresnet50_2022jan.onnx"),
48
- topK=5),
 
49
  ppresnet_q=dict(
50
  name="PPResNet",
51
  topic="image_classification",
52
  modelPath=os.path.join(root_dir, "models/image_classification_ppresnet/image_classification_ppresnet50_2022jan_int8.onnx"),
53
- topK=5),
 
54
  yunet=dict(
55
  name="YuNet",
56
  topic="face_detection",
 
25
  name="MobileNet",
26
  topic="image_classification",
27
  modelPath=os.path.join(root_dir, "models/image_classification_mobilenet/image_classification_mobilenetv1_2022apr.onnx"),
28
+ topK=5,
29
+ loadLabel=False),
30
  mobilenetv1_q=dict(
31
  name="MobileNet",
32
  topic="image_classification",
33
  modelPath=os.path.join(root_dir, "models/image_classification_mobilenet/image_classification_mobilenetv1_2022apr_int8.onnx"),
34
+ topK=5,
35
+ loadLabel=False),
36
  mobilenetv2=dict(
37
  name="MobileNet",
38
  topic="image_classification",
39
  modelPath=os.path.join(root_dir, "models/image_classification_mobilenet/image_classification_mobilenetv2_2022apr.onnx"),
40
+ topK=5,
41
+ loadLabel=False),
42
  mobilenetv2_q=dict(
43
  name="MobileNet",
44
  topic="image_classification",
45
  modelPath=os.path.join(root_dir, "models/image_classification_mobilenet/image_classification_mobilenetv2_2022apr_int8.onnx"),
46
+ topK=5,
47
+ loadLabel=False),
48
  ppresnet=dict(
49
  name="PPResNet",
50
  topic="image_classification",
51
  modelPath=os.path.join(root_dir, "models/image_classification_ppresnet/image_classification_ppresnet50_2022jan.onnx"),
52
+ topK=5,
53
+ loadLabel=False),
54
  ppresnet_q=dict(
55
  name="PPResNet",
56
  topic="image_classification",
57
  modelPath=os.path.join(root_dir, "models/image_classification_ppresnet/image_classification_ppresnet50_2022jan_int8.onnx"),
58
+ topK=5,
59
+ loadLabel=False),
60
  yunet=dict(
61
  name="YuNet",
62
  topic="face_detection",