File size: 1,693 Bytes
051f92c a49ba8d d4b4b25 a49ba8d e6fdb4c a49ba8d e6fdb4c e1b26df e6fdb4c d4b4b25 e1b26df e6fdb4c e1b26df e6fdb4c d4b4b25 e6fdb4c e1b26df d4b4b25 e6fdb4c c40b85e e6fdb4c 9cc90a0 051f92c e6fdb4c c40b85e 051f92c 9cc90a0 e6fdb4c 051f92c c40b85e 331746a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
import numpy as np
import gradio as gr
import tensorflow as tf
# App title
title = "Welcome to your first sketch recognition app!"
# App description
head = (
"<center>"
"<img src='./mnist-classes.png' width=400>"
"<p>The model is trained to classify numbers (from 0 to 9). "
"To test it, draw your number in the space provided.</p>"
"</center>"
)
# GitHub repository link
ref = "Find the complete code [here](https://github.com/ovh/ai-training-examples/tree/main/apps/gradio/sketch-recognition)."
# Image size: 28x28
img_size = 28
# Class names (from 0 to 9)
labels = ["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
# Load model (trained on MNIST dataset)
model = tf.keras.models.load_model("./sketch_recognition_numbers_model.h5")
# Prediction function for sketch recognition
def predict(data):
# Reshape image to 28x28
img = np.reshape(data, (1, img_size, img_size, 1))
# Make prediction
pred = model.predict(img)
# Get top class
top_class = np.argmax
# Get top 3 classes
top_3_classes = np.argsort(pred[0])[-3:][::-1]
# Get top 3 probabilities
top_3_probs = pred[0][top_3_classes]
# Get class names
class_names = [labels[i] for i in top_3_classes]
# Return class names and probabilities
return {class_names[i]: top_3_probs[i] for i in range(3)}
# Top 3 classes
label = gr.Label(num_top_classes=3)
# Open Gradio interface for sketch recognition
interface = gr.Interface(
fn=predict,
inputs=gr.Sketchpad(crop_size=(28,28), type='numpy', image_mode='L', brush=gr.Brush()),
outputs=label,
title=title,
description=head,
article=ref
)
interface.launch() |