mnist / app.py
alibayram's picture
Refactor sketch recognition app: update app title and description, streamline image processing, and enhance prediction function
e6fdb4c
raw
history blame
1.73 kB
import numpy as np
import gradio as gr
import tensorflow as tf
import cv2
# App title
title = "Welcome to your first sketch recognition app!"
# App description
head = (
"<center>"
"<img src='./mnist-classes.png' width=400>"
"<p>The model is trained to classify numbers (from 0 to 9). "
"To test it, draw your number in the space provided.</p>"
"</center>"
)
# GitHub repository link
ref = "Find the complete code [here](https://github.com/ovh/ai-training-examples/tree/main/apps/gradio/sketch-recognition)."
# Image size: 28x28
img_size = 28
# Class names (from 0 to 9)
labels = ["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
# Load model (trained on MNIST dataset)
model = tf.keras.models.load_model("./sketch_recognition_numbers_model.h5")
# Prediction function for sketch recognition
def predict(data):
# Extract the 'image' key from the input dictionary
img = data['image']
# Convert to NumPy array
img = np.array(img)
# Convert to grayscale
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
# Resize image to 28x28
img = cv2.resize(img, (img_size, img_size))
# Normalize pixel values
img = img / 255.0
# Reshape image to match model input
img = img.reshape(1, img_size, img_size, 1)
# Model predictions
preds = model.predict(img)[0]
# Return the probability for each class
return {label: float(pred) for label, pred in zip(labels, preds)}
# Top 3 classes
label = gr.Label(num_top_classes=3)
# Open Gradio interface for sketch recognition
interface = gr.Interface(
fn=predict,
inputs=gr.Sketchpad(),
outputs=label,
title=title,
description=head,
article=ref
)
interface.launch()