File size: 1,693 Bytes
051f92c
a49ba8d
d4b4b25
a49ba8d
e6fdb4c
 
a49ba8d
e6fdb4c
e1b26df
e6fdb4c
 
 
 
 
d4b4b25
 
e1b26df
e6fdb4c
e1b26df
e6fdb4c
d4b4b25
 
e6fdb4c
e1b26df
d4b4b25
e6fdb4c
 
c40b85e
e6fdb4c
 
9cc90a0
 
 
 
 
 
 
 
 
 
 
 
 
 
051f92c
e6fdb4c
 
 
 
c40b85e
051f92c
9cc90a0
e6fdb4c
051f92c
 
c40b85e
 
331746a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import numpy as np
import gradio as gr
import tensorflow as tf

# App title
title = "Welcome to your first sketch recognition app!"

# App description
head = (
    "<center>"
    "<img src='./mnist-classes.png' width=400>"
    "<p>The model is trained to classify numbers (from 0 to 9). "
    "To test it, draw your number in the space provided.</p>"
    "</center>"
)

# GitHub repository link
ref = "Find the complete code [here](https://github.com/ovh/ai-training-examples/tree/main/apps/gradio/sketch-recognition)."

# Image size: 28x28
img_size = 28

# Class names (from 0 to 9)
labels = ["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]

# Load model (trained on MNIST dataset)
model = tf.keras.models.load_model("./sketch_recognition_numbers_model.h5")

# Prediction function for sketch recognition
def predict(data):
    # Reshape image to 28x28
    img = np.reshape(data, (1, img_size, img_size, 1))
    # Make prediction
    pred = model.predict(img)
    # Get top class
    top_class = np.argmax
    # Get top 3 classes
    top_3_classes = np.argsort(pred[0])[-3:][::-1]
    # Get top 3 probabilities
    top_3_probs = pred[0][top_3_classes]
    # Get class names
    class_names = [labels[i] for i in top_3_classes]
    # Return class names and probabilities
    return {class_names[i]: top_3_probs[i] for i in range(3)}

# Top 3 classes
label = gr.Label(num_top_classes=3)

# Open Gradio interface for sketch recognition
interface = gr.Interface(
    fn=predict,
    inputs=gr.Sketchpad(crop_size=(28,28), type='numpy', image_mode='L', brush=gr.Brush()),
    outputs=label,
    title=title,
    description=head,
    article=ref
)
interface.launch()