akhaliq HF Staff commited on
Commit
8e15e86
·
1 Parent(s): 8014f33

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -47
app.py CHANGED
@@ -3,69 +3,64 @@ import matplotlib.pyplot as plt
3
  import numpy as np
4
  from collections import namedtuple
5
  from mxnet.gluon.data.vision import transforms
6
- from mxnet.contrib.onnx.onnx2mx.import_model import import_model
7
  import os
8
  import gradio as gr
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  mx.test_utils.download('https://s3.amazonaws.com/model-server/inputs/kitten.jpg')
11
 
12
  mx.test_utils.download('https://s3.amazonaws.com/onnx-model-zoo/synset.txt')
13
  with open('synset.txt', 'r') as f:
14
  labels = [l.rstrip() for l in f]
15
 
16
- os.system("wget https://github.com/AK391/models/raw/main/vision/classification/shufflenet/model/shufflenet-v2-10.onnx")
17
 
18
- # Enter path to the ONNX model file
19
 
20
- sym, arg_params, aux_params = import_model('shufflenet-v2-10.onnx')
21
-
22
- Batch = namedtuple('Batch', ['data'])
23
- def get_image(path, show=False):
24
- img = mx.image.imread(path)
25
- if img is None:
26
- return None
27
- if show:
28
- plt.imshow(img.asnumpy())
29
- plt.axis('off')
30
- return img
31
-
32
- def preprocess(img):
33
- transform_fn = transforms.Compose([
34
- transforms.Resize(256),
35
- transforms.CenterCrop(224),
36
- transforms.ToTensor(),
37
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
38
- ])
39
- img = transform_fn(img)
40
- img = img.expand_dims(axis=0)
41
- return img
42
 
43
  def predict(path):
44
- img = get_image(path, show=True)
45
- img = preprocess(img)
46
- mod.forward(Batch([img]))
47
- # Take softmax to generate probabilities
48
- scores = mx.ndarray.softmax(mod.get_outputs()[0]).asnumpy()
49
- # print the top-5 inferences class
50
- scores = np.squeeze(scores)
51
- a = np.argsort(scores)[::-1]
52
  results = {}
53
  for i in a[0:5]:
54
- results[labels[i]] = float(scores[i])
55
  return results
56
-
57
- # Determine and set context
58
- if len(mx.test_utils.list_gpus())==0:
59
- ctx = mx.cpu()
60
- else:
61
- ctx = mx.gpu(0)
62
- # Load module
63
- mod = mx.mod.Module(symbol=sym, context=ctx, data_names=['data'], label_names=None)
64
- mod.bind(for_training=False, data_shapes=[('data', (1,3,224,224))],label_shapes=mod._label_shapes)
65
- mod.set_params(arg_params, aux_params, allow_missing=True, allow_extra=True)
66
 
67
- title="ResNet"
68
- description="ResNet models perform image classification - they take images as input and classify the major object in the image into a set of pre-defined classes. They are trained on ImageNet dataset which contains images from 1000 classes. ResNet models provide very high accuracies with affordable model sizes. They are ideal for cases when high accuracy of classification is required."
69
 
70
  examples=[['catonnx.jpg']]
71
- gr.Interface(predict,gr.inputs.Image(type='filepath'),"label",title=title,description=description,examples=examples).launch(enable_queue=True)
 
3
  import numpy as np
4
  from collections import namedtuple
5
  from mxnet.gluon.data.vision import transforms
 
6
  import os
7
  import gradio as gr
8
 
9
+ from PIL import Image
10
+ import imageio
11
+ import onnxruntime as ort
12
+
13
+ def get_image(path):
14
+ '''
15
+ Using path to image, return the RGB load image
16
+ '''
17
+ img = imageio.imread(path, pilmode='RGB')
18
+ return img
19
+
20
+ # Pre-processing function for ImageNet models using numpy
21
+ def preprocess(img):
22
+ '''
23
+ Preprocessing required on the images for inference with mxnet gluon
24
+ The function takes loaded image and returns processed tensor
25
+ '''
26
+ img = np.array(Image.fromarray(img).resize((224, 224))).astype(np.float32)
27
+ img[:, :, 0] -= 123.68
28
+ img[:, :, 1] -= 116.779
29
+ img[:, :, 2] -= 103.939
30
+ img[:,:,[0,1,2]] = img[:,:,[2,1,0]]
31
+ img = img.transpose((2, 0, 1))
32
+ img = np.expand_dims(img, axis=0)
33
+
34
+ return img
35
+
36
  mx.test_utils.download('https://s3.amazonaws.com/model-server/inputs/kitten.jpg')
37
 
38
  mx.test_utils.download('https://s3.amazonaws.com/onnx-model-zoo/synset.txt')
39
  with open('synset.txt', 'r') as f:
40
  labels = [l.rstrip() for l in f]
41
 
42
+ os.system("wget https://github.com/onnx/models/raw/main/vision/classification/inception_and_googlenet/googlenet/model/googlenet-9.onnx")
43
 
44
+ ort_session = ort.InferenceSession("googlenet-9.onnx")
45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
  def predict(path):
48
+ img_batch = preprocess(get_image(path))
49
+
50
+ outputs = ort_session.run(
51
+ None,
52
+ {"data_0": img_batch.astype(np.float32)},
53
+ )
54
+
55
+ a = np.argsort(-outputs[0].flatten())
56
  results = {}
57
  for i in a[0:5]:
58
+ results[labels[i]]=float(outputs[0][0][i])
59
  return results
60
+
 
 
 
 
 
 
 
 
 
61
 
62
+ title="GoogleNet"
63
+ description="GoogLeNet is the name of a convolutional neural network for classification, which competed in the ImageNet Large Scale Visual Recognition Challenge in 2014."
64
 
65
  examples=[['catonnx.jpg']]
66
+ gr.Interface(predict,gr.inputs.Image(type='filepath'),"label",title=title,description=description,examples=examples).launch(enable_queue=True,debug=True)