mnist / app.py
alibayram's picture
Refactor predict function: streamline image extraction and add debug prints for image shape and content
676005c
raw
history blame
2.14 kB
import numpy as np
import gradio as gr
import tensorflow as tf
import cv2
# App title
title = "Welcome to your first sketch recognition app!"
# App description
head = (
"<center>"
"<img src='./mnist-classes.png' width=400>"
"<p>The model is trained to classify numbers (from 0 to 9). "
"To test it, draw your number in the space provided.</p>"
"</center>"
)
# GitHub repository link
ref = "Find the complete code [here](https://github.com/ovh/ai-training-examples/tree/main/apps/gradio/sketch-recognition)."
# Class names (from 0 to 9)
labels = {
0: "zero",
1: "one",
2: "two",
3: "three",
4: "four",
5: "five",
6: "six",
7: "seven",
8: "eight",
9: "nine"
}
# Load model (trained on MNIST dataset)
model = tf.keras.models.load_model("./sketch_recognition_numbers_model.h5")
def predict(data):
# Convert to NumPy array
img = np.array(data['composite'])
print("img.shape", img.shape)
# Handle RGBA or RGB images
if img.shape[-1] == 4: # RGBA
img = cv2.cvtColor(img, cv2.COLOR_RGBA2RGB)
if img.shape[-1] == 3: # RGB
img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
# Resize image to 28x28
img = cv2.resize(img, (28, 28))
# Normalize pixel values to [0, 1]
img = img / 255.0
# Reshape to match model input
img = img.reshape(1, 28, 28, 1)
print("img", img)
# Model predictions
preds = model.predict(img)[0]
print("preds", preds)
values_map = {preds[i]: i for i in range(len(preds))}
sorted_values = sorted(preds, reverse=True)
labels_map = dict()
for i in range(3):
print("sorted_values[i]", sorted_values[i], values_map[sorted_values[i]])
labels_map[labels[values_map[sorted_values[i]]]] = sorted_values[i]
print("labels_map", labels_map)
return labels_map
# Top 3 classes
label = gr.Label(num_top_classes=3)
# Open Gradio interface for sketch recognition
interface = gr.Interface(
fn=predict,
inputs=gr.Sketchpad(type='numpy'),
outputs=label,
title=title,
description=head,
article=ref
)
interface.launch(share=True)