Spaces:
Sleeping
Sleeping
import gradio as gr | |
import numpy as np | |
import tensorflow as tf | |
from PIL import Image | |
# Load the pre-trained model | |
model = tf.keras.models.load_model('number_recognition_model_colab.h5') | |
# Class names for MNIST digits | |
classes_names = ["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"] | |
def preprocess_image(image): | |
""" | |
Preprocess the input image to match the model's expected input | |
- Resize to 28x28 | |
- Convert to grayscale | |
- Normalize pixel values | |
""" | |
# Convert to grayscale if image is RGB | |
if len(image.shape) == 3: | |
image = np.mean(image, axis=2) | |
# Resize to 28x28 | |
image = Image.fromarray(image.astype('uint8')).resize((28, 28), Image.LANCZOS) | |
image = np.array(image) | |
# Normalize | |
image = image.astype("float32") / 255.0 | |
# Reshape to match model input shape | |
image = image.reshape(1, 28, 28, 1) | |
return image | |
def predict_digit(image): | |
""" | |
Predict the digit in the input image | |
""" | |
# Preprocess the image | |
processed_image = preprocess_image(image) | |
# Make prediction | |
predictions = model.predict(processed_image) | |
# Get the predicted class | |
predicted_class = np.argmax(predictions[0]) | |
predicted_label = classes_names[predicted_class] | |
confidence = predictions[0][predicted_class] * 100 | |
# Create a more detailed output | |
result_text = f"Predicted Digit: {predicted_label}\nConfidence: {confidence:.2f}%" | |
# Create bar chart of probabilities | |
probabilities = predictions[0] * 100 | |
return result_text, probabilities | |
def create_probability_plot(probabilities): | |
""" | |
Create a bar plot of digit probabilities | |
""" | |
import matplotlib.pyplot as plt | |
plt.figure(figsize=(10, 5)) | |
plt.bar(classes_names, probabilities) | |
plt.title('Digit Probability Distribution') | |
plt.xlabel('Digits') | |
plt.ylabel('Probability (%)') | |
plt.ylim(0, 100) | |
# Rotate x-axis labels | |
plt.xticks(rotation=45) | |
return plt | |
# Create Gradio interface | |
def gradio_predict(image): | |
""" | |
Wrapper function for Gradio interface | |
""" | |
result_text, probabilities = predict_digit(image) | |
prob_plot = create_probability_plot(probabilities) | |
return result_text, prob_plot | |
# Set up the Gradio interface | |
iface = gr.Interface( | |
fn=gradio_predict, | |
inputs=gr.Image(type="numpy", image_mode="L"), | |
outputs=[ | |
gr.Textbox(label="Prediction"), | |
gr.Plot(label="Probability Distribution") | |
], | |
title="MNIST Digit Recognizer", | |
description="Draw a single-digit number (0-9) and the model will predict which digit it is!", | |
allow_flagging="never", | |
examples=[ | |
["example_zero.png"], | |
["example_one.png"], | |
["example_two.png"] | |
] | |
) | |
# Launch the interface | |
if __name__ == "__main__": | |
iface.launch() |