Mnist-Digits / app.py
cisemh's picture
Update app.py
d1c466d verified
raw
history blame
3.17 kB
import numpy as np
import tensorflow as tf
from tensorflow import keras
import gradio as gr
# Load and preprocess the MNIST dataset
def load_data():
"""Load and preprocess the MNIST dataset."""
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.mnist.load_data()
X_train = X_train.astype("float32") / 255
X_test = X_test.astype("float32") / 255
X_train = X_train.reshape(-1, 28, 28, 1)
X_test = X_test.reshape(-1, 28, 28, 1)
return (X_train, y_train), (X_test, y_test)
# Build the CNN model
def build_model(input_shape, num_classes):
"""Build the CNN model."""
inputs = keras.layers.Input(input_shape)
x = keras.layers.Conv2D(28, kernel_size=(3, 3), activation='relu')(inputs)
x = keras.layers.MaxPooling2D(pool_size=(2, 2))(x)
x = keras.layers.Flatten()(x)
x = keras.layers.Dense(128, activation='relu')(x)
outputs = keras.layers.Dense(num_classes, activation='softmax')(x)
return keras.models.Model(inputs=inputs, outputs=outputs)
# Preprocess input for prediction
def preprocess_image(image):
"""Resize and normalize the input image for prediction."""
image = np.array(image.convert('L')) # Convert to grayscale
image = image.astype("float32") / 255 # Normalize
image = image.reshape(1, 28, 28, 1) # Reshape to model's input
return image
# Predict digit
def predict_digit(image):
"""Predict the digit in the uploaded image."""
processed_image = preprocess_image(image)
prediction = model.predict(processed_image)
class_id = np.argmax(prediction)
confidence = prediction[0][class_id]
label = classes_names[class_id]
results = {name: float(prediction[0][i]) for i, name in enumerate(classes_names)}
return label, results
if __name__ == "__main__":
# Parameters
classes_names = ["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
input_shape = (28, 28, 1)
num_classes = len(classes_names)
# Load data
(X_train, y_train), (X_test, y_test) = load_data()
# Build and train model
model = build_model(input_shape, num_classes)
model.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"])
print("Training model...")
model.fit(X_train, y_train, epochs=3, batch_size=64) # Quick training for demonstration
# Gradio Interface
title = "Welcome to Your First Sketch Recognition App!"
description = (
"The robot was trained to classify numbers (from 0 to 9). To test it, draw your number in the space provided."
)
examples = [["example_image.png"]] # You can add example images here.
interface = gr.Interface(
fn=predict_digit,
inputs=gr.inputs.Image(shape=(28, 28), image_mode="L", invert_colors=True, label="Draw a Digit"),
outputs=[
gr.outputs.Textbox(label="Predicted Digit"),
gr.outputs.Label(num_top_classes=10, label="Prediction Confidence"),
],
title=title,
description=description,
examples=examples,
live=True,
)
# Launch Gradio interface
print("Launching Gradio interface...")
interface.launch()