File size: 1,878 Bytes
c40b85e
a49ba8d
d4b4b25
c40b85e
 
a49ba8d
e1b26df
c40b85e
a49ba8d
e1b26df
 
c40b85e
d9d330c
c40b85e
 
d4b4b25
 
e1b26df
 
 
 
d4b4b25
 
e1b26df
 
d4b4b25
e1b26df
 
448af73
e1b26df
 
c40b85e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cd17133
e1b26df
 
cd17133
e1b26df
c40b85e
 
d9d330c
c40b85e
 
 
 
 
331746a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
# import dependencies
import gradio as gr
import tensorflow as tf
import cv2
import numpy as np

# app title
title = "Welcome to your first sketch recognition app!"

# app description
head = (
    "<center>"
    "<img src='mnist-classes.png' width=400><br>"
    "The robot was trained to classify numbers (from 0 to 9). To test it, write your number in the space provided."
    "</center>"
)

# GitHub repository link
ref = "Find the whole code [here](https://github.com/ovh/ai-training-examples/tree/main/apps/gradio/sketch-recognition)."

# image size: 28x28
img_size = 28

# classes name (from 0 to 9)
labels = ["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]

# load model (trained on MNIST dataset)
model = tf.keras.models.load_model("./sketch_recognition_numbers_model.h5")

# prediction function for sketch recognition
def predict(img):
    if img is not None:
        # Convert to numpy array if not already
        img = np.array(img)
        
        # Ensure grayscale
        if len(img.shape) == 3:
            img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
        
        # Resize to required dimensions
        img = cv2.resize(img, (img_size, img_size))
        
        # Normalize and reshape
        img = img.astype('float32') / 255.0
        img = img.reshape(1, img_size, img_size, 1)

        # model predictions
        preds = model.predict(img)[0]

        # return the probability for each class
        return {label: float(pred) for label, pred in zip(labels, preds)}
    return None

# top 3 of classes
label = gr.Label(num_top_classes=3)

# open Gradio interface for sketch recognition
interface = gr.Interface(
    fn=predict, 
    inputs=gr.Sketchpad(height=280, width=280),  # Changed from shape to height and width
    outputs=label, 
    title=title, 
    description=head, 
    article=ref
)
interface.launch()