cisemh commited on
Commit
babe4e3
·
verified ·
1 Parent(s): 7aba1cd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +92 -89
app.py CHANGED
@@ -1,102 +1,105 @@
1
- import cv2
2
  import gradio as gr
3
- import tensorflow as tf
4
  import numpy as np
 
5
  from PIL import Image
6
 
7
- title = "Welcome on your first sketch recognition app!"
8
-
9
- head = (
10
- "<center>"
11
- "The robot was trained to classify numbers (from 0 to 9). To test it, write your number in the space provided."
12
- "</center>"
13
- )
14
-
15
- # Model yükleniyor
16
- model = tf.keras.models.load_model("number_recognition_model_colab.keras")
17
 
 
 
18
 
19
- img_size = 28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
- labels = ["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
- def predict(img):
24
- try:
25
- # Enhanced image validation and conversion
26
- if img is None:
27
- raise ValueError("No image provided")
28
-
29
- # Convert to numpy array if it's a PIL Image
30
- if isinstance(img, Image.Image):
31
- img = np.array(img)
32
-
33
- # Handle base64 image strings
34
- elif isinstance(img, str):
35
- # Check if it's a base64 data URL
36
- if img.startswith('data:image'):
37
- # Split and decode base64 part
38
- img = img.split(',')[1]
39
-
40
- # Decode base64 to image
41
- try:
42
- img = Image.open(io.BytesIO(base64.b64decode(img)))
43
- img = np.array(img)
44
- except Exception as e:
45
- print(f"Base64 decoding error: {e}")
46
- raise ValueError("Invalid base64 image")
47
-
48
- # Validate numpy array
49
- if not isinstance(img, np.ndarray):
50
- raise ValueError("Input could not be converted to a valid image")
51
-
52
- # Print initial image details for debugging
53
- print(f"Initial image type: {type(img)}, shape: {img.shape}")
54
-
55
- # Handle color channels
56
- if img.ndim == 3:
57
- if img.shape[-1] == 3: # Color image
58
- img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
59
- elif img.shape[-1] == 4: # RGBA image
60
- img = cv2.cvtColor(img, cv2.COLOR_RGBA2GRAY)
61
-
62
- # Ensure single channel
63
- if img.ndim == 2:
64
- img = np.expand_dims(img, axis=-1)
65
-
66
- # Resize and normalize
67
- img = cv2.resize(img, (img_size, img_size))
68
- img = img.astype('float32') / 255.0
69
- img = img.reshape(1, img_size, img_size, 1)
70
-
71
- # Print processed image details
72
- print(f"Processed image shape: {img.shape}")
73
-
74
- # Get predictions from the model
75
- preds = model.predict(img)[0]
76
-
77
- # Print predictions for debugging
78
- print("Predictions:", preds)
79
-
80
- # Return predictions as a dictionary
81
- return {label: float(pred) for label, pred in zip(labels, preds)}
82
 
83
- except Exception as e:
84
- # Comprehensive error logging
85
- print(f"Full error during prediction: {e}")
86
- return {"Error": str(e)}
87
-
88
-
89
 
90
- # Set up the Gradio interface with the input as a sketchpad and output as labels
91
- label = gr.Label(num_top_classes=3)
 
 
 
 
 
 
92
 
93
- # Gradio arayüzü
94
- interface = gr.Interface(
95
- fn=predict,
96
- inputs=gr.Sketchpad(type="pil"),
97
- outputs=label,
98
- title="Sketch Recognition App",
99
- description="Draw a number (0-9) and see the model's top predictions."
 
 
 
 
 
 
 
 
 
100
  )
101
 
102
- interface.launch(debug=True, share=True)
 
 
 
 
1
  import gradio as gr
 
2
  import numpy as np
3
+ import tensorflow as tf
4
  from PIL import Image
5
 
6
+ # Load the pre-trained model
7
+ model = tf.keras.models.load_model('number_recognition_model_colab.h5')
 
 
 
 
 
 
 
 
8
 
9
+ # Class names for MNIST digits
10
+ classes_names = ["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
11
 
12
+ def preprocess_image(image):
13
+ """
14
+ Preprocess the input image to match the model's expected input
15
+ - Resize to 28x28
16
+ - Convert to grayscale
17
+ - Normalize pixel values
18
+ """
19
+ # Convert to grayscale if image is RGB
20
+ if len(image.shape) == 3:
21
+ image = np.mean(image, axis=2)
22
+
23
+ # Resize to 28x28
24
+ image = Image.fromarray(image.astype('uint8')).resize((28, 28), Image.LANCZOS)
25
+ image = np.array(image)
26
+
27
+ # Normalize
28
+ image = image.astype("float32") / 255.0
29
+
30
+ # Reshape to match model input shape
31
+ image = image.reshape(1, 28, 28, 1)
32
+
33
+ return image
34
 
35
+ def predict_digit(image):
36
+ """
37
+ Predict the digit in the input image
38
+ """
39
+ # Preprocess the image
40
+ processed_image = preprocess_image(image)
41
+
42
+ # Make prediction
43
+ predictions = model.predict(processed_image)
44
+
45
+ # Get the predicted class
46
+ predicted_class = np.argmax(predictions[0])
47
+ predicted_label = classes_names[predicted_class]
48
+ confidence = predictions[0][predicted_class] * 100
49
+
50
+ # Create a more detailed output
51
+ result_text = f"Predicted Digit: {predicted_label}\nConfidence: {confidence:.2f}%"
52
+
53
+ # Create bar chart of probabilities
54
+ probabilities = predictions[0] * 100
55
+
56
+ return result_text, probabilities
57
 
58
+ def create_probability_plot(probabilities):
59
+ """
60
+ Create a bar plot of digit probabilities
61
+ """
62
+ import matplotlib.pyplot as plt
63
+
64
+ plt.figure(figsize=(10, 5))
65
+ plt.bar(classes_names, probabilities)
66
+ plt.title('Digit Probability Distribution')
67
+ plt.xlabel('Digits')
68
+ plt.ylabel('Probability (%)')
69
+ plt.ylim(0, 100)
70
+
71
+ # Rotate x-axis labels
72
+ plt.xticks(rotation=45)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
+ return plt
 
 
 
 
 
75
 
76
+ # Create Gradio interface
77
+ def gradio_predict(image):
78
+ """
79
+ Wrapper function for Gradio interface
80
+ """
81
+ result_text, probabilities = predict_digit(image)
82
+ prob_plot = create_probability_plot(probabilities)
83
+ return result_text, prob_plot
84
 
85
+ # Set up the Gradio interface
86
+ iface = gr.Interface(
87
+ fn=gradio_predict,
88
+ inputs=gr.Image(type="numpy", image_mode="L"),
89
+ outputs=[
90
+ gr.Textbox(label="Prediction"),
91
+ gr.Plot(label="Probability Distribution")
92
+ ],
93
+ title="MNIST Digit Recognizer",
94
+ description="Draw a single-digit number (0-9) and the model will predict which digit it is!",
95
+ allow_flagging="never",
96
+ examples=[
97
+ ["example_zero.png"],
98
+ ["example_one.png"],
99
+ ["example_two.png"]
100
+ ]
101
  )
102
 
103
+ # Launch the interface
104
+ if __name__ == "__main__":
105
+ iface.launch()