cisemh commited on
Commit
fae307d
·
verified ·
1 Parent(s): babe4e3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -101
app.py CHANGED
@@ -1,105 +1,74 @@
1
- import gradio as gr
2
  import numpy as np
 
3
  import tensorflow as tf
4
- from PIL import Image
5
-
6
- # Load the pre-trained model
7
- model = tf.keras.models.load_model('number_recognition_model_colab.h5')
8
-
9
- # Class names for MNIST digits
10
- classes_names = ["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
11
-
12
- def preprocess_image(image):
13
- """
14
- Preprocess the input image to match the model's expected input
15
- - Resize to 28x28
16
- - Convert to grayscale
17
- - Normalize pixel values
18
- """
19
- # Convert to grayscale if image is RGB
20
- if len(image.shape) == 3:
21
- image = np.mean(image, axis=2)
22
-
23
- # Resize to 28x28
24
- image = Image.fromarray(image.astype('uint8')).resize((28, 28), Image.LANCZOS)
25
- image = np.array(image)
26
-
27
- # Normalize
28
- image = image.astype("float32") / 255.0
29
-
30
- # Reshape to match model input shape
31
- image = image.reshape(1, 28, 28, 1)
32
-
33
- return image
34
-
35
- def predict_digit(image):
36
- """
37
- Predict the digit in the input image
38
- """
39
- # Preprocess the image
40
- processed_image = preprocess_image(image)
41
-
42
- # Make prediction
43
- predictions = model.predict(processed_image)
44
-
45
- # Get the predicted class
46
- predicted_class = np.argmax(predictions[0])
47
- predicted_label = classes_names[predicted_class]
48
- confidence = predictions[0][predicted_class] * 100
49
-
50
- # Create a more detailed output
51
- result_text = f"Predicted Digit: {predicted_label}\nConfidence: {confidence:.2f}%"
52
-
53
- # Create bar chart of probabilities
54
- probabilities = predictions[0] * 100
55
-
56
- return result_text, probabilities
57
-
58
- def create_probability_plot(probabilities):
59
- """
60
- Create a bar plot of digit probabilities
61
- """
62
- import matplotlib.pyplot as plt
63
-
64
- plt.figure(figsize=(10, 5))
65
- plt.bar(classes_names, probabilities)
66
- plt.title('Digit Probability Distribution')
67
- plt.xlabel('Digits')
68
- plt.ylabel('Probability (%)')
69
- plt.ylim(0, 100)
70
-
71
- # Rotate x-axis labels
72
- plt.xticks(rotation=45)
73
-
74
- return plt
75
-
76
- # Create Gradio interface
77
- def gradio_predict(image):
78
- """
79
- Wrapper function for Gradio interface
80
- """
81
- result_text, probabilities = predict_digit(image)
82
- prob_plot = create_probability_plot(probabilities)
83
- return result_text, prob_plot
84
-
85
- # Set up the Gradio interface
86
- iface = gr.Interface(
87
- fn=gradio_predict,
88
- inputs=gr.Image(type="numpy", image_mode="L"),
89
- outputs=[
90
- gr.Textbox(label="Prediction"),
91
- gr.Plot(label="Probability Distribution")
92
- ],
93
- title="MNIST Digit Recognizer",
94
- description="Draw a single-digit number (0-9) and the model will predict which digit it is!",
95
- allow_flagging="never",
96
- examples=[
97
- ["example_zero.png"],
98
- ["example_one.png"],
99
- ["example_two.png"]
100
- ]
101
  )
102
 
103
- # Launch the interface
104
- if __name__ == "__main__":
105
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import numpy as np
2
+ import gradio as gr
3
  import tensorflow as tf
4
+ import cv2
5
+
6
+ # Load the trained MNIST model
7
+ model = tf.keras.models.load_model("./number_recognition_model_colab.keras")
8
+
9
+ # Class names (0 to 9)
10
+ labels = ["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
11
+
12
+ def predict(data):
13
+ # Extract the 'composite' key from the input dictionary
14
+ img = data["composite"]
15
+ img = np.array(img)
16
+
17
+ # Convert RGBA to RGB if needed
18
+ if img.shape[-1] == 4: # RGBA
19
+ img = cv2.cvtColor(img, cv2.COLOR_RGBA2RGB)
20
+
21
+ # Convert RGB to Grayscale
22
+ if img.shape[-1] == 3: # RGB
23
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
24
+
25
+ # Resize image to 28x28
26
+ img = cv2.resize(img, (28, 28))
27
+
28
+ # Normalize pixel values to [0, 1]
29
+ img = img / 255.0
30
+
31
+ # Reshape to match model input (1, 28, 28, 1)
32
+ img = img.reshape(1, 28, 28, 1)
33
+
34
+ # Model predictions
35
+ preds = model.predict(img)[0]
36
+
37
+ print(preds)
38
+
39
+ # Get top 3 classes
40
+ top_3_classes = np.argsort(preds)[-3:][::-1]
41
+ top_3_probs = preds[top_3_classes]
42
+ class_names = [labels[i] for i in top_3_classes]
43
+ print(class_names, top_3_probs, top_3_classes)
44
+
45
+ # Return top 3 predictions as a dictionary
46
+ return {class_names[i]: float(top_3_probs[i]) for i in range(3)}
47
+
48
+ # Title and description
49
+ title = "Welcome to your first sketch recognition app!"
50
+ head = (
51
+ "<center>"
52
+ "<img src='./mnist-classes.png' width=400>"
53
+ "<p>The model is trained to classify numbers (from 0 to 9). "
54
+ "To test it, draw your number in the space provided (use the editing tools in the image editor).</p>"
55
+ "</center>"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  )
57
 
58
+
59
+ with gr.Blocks(title=title) as demo:
60
+ # Display title and description
61
+ gr.Markdown(head)
62
+ gr.Markdown(ref)
63
+
64
+ with gr.Row():
65
+ # Using ImageEditor with type='numpy'
66
+ im = gr.Sketchpad(type="numpy", label="Draw your digit here (use brush and eraser)")
67
+
68
+ # Output label (top 3 predictions)
69
+ label = gr.Label(num_top_classes=3, label="Predictions")
70
+
71
+ # Trigger prediction whenever the image changes
72
+ im.change(predict, inputs=im, outputs=label)
73
+
74
+ demo.launch(share=True)