import gradio as gr import numpy as np import tensorflow as tf from PIL import Image # Load the pre-trained model model = tf.keras.models.load_model('number_recognition_model_colab.h5') # Class names for MNIST digits classes_names = ["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"] def preprocess_image(image): """ Preprocess the input image to match the model's expected input - Resize to 28x28 - Convert to grayscale - Normalize pixel values """ # Convert to grayscale if image is RGB if len(image.shape) == 3: image = np.mean(image, axis=2) # Resize to 28x28 image = Image.fromarray(image.astype('uint8')).resize((28, 28), Image.LANCZOS) image = np.array(image) # Normalize image = image.astype("float32") / 255.0 # Reshape to match model input shape image = image.reshape(1, 28, 28, 1) return image def predict_digit(image): """ Predict the digit in the input image """ # Preprocess the image processed_image = preprocess_image(image) # Make prediction predictions = model.predict(processed_image) # Get the predicted class predicted_class = np.argmax(predictions[0]) predicted_label = classes_names[predicted_class] confidence = predictions[0][predicted_class] * 100 # Create a more detailed output result_text = f"Predicted Digit: {predicted_label}\nConfidence: {confidence:.2f}%" # Create bar chart of probabilities probabilities = predictions[0] * 100 return result_text, probabilities def create_probability_plot(probabilities): """ Create a bar plot of digit probabilities """ import matplotlib.pyplot as plt plt.figure(figsize=(10, 5)) plt.bar(classes_names, probabilities) plt.title('Digit Probability Distribution') plt.xlabel('Digits') plt.ylabel('Probability (%)') plt.ylim(0, 100) # Rotate x-axis labels plt.xticks(rotation=45) return plt # Create Gradio interface def gradio_predict(image): """ Wrapper function for Gradio interface """ result_text, probabilities = predict_digit(image) prob_plot = create_probability_plot(probabilities) return result_text, prob_plot # Set up the Gradio interface iface = gr.Interface( fn=gradio_predict, inputs=gr.Image(type="numpy", image_mode="L"), outputs=[ gr.Textbox(label="Prediction"), gr.Plot(label="Probability Distribution") ], title="MNIST Digit Recognizer", description="Draw a single-digit number (0-9) and the model will predict which digit it is!", allow_flagging="never", examples=[ ["example_zero.png"], ["example_one.png"], ["example_two.png"] ] ) # Launch the interface if __name__ == "__main__": iface.launch()