cisemh commited on
Commit
b304a71
·
verified ·
1 Parent(s): a2c1754

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -63
app.py CHANGED
@@ -1,63 +1,56 @@
1
- import cv2
2
- import gradio as gr
3
- import tensorflow as tf
4
- import numpy as np
5
-
6
- title = "Welcome on your first sketch recognition app!"
7
-
8
- head = (
9
- "<center>"
10
- "The robot was trained to classify numbers (from 0 to 9). To test it, write your number in the space provided."
11
- "</center>"
12
- )
13
-
14
- ref = "Find the whole code [here](https://github.com/ovh/ai-training-examples/tree/main/apps/gradio/sketch-recognition)."
15
-
16
- img_size = 28
17
- labels = ["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
18
-
19
- model = tf.keras.models.load_model("number_recognition_model_colab.keras")
20
-
21
- def predict(img):
22
- try:
23
- # Convert the input image to a NumPy array if needed
24
- if not isinstance(img, np.ndarray):
25
- img = np.array(img)
26
-
27
- # Print shape and type of the input image
28
- print(f"Initial image type: {type(img)}, shape: {img.shape}")
29
-
30
- # Ensure the image is in grayscale and has a single channel
31
- if img.ndim == 3 and img.shape[-1] == 3:
32
- img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
33
- elif img.ndim == 2:
34
- img = np.expand_dims(img, axis=-1)
35
-
36
- # Print the shape of the grayscale image
37
- print(f"Grayscale image shape: {img.shape}")
38
-
39
- # Resize the image
40
- img = cv2.resize(img, (img_size, img_size))
41
-
42
- # Normalize the image
43
- img = img.astype('float32') / 255.0
44
- img = img.reshape(1, img_size, img_size, 1)
45
-
46
- # Print the shape after resizing and normalizing
47
- print(f"Processed image shape: {img.shape}")
48
-
49
- preds = model.predict(img)[0]
50
-
51
- # Print the predictions
52
- print("Predictions:", preds)
53
-
54
- return {label: float(pred) for label, pred in zip(labels, preds)}
55
- except Exception as e:
56
- # Print the exception to the console
57
- print(f"Error during prediction: {e}")
58
- return {"Error": str(e)}
59
-
60
- label = gr.Label(num_top_classes=3)
61
-
62
- interface = gr.Interface(fn=predict, inputs="sketchpad", outputs=label, title=title, description=head, article=ref)
63
- interface.launch(debug=True)
 
1
+ import cv2
2
+ import gradio as gr
3
+ import tensorflow as tf
4
+ import numpy as np
5
+
6
+ # Başlık ve açıklama
7
+ title = "Welcome on your first sketch recognition app!"
8
+ description = (
9
+ "The robot was trained to classify numbers (from 0 to 9). "
10
+ "To test it, write your number in the space provided."
11
+ )
12
+
13
+ article = "Find the whole code [here](https://github.com/ovh/ai-training-examples/tree/main/apps/gradio/sketch-recognition)."
14
+
15
+ # Model parametreleri
16
+ img_size = 28
17
+ labels = ["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
18
+
19
+ # Model yükleme
20
+ model = tf.keras.models.load_model("number_recognition_model_colab.keras")
21
+
22
+ # Tahmin fonksiyonu
23
+ def predict(img):
24
+ try:
25
+ # Görüntüyü işleme
26
+ if not isinstance(img, np.ndarray):
27
+ img = np.array(img)
28
+
29
+ if img.ndim == 3 and img.shape[-1] == 3:
30
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
31
+
32
+ img = cv2.resize(img, (img_size, img_size))
33
+ img = img.astype('float32') / 255.0
34
+ img = img.reshape(1, img_size, img_size, 1)
35
+
36
+ preds = model.predict(img)[0]
37
+
38
+ return {label: float(pred) for label, pred in zip(labels, preds)}
39
+ except Exception as e:
40
+ return {"Error": str(e)}
41
+
42
+ # Gradio arayüzü
43
+ interface = gr.Interface(
44
+ fn=predict,
45
+ inputs=gr.inputs.Sketchpad(label="Draw a number"), # Kullanıcının sayı çizebileceği alan
46
+ outputs=gr.outputs.Label(num_top_classes=3, label="Predicted Number"), # Tahminlerin gösterileceği alan
47
+ title=title,
48
+ description=description,
49
+ article=article,
50
+ allow_flagging="manual", # Flagging'i etkinleştir
51
+ flagging_options=["Incorrect Prediction", "Other Issues"], # Kullanıcıların seçebileceği flag nedenleri
52
+ flagging_dir="flagged_results" # Flag sonuçlarını kaydedeceği klasör
53
+ )
54
+
55
+ # Uygulamayı başlat
56
+ interface.launch()