Mnist-Digits / app.py
cisemh's picture
Update app.py
c80c536 verified
raw
history blame
1.4 kB
import cv2
import gradio as gr
import tensorflow as tf
import numpy as np
# Model yükleniyor
model = tf.keras.models.load_model("number_recognition_model_colab.keras")
# Etiketler (0'dan 9'a kadar sayılar)
labels = ["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
# Tahmin fonksiyonu
def predict(img):
try:
# Görüntüyü gri tonlamaya dönüştür
if img.ndim == 3 and img.shape[-1] == 3:
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
elif img.ndim == 2:
img = np.expand_dims(img, axis=-1)
# Görüntüyü yeniden boyutlandır ve normalize et
img = cv2.resize(img, (28, 28))
img = img.astype('float32') / 255.0
img = img.reshape(1, 28, 28, 1)
# Modelden tahmin al
preds = model.predict(img)[0]
# Tahmin sonuçlarını formatla
sorted_preds = sorted(zip(labels, preds), key=lambda x: x[1], reverse=True)[:3]
formatted_preds = "\n".join([f"{label}: {prob:.2f}" for label, prob in sorted_preds])
return formatted_preds
except Exception as e:
return f"Error: {e}"
# Gradio arayüzü
interface = gr.Interface(
fn=predict,
inputs="sketchpad",
outputs="textbox",
title="Sketch Recognition App",
description="Draw a number (0-9) and see the model's top predictions."
)
interface.launch(debug=True)