ONNX
Wanli commited on
Commit
2d6cdf7
·
1 Parent(s): 13ef3c1

Add 'load_label' parameter for image classification models (#185)

Browse files

* add 'load_label' parameter for image classification models

* move load_label flag to initializer

Files changed (2) hide show
  1. demo.py +4 -1
  2. ppresnet.py +3 -2
demo.py CHANGED
@@ -37,13 +37,16 @@ parser.add_argument('--backend_target', '-bt', type=int, default=0,
37
  {:d}: TIM-VX + NPU,
38
  {:d}: CANN + NPU
39
  '''.format(*[x for x in range(len(backend_target_pairs))]))
 
 
40
  args = parser.parse_args()
41
 
42
  if __name__ == '__main__':
43
  backend_id = backend_target_pairs[args.backend_target][0]
44
  target_id = backend_target_pairs[args.backend_target][1]
 
45
  # Instantiate ResNet
46
- model = PPResNet(modelPath=args.model, backendId=backend_id, targetId=target_id)
47
 
48
  # Read image and get a 224x224 crop from a 256x256 resized
49
  image = cv.imread(args.input)
 
37
  {:d}: TIM-VX + NPU,
38
  {:d}: CANN + NPU
39
  '''.format(*[x for x in range(len(backend_target_pairs))]))
40
+ parser.add_argument('--top_k', type=int, default=1,
41
+ help='Usage: Get top k predictions.')
42
  args = parser.parse_args()
43
 
44
  if __name__ == '__main__':
45
  backend_id = backend_target_pairs[args.backend_target][0]
46
  target_id = backend_target_pairs[args.backend_target][1]
47
+ top_k = args.top_k
48
  # Instantiate ResNet
49
+ model = PPResNet(modelPath=args.model, topK=top_k, backendId=backend_id, targetId=target_id)
50
 
51
  # Read image and get a 224x224 crop from a 256x256 resized
52
  image = cv.imread(args.input)
ppresnet.py CHANGED
@@ -9,10 +9,11 @@ import numpy as np
9
  import cv2 as cv
10
 
11
  class PPResNet:
12
- def __init__(self, modelPath, topK=1, backendId=0, targetId=0):
13
  self._modelPath = modelPath
14
  assert topK >= 1
15
  self._topK = topK
 
16
  self._backendId = backendId
17
  self._targetId = targetId
18
 
@@ -69,7 +70,7 @@ class PPResNet:
69
  for ob in outputBlob:
70
  class_id_list = ob.argsort()[::-1][:self._topK]
71
  batched_class_id_list.append(class_id_list)
72
- if len(self._labels) > 0:
73
  batched_predicted_labels = []
74
  for class_id_list in batched_class_id_list:
75
  predicted_labels = []
 
9
  import cv2 as cv
10
 
11
  class PPResNet:
12
+ def __init__(self, modelPath, topK=1, loadLabel=True, backendId=0, targetId=0):
13
  self._modelPath = modelPath
14
  assert topK >= 1
15
  self._topK = topK
16
+ self._load_label = loadLabel
17
  self._backendId = backendId
18
  self._targetId = targetId
19
 
 
70
  for ob in outputBlob:
71
  class_id_list = ob.argsort()[::-1][:self._topK]
72
  batched_class_id_list.append(class_id_list)
73
+ if len(self._labels) > 0 and self._load_label:
74
  batched_predicted_labels = []
75
  for class_id_list in batched_class_id_list:
76
  predicted_labels = []