cisemh commited on
Commit
c80c536
·
verified ·
1 Parent(s): 609593a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -18
app.py CHANGED
@@ -1,21 +1,46 @@
 
 
1
  import tensorflow as tf
2
  import numpy as np
3
- from urllib.request import urlretrieve
4
- import gradio as gr
5
 
6
- urlretrieve("https://gr-models.s3-us-west-2.amazonaws.com/mnist-model.h5", "mnist-model.h5")
7
- model = tf.keras.models.load_model("mnist-model.h5")
8
-
9
- def recognize_digit(image):
10
- image = image.reshape(1, -1) # add a batch dimension
11
- prediction = model.predict(image).tolist()[0]
12
- return {str(i): prediction[i] for i in range(10)}
13
-
14
- gr.Interface(fn=recognize_digit,
15
- inputs="sketchpad",
16
- outputs=gr.outputs.Label(num_top_classes=3),
17
- live=True,
18
- css=".footer {display:none !important}",
19
- # title="MNIST Sketchpad",
20
- description="Draw a number 0 through 9 on the sketchpad, and see predictions in real time.",
21
- thumbnail="https://raw.githubusercontent.com/gradio-app/real-time-mnist/master/thumbnail2.png").launch();
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import gradio as gr
3
  import tensorflow as tf
4
  import numpy as np
 
 
5
 
6
+ # Model yükleniyor
7
+ model = tf.keras.models.load_model("number_recognition_model_colab.keras")
8
+
9
+ # Etiketler (0'dan 9'a kadar sayılar)
10
+ labels = ["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
11
+
12
+ # Tahmin fonksiyonu
13
+ def predict(img):
14
+ try:
15
+ # Görüntüyü gri tonlamaya dönüştür
16
+ if img.ndim == 3 and img.shape[-1] == 3:
17
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
18
+ elif img.ndim == 2:
19
+ img = np.expand_dims(img, axis=-1)
20
+
21
+ # Görüntüyü yeniden boyutlandır ve normalize et
22
+ img = cv2.resize(img, (28, 28))
23
+ img = img.astype('float32') / 255.0
24
+ img = img.reshape(1, 28, 28, 1)
25
+
26
+ # Modelden tahmin al
27
+ preds = model.predict(img)[0]
28
+
29
+ # Tahmin sonuçlarını formatla
30
+ sorted_preds = sorted(zip(labels, preds), key=lambda x: x[1], reverse=True)[:3]
31
+ formatted_preds = "\n".join([f"{label}: {prob:.2f}" for label, prob in sorted_preds])
32
+
33
+ return formatted_preds
34
+ except Exception as e:
35
+ return f"Error: {e}"
36
+
37
+ # Gradio arayüzü
38
+ interface = gr.Interface(
39
+ fn=predict,
40
+ inputs="sketchpad",
41
+ outputs="textbox",
42
+ title="Sketch Recognition App",
43
+ description="Draw a number (0-9) and see the model's top predictions."
44
+ )
45
+
46
+ interface.launch(debug=True)