ONNX
Wanli commited on
Commit
89a8880
·
1 Parent(s): bd42ba7

Add 'load_label' parameter for image classification models (#185)

Browse files

* add 'load_label' parameter for image classification models

* move load_label flag to initializer

Files changed (2) hide show
  1. demo.py +4 -1
  2. mobilenet.py +3 -2
demo.py CHANGED
@@ -31,13 +31,16 @@ parser.add_argument('--backend_target', '-bt', type=int, default=0,
31
  {:d}: TIM-VX + NPU,
32
  {:d}: CANN + NPU
33
  '''.format(*[x for x in range(len(backend_target_pairs))]))
 
 
34
  args = parser.parse_args()
35
 
36
  if __name__ == '__main__':
37
  backend_id = backend_target_pairs[args.backend_target][0]
38
  target_id = backend_target_pairs[args.backend_target][1]
 
39
  # Instantiate MobileNet
40
- model = MobileNet(modelPath=args.model, backendId=backend_id, targetId=target_id)
41
 
42
  # Read image and get a 224x224 crop from a 256x256 resized
43
  image = cv.imread(args.input)
 
31
  {:d}: TIM-VX + NPU,
32
  {:d}: CANN + NPU
33
  '''.format(*[x for x in range(len(backend_target_pairs))]))
34
+ parser.add_argument('--top_k', type=int, default=1,
35
+ help='Usage: Get top k predictions.')
36
  args = parser.parse_args()
37
 
38
  if __name__ == '__main__':
39
  backend_id = backend_target_pairs[args.backend_target][0]
40
  target_id = backend_target_pairs[args.backend_target][1]
41
+ top_k = args.top_k
42
  # Instantiate MobileNet
43
+ model = MobileNet(modelPath=args.model, topK=top_k, backendId=backend_id, targetId=target_id)
44
 
45
  # Read image and get a 224x224 crop from a 256x256 resized
46
  image = cv.imread(args.input)
mobilenet.py CHANGED
@@ -6,10 +6,11 @@ class MobileNet:
6
  Works with MobileNet V1 & V2.
7
  '''
8
 
9
- def __init__(self, modelPath, topK=1, backendId=0, targetId=0):
10
  self.model_path = modelPath
11
  assert topK >= 1
12
  self.top_k = topK
 
13
  self.backend_id = backendId
14
  self.target_id = targetId
15
 
@@ -64,7 +65,7 @@ class MobileNet:
64
  for o in output_blob:
65
  class_id_list = o.argsort()[::-1][:self.top_k]
66
  batched_class_id_list.append(class_id_list)
67
- if len(self._labels) > 0:
68
  batched_predicted_labels = []
69
  for class_id_list in batched_class_id_list:
70
  predicted_labels = []
 
6
  Works with MobileNet V1 & V2.
7
  '''
8
 
9
+ def __init__(self, modelPath, topK=1, loadLabel=True, backendId=0, targetId=0):
10
  self.model_path = modelPath
11
  assert topK >= 1
12
  self.top_k = topK
13
+ self.load_label = loadLabel
14
  self.backend_id = backendId
15
  self.target_id = targetId
16
 
 
65
  for o in output_blob:
66
  class_id_list = o.argsort()[::-1][:self.top_k]
67
  batched_class_id_list.append(class_id_list)
68
+ if len(self._labels) > 0 and self.load_label:
69
  batched_predicted_labels = []
70
  for class_id_list in batched_class_id_list:
71
  predicted_labels = []