cisemh commited on
Commit
c25210a
·
verified ·
1 Parent(s): b11eeae

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -29
app.py CHANGED
@@ -3,19 +3,15 @@ import gradio as gr
3
  import tensorflow as tf
4
  import numpy as np
5
 
6
- title = "Welcome on your first sketch recognition app!"
7
-
8
- head = (
9
- "<center>"
10
- "The robot was trained to classify numbers (from 0 to 9). To test it, write your number in the space provided."
11
- "</center>"
12
- )
13
-
14
- ref = "Find the whole code [here](https://github.com/ovh/ai-training-examples/tree/main/apps/gradio/sketch-recognition)."
15
 
 
16
  img_size = 28
17
  labels = ["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
18
 
 
19
  model = tf.keras.models.load_model("number_recognition_model_colab.keras")
20
 
21
  def predict(img):
@@ -24,41 +20,42 @@ def predict(img):
24
  if not isinstance(img, np.ndarray):
25
  img = np.array(img)
26
 
27
- # Print shape and type of the input image
28
- print(f"Initial image type: {type(img)}, shape: {img.shape}")
29
-
30
- # Ensure the image is in grayscale and has a single channel
31
  if img.ndim == 3 and img.shape[-1] == 3:
32
  img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
33
  elif img.ndim == 2:
34
  img = np.expand_dims(img, axis=-1)
35
 
36
- # Print the shape of the grayscale image
37
- print(f"Grayscale image shape: {img.shape}")
38
-
39
- # Resize the image
40
  img = cv2.resize(img, (img_size, img_size))
41
 
42
  # Normalize the image
43
  img = img.astype('float32') / 255.0
44
  img = img.reshape(1, img_size, img_size, 1)
45
 
46
- # Print the shape after resizing and normalizing
47
- print(f"Processed image shape: {img.shape}")
48
-
49
  preds = model.predict(img)[0]
50
 
51
- # Tahmin sonuçlarını formatla
52
- sorted_preds = sorted(zip(labels, preds), key=lambda x: x[1], reverse=True)[:3]
53
- formatted_preds = "\n".join([f"{label}: {prob:.2f}" for label, prob in sorted_preds])
54
 
55
- return formatted_preds
56
  except Exception as e:
57
- # Print the exception to the console
58
- print(f"Error during prediction: {e}")
59
  return {"Error": str(e)}
60
 
61
- label = gr.Label(num_top_classes=3)
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
- interface = gr.Interface(fn=predict, inputs="sketchpad", outputs=label, title=title, description=head, article=ref)
64
- interface.launch(debug=True)
 
3
  import tensorflow as tf
4
  import numpy as np
5
 
6
+ # Title and description for the interface
7
+ title = "Welcome to your first sketch recognition app!"
8
+ head = "<center>The robot was trained to classify numbers (0 to 9). To test it, write your number in the space provided.</center>"
 
 
 
 
 
 
9
 
10
+ # Image size and label mapping
11
  img_size = 28
12
  labels = ["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
13
 
14
+ # Load the trained model
15
  model = tf.keras.models.load_model("number_recognition_model_colab.keras")
16
 
17
  def predict(img):
 
20
  if not isinstance(img, np.ndarray):
21
  img = np.array(img)
22
 
23
+ # Convert the image to grayscale if it's not already
 
 
 
24
  if img.ndim == 3 and img.shape[-1] == 3:
25
  img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
26
  elif img.ndim == 2:
27
  img = np.expand_dims(img, axis=-1)
28
 
29
+ # Resize the image to the expected input size
 
 
 
30
  img = cv2.resize(img, (img_size, img_size))
31
 
32
  # Normalize the image
33
  img = img.astype('float32') / 255.0
34
  img = img.reshape(1, img_size, img_size, 1)
35
 
36
+ # Get predictions from the model
 
 
37
  preds = model.predict(img)[0]
38
 
39
+ # Return the predicted probabilities for each class
40
+ return {label: float(pred) for label, pred in zip(labels, preds)}
 
41
 
 
42
  except Exception as e:
 
 
43
  return {"Error": str(e)}
44
 
45
+ # Use a sketchpad as input for drawing
46
+ input_component = gr.Sketchpad()
47
+
48
+ # Output will show the top 3 predicted classes
49
+ output_component = gr.Label(num_top_classes=3)
50
+
51
+ # Create the Gradio interface
52
+ interface = gr.Interface(
53
+ fn=predict,
54
+ inputs=input_component,
55
+ outputs=output_component,
56
+ title=title,
57
+ description=head
58
+ )
59
 
60
+ # Launch the interface
61
+ interface.launch(debug=True)