Spaces:
Sleeping
Sleeping
import numpy as np | |
import tensorflow as tf | |
from tensorflow import keras | |
import gradio as gr | |
# Load and preprocess the MNIST dataset | |
def load_data(): | |
"""Load and preprocess the MNIST dataset.""" | |
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.mnist.load_data() | |
X_train = X_train.astype("float32") / 255 | |
X_test = X_test.astype("float32") / 255 | |
X_train = X_train.reshape(-1, 28, 28, 1) | |
X_test = X_test.reshape(-1, 28, 28, 1) | |
return (X_train, y_train), (X_test, y_test) | |
# Build the CNN model | |
def build_model(input_shape, num_classes): | |
"""Build the CNN model.""" | |
inputs = keras.layers.Input(input_shape) | |
x = keras.layers.Conv2D(28, kernel_size=(3, 3), activation='relu')(inputs) | |
x = keras.layers.MaxPooling2D(pool_size=(2, 2))(x) | |
x = keras.layers.Flatten()(x) | |
x = keras.layers.Dense(128, activation='relu')(x) | |
outputs = keras.layers.Dense(num_classes, activation='softmax')(x) | |
return keras.models.Model(inputs=inputs, outputs=outputs) | |
# Preprocess input for prediction | |
def preprocess_image(image): | |
"""Resize and normalize the input image for prediction.""" | |
image = np.array(image.convert('L')) # Convert to grayscale | |
image = image.astype("float32") / 255 # Normalize | |
image = image.reshape(1, 28, 28, 1) # Reshape to model's input | |
return image | |
# Predict digit | |
def predict_digit(image): | |
"""Predict the digit in the uploaded image.""" | |
processed_image = preprocess_image(image) | |
prediction = model.predict(processed_image) | |
class_id = np.argmax(prediction) | |
confidence = prediction[0][class_id] | |
label = classes_names[class_id] | |
results = {name: float(prediction[0][i]) for i, name in enumerate(classes_names)} | |
return label, results | |
if __name__ == "__main__": | |
# Parameters | |
classes_names = ["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"] | |
input_shape = (28, 28, 1) | |
num_classes = len(classes_names) | |
# Load data | |
(X_train, y_train), (X_test, y_test) = load_data() | |
# Build and train model | |
model = build_model(input_shape, num_classes) | |
model.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"]) | |
print("Training model...") | |
model.fit(X_train, y_train, epochs=3, batch_size=64) # Quick training for demonstration | |
# Gradio Interface | |
title = "Welcome to Your First Sketch Recognition App!" | |
description = ( | |
"The robot was trained to classify numbers (from 0 to 9). To test it, draw your number in the space provided." | |
) | |
examples = [["example_image.png"]] # You can add example images here. | |
interface = gr.Interface( | |
fn=predict_digit, | |
inputs=gr.inputs.Image(shape=(28, 28), image_mode="L", invert_colors=True, label="Draw a Digit"), | |
outputs=[ | |
gr.outputs.Textbox(label="Predicted Digit"), | |
gr.outputs.Label(num_top_classes=10, label="Prediction Confidence"), | |
], | |
title=title, | |
description=description, | |
examples=examples, | |
live=True, | |
) | |
# Launch Gradio interface | |
print("Launching Gradio interface...") | |
interface.launch() | |