cisemh commited on
Commit
6e54a07
·
verified ·
1 Parent(s): 95004a4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -21
app.py CHANGED
@@ -3,38 +3,61 @@ import gradio as gr
3
  import tensorflow as tf
4
  import numpy as np
5
 
6
- # Load the trained model
7
- model = tf.keras.models.load_model("number_recognition_model_colab.keras")
 
 
 
 
 
 
 
8
 
9
- # Image size and labels
10
  img_size = 28
11
  labels = ["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
12
 
13
- # Prediction function
 
14
  def predict(img):
15
  try:
16
- # Ensure the image is a NumPy array and grayscale
17
- img = cv2.cvtColor(np.array(img), cv2.COLOR_BGR2GRAY)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  img = cv2.resize(img, (img_size, img_size))
19
- img = img.astype("float32") / 255.0
 
 
20
  img = img.reshape(1, img_size, img_size, 1)
21
 
22
- # Make predictions
 
 
23
  preds = model.predict(img)[0]
 
 
 
 
24
  return {label: float(pred) for label, pred in zip(labels, preds)}
25
  except Exception as e:
 
 
26
  return {"Error": str(e)}
27
 
28
- # Gradio interface
29
- if __name__ == "__main__":
30
- demo = gr.Interface(
31
- fn=predict, # Function to call
32
- inputs=gr.Sketchpad(label="Draw a number"), # Sketchpad input
33
- outputs=gr.Label(num_top_classes=3), # Label output to show top 3 predictions
34
- title="Number Recognition App",
35
- description=(
36
- "The model was trained to classify numbers (from 0 to 9). "
37
- "Draw a number in the sketchpad below and see the prediction!"
38
- )
39
- )
40
- demo.launch()
 
3
  import tensorflow as tf
4
  import numpy as np
5
 
6
+ title = "Welcome on your first sketch recognition app!"
7
+
8
+ head = (
9
+ "<center>"
10
+ "The robot was trained to classify numbers (from 0 to 9). To test it, write your number in the space provided."
11
+ "</center>"
12
+ )
13
+
14
+ ref = "Find the whole code [here](https://github.com/ovh/ai-training-examples/tree/main/apps/gradio/sketch-recognition)."
15
 
 
16
  img_size = 28
17
  labels = ["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
18
 
19
+ model = tf.keras.models.load_model("number_recognition_model_colab.keras")
20
+
21
  def predict(img):
22
  try:
23
+ # Convert the input image to a NumPy array if needed
24
+ if not isinstance(img, np.ndarray):
25
+ img = np.array(img)
26
+
27
+ # Print shape and type of the input image
28
+ print(f"Initial image type: {type(img)}, shape: {img.shape}")
29
+
30
+ # Ensure the image is in grayscale and has a single channel
31
+ if img.ndim == 3 and img.shape[-1] == 3:
32
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
33
+ elif img.ndim == 2:
34
+ img = np.expand_dims(img, axis=-1)
35
+
36
+ # Print the shape of the grayscale image
37
+ print(f"Grayscale image shape: {img.shape}")
38
+
39
+ # Resize the image
40
  img = cv2.resize(img, (img_size, img_size))
41
+
42
+ # Normalize the image
43
+ img = img.astype('float32') / 255.0
44
  img = img.reshape(1, img_size, img_size, 1)
45
 
46
+ # Print the shape after resizing and normalizing
47
+ print(f"Processed image shape: {img.shape}")
48
+
49
  preds = model.predict(img)[0]
50
+
51
+ # Print the predictions
52
+ print("Predictions:", preds)
53
+
54
  return {label: float(pred) for label, pred in zip(labels, preds)}
55
  except Exception as e:
56
+ # Print the exception to the console
57
+ print(f"Error during prediction: {e}")
58
  return {"Error": str(e)}
59
 
60
+ label = gr.Label(num_top_classes=3)
61
+
62
+ interface = gr.Interface(fn=predict, inputs="sketchpad", outputs=label, title=title, description=head, article=ref)
63
+ interface.launch(debug=True)