DavidD003 commited on
Commit
33343db
·
1 Parent(s): 05582d5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -4
app.py CHANGED
@@ -4,12 +4,33 @@ from PIL import Image
4
  #
5
  #learn = load_learner('export.pkl')
6
  learn = torch.load('digit_classifier.pth')
 
7
  labels = [str(x) for x in range(10)]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  def predict(img):
9
  #First take input and reduce it to 8x8 px as the dataset was
 
 
 
 
 
 
 
10
 
11
- img = PILImage.create(img)
12
- pred,pred_idx,probs = learn.predict(img)
13
- return {labels[i]: float(probs[i]) for i in range(len(labels))}
14
 
15
- gr.Interface(fn=predict, inputs=gr.inputs.Image(shape=(512, 512)), outputs=gr.outputs.Label(num_top_classes=3)).launch(share=True)
 
4
  #
5
  #learn = load_learner('export.pkl')
6
  learn = torch.load('digit_classifier.pth')
7
+ learn.eval() #switch to eval mode
8
  labels = [str(x) for x in range(10)]
9
+
10
+ #Define function to reduce image of arbitrary size to 8x8 per model requirements.
11
+ def reduce_image_count(image):
12
+ output_size = (8, 8)
13
+ block_size = (image.shape[0] // output_size[0], image.shape[1] // output_size[1])
14
+ output = np.zeros(output_size)
15
+
16
+ for i in range(output_size[0]):
17
+ for j in range(output_size[1]):
18
+ block = image[i*block_size[0]:(i+1)*block_size[0], j*block_size[1]:(j+1)*block_size[1]]
19
+ count = np.count_nonzero(block)
20
+ output[i, j] = 16 - ((count / (block_size[0] * block_size[1])) * 16)
21
+
22
+ return output
23
+
24
  def predict(img):
25
  #First take input and reduce it to 8x8 px as the dataset was
26
+ pil_image = Image.open(img) #get image
27
+ gray_img = pil_image.convert('L')#grayscale
28
+ pic = np.array(gray_img) #convert to array
29
+ inp_img=reduce_image_count(pic)#Reduce image to required input size
30
+
31
+ otpt=F.softmax(learn.forward(inp_img.view(-1,64)))
32
+ #pred,pred_idx,probs = learn.predict(img)
33
 
34
+ return {labels[i]: float(otpt[0].data[i]) for i in range(len(labels)),'image': inp_img}
 
 
35
 
36
+ gr.Interface(fn=predict, inputs=gr.inputs.Image(shape=(512, 512)), outputs=[gr.outputs.Label(num_top_classes=3), gr.outputs.Image()]).launch(share=True)