cisemh commited on
Commit
64c61da
·
verified ·
1 Parent(s): 9a87dce

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -39
app.py CHANGED
@@ -3,48 +3,30 @@ import gradio as gr
3
  import tensorflow as tf
4
  import numpy as np
5
 
6
- title = "Welcome to your first sketch recognition app!"
7
- head = (
8
- "<center>"
9
- "The robot was trained to classify numbers (from 0 to 9). To test it, write your number in the space provided."
10
- "</center>"
11
- )
12
- ref = "Find the whole code [here](https://github.com/ovh/ai-training-examples/tree/main/apps/gradio/sketch-recognition)."
13
 
14
- img_size = 28
15
- labels = ["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
16
 
17
  # Model yükleniyor
18
  model = tf.keras.models.load_model("number_recognition_model_colab.keras")
19
 
20
  def predict(img):
21
- try:
22
- # Girdi görselini NumPy array'e çevir
23
- if not isinstance(img, np.ndarray):
24
- img = np.array(img)
25
-
26
- # Görüntüyü gri tonlamaya çevir ve yeniden boyutlandır
27
- img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) if img.ndim == 3 else img
28
- img = cv2.resize(img, (img_size, img_size))
29
- img = img.astype('float32') / 255.0
30
- img = img.reshape(1, img_size, img_size, 1)
31
-
32
- preds = model.predict(img)[0]
33
-
34
- return {label: float(pred) for label, pred in zip(labels, preds)}
35
- except Exception as e:
36
- return {"Error": str(e)}
37
-
38
- label = gr.Label(num_top_classes=3)
39
-
40
- # Yeni Gradio bileşenleriyle uyumlu hale getirildi
41
- interface = gr.Interface(
42
- fn=predict,
43
- inputs=gr.Sketchpad(label="Draw a number"), # Sketchpad kullanımı
44
- outputs=label,
45
- title=title,
46
- description=head,
47
- article=ref
48
- )
49
-
50
- interface.launch(debug=True)
 
3
  import tensorflow as tf
4
  import numpy as np
5
 
 
 
 
 
 
 
 
6
 
 
 
7
 
8
  # Model yükleniyor
9
  model = tf.keras.models.load_model("number_recognition_model_colab.keras")
10
 
11
  def predict(img):
12
+ # Preprocess the input image
13
+ img = img.reshape(1, 28, 28) / 255.0
14
+
15
+ # Make the prediction
16
+ prediction = model.predict(img)
17
+ predicted_digit = np.argmax(prediction[0])
18
+
19
+ return predicted_digit
20
+
21
+ # Create the Gradio interface
22
+ with gr.Blocks() as demo:
23
+ gr.Markdown("Welcome on your first sketch recognition app!")
24
+
25
+ with gr.Row():
26
+ sketchpad = gr.Sketchpad(shape=(28, 28))
27
+ output = gr.Text()
28
+
29
+ btn = gr.Button("Submit")
30
+ btn.click(predict, inputs=sketchpad, outputs=output)
31
+
32
+ demo.launch()