cisemh commited on
Commit
d1c466d
·
verified ·
1 Parent(s): 80ee6b5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -91
app.py CHANGED
@@ -1,102 +1,83 @@
1
- import cv2
2
- import gradio as gr
3
- import tensorflow as tf
4
  import numpy as np
5
- from PIL import Image
 
 
6
 
7
- title = "Welcome on your first sketch recognition app!"
 
 
 
 
 
 
 
 
8
 
9
- head = (
10
- "<center>"
11
- "The robot was trained to classify numbers (from 0 to 9). To test it, write your number in the space provided."
12
- "</center>"
13
- )
 
 
 
 
 
14
 
15
- # Model yükleniyor
16
- model = tf.keras.models.load_model("./number_recognition_model_colab.keras")
 
 
 
 
 
17
 
 
 
 
 
 
 
 
 
 
 
18
 
19
- img_size = 28
 
 
 
 
20
 
21
- labels = ["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
 
22
 
23
- def predict(img):
24
- try:
25
- # Enhanced image validation and conversion
26
- if img is None:
27
- raise ValueError("No image provided")
28
-
29
- # Convert to numpy array if it's a PIL Image
30
- if isinstance(img, Image.Image):
31
- img = np.array(img)
32
-
33
- # Handle base64 image strings
34
- elif isinstance(img, str):
35
- # Check if it's a base64 data URL
36
- if img.startswith('data:image'):
37
- # Split and decode base64 part
38
- img = img.split(',')[1]
39
-
40
- # Decode base64 to image
41
- try:
42
- img = Image.open(io.BytesIO(base64.b64decode(img)))
43
- img = np.array(img)
44
- except Exception as e:
45
- print(f"Base64 decoding error: {e}")
46
- raise ValueError("Invalid base64 image")
47
-
48
- # Validate numpy array
49
- if not isinstance(img, np.ndarray):
50
- raise ValueError("Input could not be converted to a valid image")
51
-
52
- # Print initial image details for debugging
53
- print(f"Initial image type: {type(img)}, shape: {img.shape}")
54
-
55
- # Handle color channels
56
- if img.ndim == 3:
57
- if img.shape[-1] == 3: # Color image
58
- img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
59
- elif img.shape[-1] == 4: # RGBA image
60
- img = cv2.cvtColor(img, cv2.COLOR_RGBA2GRAY)
61
-
62
- # Ensure single channel
63
- if img.ndim == 2:
64
- img = np.expand_dims(img, axis=-1)
65
-
66
- # Resize and normalize
67
- img = cv2.resize(img, (img_size, img_size))
68
- img = img.astype('float32') / 255.0
69
- img = img.reshape(1, img_size, img_size, 1)
70
-
71
- # Print processed image details
72
- print(f"Processed image shape: {img.shape}")
73
-
74
- # Get predictions from the model
75
- preds = model.predict(img)[0]
76
-
77
- # Print predictions for debugging
78
- print("Predictions:", preds)
79
-
80
- # Return predictions as a dictionary
81
- return {label: float(pred) for label, pred in zip(labels, preds)}
82
-
83
- except Exception as e:
84
- # Comprehensive error logging
85
- print(f"Full error during prediction: {e}")
86
- return {"Error": str(e)}
87
-
88
-
89
 
90
- # Set up the Gradio interface with the input as a sketchpad and output as labels
91
- label = gr.Label(num_top_classes=3)
 
 
 
 
92
 
93
- # Gradio arayüzü
94
- interface = gr.Interface(
95
- fn=predict,
96
- inputs=gr.Sketchpad(type="pil"),
97
- outputs=label,
98
- title="Sketch Recognition App",
99
- description="Draw a number (0-9) and see the model's top predictions."
100
- )
 
 
 
 
101
 
102
- interface.launch(debug=True, share=True)
 
 
 
 
 
 
1
  import numpy as np
2
+ import tensorflow as tf
3
+ from tensorflow import keras
4
+ import gradio as gr
5
 
6
+ # Load and preprocess the MNIST dataset
7
+ def load_data():
8
+ """Load and preprocess the MNIST dataset."""
9
+ (X_train, y_train), (X_test, y_test) = tf.keras.datasets.mnist.load_data()
10
+ X_train = X_train.astype("float32") / 255
11
+ X_test = X_test.astype("float32") / 255
12
+ X_train = X_train.reshape(-1, 28, 28, 1)
13
+ X_test = X_test.reshape(-1, 28, 28, 1)
14
+ return (X_train, y_train), (X_test, y_test)
15
 
16
+ # Build the CNN model
17
+ def build_model(input_shape, num_classes):
18
+ """Build the CNN model."""
19
+ inputs = keras.layers.Input(input_shape)
20
+ x = keras.layers.Conv2D(28, kernel_size=(3, 3), activation='relu')(inputs)
21
+ x = keras.layers.MaxPooling2D(pool_size=(2, 2))(x)
22
+ x = keras.layers.Flatten()(x)
23
+ x = keras.layers.Dense(128, activation='relu')(x)
24
+ outputs = keras.layers.Dense(num_classes, activation='softmax')(x)
25
+ return keras.models.Model(inputs=inputs, outputs=outputs)
26
 
27
+ # Preprocess input for prediction
28
+ def preprocess_image(image):
29
+ """Resize and normalize the input image for prediction."""
30
+ image = np.array(image.convert('L')) # Convert to grayscale
31
+ image = image.astype("float32") / 255 # Normalize
32
+ image = image.reshape(1, 28, 28, 1) # Reshape to model's input
33
+ return image
34
 
35
+ # Predict digit
36
+ def predict_digit(image):
37
+ """Predict the digit in the uploaded image."""
38
+ processed_image = preprocess_image(image)
39
+ prediction = model.predict(processed_image)
40
+ class_id = np.argmax(prediction)
41
+ confidence = prediction[0][class_id]
42
+ label = classes_names[class_id]
43
+ results = {name: float(prediction[0][i]) for i, name in enumerate(classes_names)}
44
+ return label, results
45
 
46
+ if __name__ == "__main__":
47
+ # Parameters
48
+ classes_names = ["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
49
+ input_shape = (28, 28, 1)
50
+ num_classes = len(classes_names)
51
 
52
+ # Load data
53
+ (X_train, y_train), (X_test, y_test) = load_data()
54
 
55
+ # Build and train model
56
+ model = build_model(input_shape, num_classes)
57
+ model.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"])
58
+ print("Training model...")
59
+ model.fit(X_train, y_train, epochs=3, batch_size=64) # Quick training for demonstration
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
+ # Gradio Interface
62
+ title = "Welcome to Your First Sketch Recognition App!"
63
+ description = (
64
+ "The robot was trained to classify numbers (from 0 to 9). To test it, draw your number in the space provided."
65
+ )
66
+ examples = [["example_image.png"]] # You can add example images here.
67
 
68
+ interface = gr.Interface(
69
+ fn=predict_digit,
70
+ inputs=gr.inputs.Image(shape=(28, 28), image_mode="L", invert_colors=True, label="Draw a Digit"),
71
+ outputs=[
72
+ gr.outputs.Textbox(label="Predicted Digit"),
73
+ gr.outputs.Label(num_top_classes=10, label="Prediction Confidence"),
74
+ ],
75
+ title=title,
76
+ description=description,
77
+ examples=examples,
78
+ live=True,
79
+ )
80
 
81
+ # Launch Gradio interface
82
+ print("Launching Gradio interface...")
83
+ interface.launch()