mnist / app.py
alibayram's picture
Enhance sketch recognition app: add image to app description and update sketchpad input parameters
d9d330c
raw
history blame
1.88 kB
# import dependencies
import gradio as gr
import tensorflow as tf
import cv2
import numpy as np
# app title
title = "Welcome to your first sketch recognition app!"
# app description
head = (
"<center>"
"<img src='mnist-classes.png' width=400><br>"
"The robot was trained to classify numbers (from 0 to 9). To test it, write your number in the space provided."
"</center>"
)
# GitHub repository link
ref = "Find the whole code [here](https://github.com/ovh/ai-training-examples/tree/main/apps/gradio/sketch-recognition)."
# image size: 28x28
img_size = 28
# classes name (from 0 to 9)
labels = ["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
# load model (trained on MNIST dataset)
model = tf.keras.models.load_model("./sketch_recognition_numbers_model.h5")
# prediction function for sketch recognition
def predict(img):
if img is not None:
# Convert to numpy array if not already
img = np.array(img)
# Ensure grayscale
if len(img.shape) == 3:
img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
# Resize to required dimensions
img = cv2.resize(img, (img_size, img_size))
# Normalize and reshape
img = img.astype('float32') / 255.0
img = img.reshape(1, img_size, img_size, 1)
# model predictions
preds = model.predict(img)[0]
# return the probability for each class
return {label: float(pred) for label, pred in zip(labels, preds)}
return None
# top 3 of classes
label = gr.Label(num_top_classes=3)
# open Gradio interface for sketch recognition
interface = gr.Interface(
fn=predict,
inputs=gr.Sketchpad(height=280, width=280), # Changed from shape to height and width
outputs=label,
title=title,
description=head,
article=ref
)
interface.launch()