Mnist-Digits / app.py
cisemh's picture
Update app.py
babe4e3 verified
raw
history blame
2.87 kB
import gradio as gr
import numpy as np
import tensorflow as tf
from PIL import Image
# Load the pre-trained model
model = tf.keras.models.load_model('number_recognition_model_colab.h5')
# Class names for MNIST digits
classes_names = ["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
def preprocess_image(image):
"""
Preprocess the input image to match the model's expected input
- Resize to 28x28
- Convert to grayscale
- Normalize pixel values
"""
# Convert to grayscale if image is RGB
if len(image.shape) == 3:
image = np.mean(image, axis=2)
# Resize to 28x28
image = Image.fromarray(image.astype('uint8')).resize((28, 28), Image.LANCZOS)
image = np.array(image)
# Normalize
image = image.astype("float32") / 255.0
# Reshape to match model input shape
image = image.reshape(1, 28, 28, 1)
return image
def predict_digit(image):
"""
Predict the digit in the input image
"""
# Preprocess the image
processed_image = preprocess_image(image)
# Make prediction
predictions = model.predict(processed_image)
# Get the predicted class
predicted_class = np.argmax(predictions[0])
predicted_label = classes_names[predicted_class]
confidence = predictions[0][predicted_class] * 100
# Create a more detailed output
result_text = f"Predicted Digit: {predicted_label}\nConfidence: {confidence:.2f}%"
# Create bar chart of probabilities
probabilities = predictions[0] * 100
return result_text, probabilities
def create_probability_plot(probabilities):
"""
Create a bar plot of digit probabilities
"""
import matplotlib.pyplot as plt
plt.figure(figsize=(10, 5))
plt.bar(classes_names, probabilities)
plt.title('Digit Probability Distribution')
plt.xlabel('Digits')
plt.ylabel('Probability (%)')
plt.ylim(0, 100)
# Rotate x-axis labels
plt.xticks(rotation=45)
return plt
# Create Gradio interface
def gradio_predict(image):
"""
Wrapper function for Gradio interface
"""
result_text, probabilities = predict_digit(image)
prob_plot = create_probability_plot(probabilities)
return result_text, prob_plot
# Set up the Gradio interface
iface = gr.Interface(
fn=gradio_predict,
inputs=gr.Image(type="numpy", image_mode="L"),
outputs=[
gr.Textbox(label="Prediction"),
gr.Plot(label="Probability Distribution")
],
title="MNIST Digit Recognizer",
description="Draw a single-digit number (0-9) and the model will predict which digit it is!",
allow_flagging="never",
examples=[
["example_zero.png"],
["example_one.png"],
["example_two.png"]
]
)
# Launch the interface
if __name__ == "__main__":
iface.launch()