File size: 1,649 Bytes
cd17133
 
a49ba8d
d4b4b25
a49ba8d
e1b26df
 
a49ba8d
e1b26df
 
 
 
 
 
d4b4b25
 
e1b26df
 
 
 
d4b4b25
 
e1b26df
 
d4b4b25
e1b26df
 
448af73
e1b26df
 
 
cd17133
e1b26df
 
 
cd17133
448af73
e1b26df
cd17133
e1b26df
 
 
 
 
 
448af73
e1b26df
 
cd17133
e1b26df
 
cd17133
e1b26df
 
331746a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import numpy as np
import cv2
import gradio as gr
import tensorflow as tf

# app title
title = "Welcome on your first sketch recognition app!"

# app description
head = (
  "<center>"
  "<img src='./mnist-classes.png' width=400>"
  "The robot was trained to classify numbers (from 0 to 9). To test it, write your number in the space provided."
  "</center>"
)

# GitHub repository link
ref = "Find the whole code [here](https://github.com/ovh/ai-training-examples/tree/main/apps/gradio/sketch-recognition)."

# image size: 28x28
img_size = 28

# classes name (from 0 to 9)
labels = ["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]

# load model (trained on MNIST dataset)
model = tf.keras.models.load_model("./sketch_recognition_numbers_model.h5")

# prediction function for sketch recognition
def predict(img):
    # Convert from PIL to NumPy
    img = np.array(img)
    
    # If the image is in RGB format, convert it to grayscale
    if len(img.shape) == 3:
        img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)

    # Resize the image to 28x28
    img = cv2.resize(img, (img_size, img_size))
    
    # Reshape to the model's input shape (1,28,28,1)
    img = img.reshape(1, img_size, img_size, 1)
    
    # model predictions
    preds = model.predict(img)[0]

    # return the probability for each class
    return {label: float(pred) for label, pred in zip(labels, preds)}

# top 3 of classes
label = gr.Label(num_top_classes=3)

# open Gradio interface for sketch recognition
interface = gr.Interface(fn=predict, inputs="sketchpad", outputs=label, title=title, description=head, article=ref)
interface.launch()