Spaces:
Sleeping
Sleeping
import numpy as np | |
import gradio as gr | |
import tensorflow as tf | |
import cv2 | |
# Load the trained MNIST model | |
model = tf.keras.models.load_model("./number_recognition_model_colab.keras") | |
# Class names (0 to 9) | |
labels = ["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"] | |
def predict(data): | |
# Extract the 'composite' key from the input dictionary | |
img = data["composite"] | |
img = np.array(img) | |
# Convert RGBA to RGB if needed | |
if img.shape[-1] == 4: # RGBA | |
img = cv2.cvtColor(img, cv2.COLOR_RGBA2RGB) | |
# Convert RGB to Grayscale | |
if img.shape[-1] == 3: # RGB | |
img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) | |
# Resize image to 28x28 | |
img = cv2.resize(img, (28, 28)) | |
# Normalize pixel values to [0, 1] | |
img = img / 255.0 | |
# Reshape to match model input (1, 28, 28, 1) | |
img = img.reshape(1, 28, 28, 1) | |
# Model predictions | |
preds = model.predict(img)[0] | |
print(preds) | |
# Get top 3 classes | |
top_3_classes = np.argsort(preds)[-3:][::-1] | |
top_3_probs = preds[top_3_classes] | |
class_names = [labels[i] for i in top_3_classes] | |
print(class_names, top_3_probs, top_3_classes) | |
# Return top 3 predictions as a dictionary | |
return {class_names[i]: float(top_3_probs[i]) for i in range(3)} | |
# Title and description | |
title = "Welcome to your first sketch recognition app!" | |
head = ( | |
"<center>" | |
"<img src='./mnist-classes.png' width=400>" | |
"<p>The model is trained to classify numbers (from 0 to 9). " | |
"To test it, draw your number in the space provided (use the editing tools in the image editor).</p>" | |
"</center>" | |
) | |
with gr.Blocks(title=title) as demo: | |
# Display title and description | |
gr.Markdown(head) | |
gr.Markdown(ref) | |
with gr.Row(): | |
# Using ImageEditor with type='numpy' | |
im = gr.Sketchpad(type="numpy", label="Draw your digit here (use brush and eraser)") | |
# Output label (top 3 predictions) | |
label = gr.Label(num_top_classes=3, label="Predictions") | |
# Trigger prediction whenever the image changes | |
im.change(predict, inputs=im, outputs=label) | |
demo.launch(share=True) |