cisemh commited on
Commit
1dc2a74
·
verified ·
1 Parent(s): d0112b3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +85 -9
app.py CHANGED
@@ -1,6 +1,8 @@
 
1
  import gradio as gr
2
  import tensorflow as tf
3
  import numpy as np
 
4
 
5
  title = "Welcome on your first sketch recognition app!"
6
 
@@ -13,14 +15,88 @@ head = (
13
  # Model yükleniyor
14
  model = tf.keras.models.load_model("number_recognition_model_colab.keras")
15
 
16
- def recognize_digit(image):
17
- prediction = model.predict(np.reshape(image, (1, 28, 28))).tolist()[0]
18
- return {str(i): prediction[i] for i in range(10)}
19
 
20
- sketchpad = gr.Sketchpad(shape=(28, 28))
21
- gr.Interface(fn=recognize_digit,
22
- inputs=sketchpad,
23
- outputs="label",
24
- title="Handwritten Digits Classifier",
25
- description="This app uses lenet5 for handwritten digits classification").launch()
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
  import gradio as gr
3
  import tensorflow as tf
4
  import numpy as np
5
+ from PIL import Image
6
 
7
  title = "Welcome on your first sketch recognition app!"
8
 
 
15
  # Model yükleniyor
16
  model = tf.keras.models.load_model("number_recognition_model_colab.keras")
17
 
 
 
 
18
 
19
+ img_size = 28
 
 
 
 
 
20
 
21
+ labels = ["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
22
+
23
+ def predict(img):
24
+ try:
25
+ # Enhanced image validation and conversion
26
+ if img is None:
27
+ raise ValueError("No image provided")
28
+
29
+ # Convert to numpy array if it's a PIL Image
30
+ if isinstance(img, Image.Image):
31
+ img = np.array(img)
32
+
33
+ # Handle base64 image strings
34
+ elif isinstance(img, str):
35
+ # Check if it's a base64 data URL
36
+ if img.startswith('data:image'):
37
+ # Split and decode base64 part
38
+ img = img.split(',')[1]
39
+
40
+ # Decode base64 to image
41
+ try:
42
+ img = Image.open(io.BytesIO(base64.b64decode(img)))
43
+ img = np.array(img)
44
+ except Exception as e:
45
+ print(f"Base64 decoding error: {e}")
46
+ raise ValueError("Invalid base64 image")
47
+
48
+ # Validate numpy array
49
+ if not isinstance(img, np.ndarray):
50
+ raise ValueError("Input could not be converted to a valid image")
51
+
52
+ # Print initial image details for debugging
53
+ print(f"Initial image type: {type(img)}, shape: {img.shape}")
54
+
55
+ # Handle color channels
56
+ if img.ndim == 3:
57
+ if img.shape[-1] == 3: # Color image
58
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
59
+ elif img.shape[-1] == 4: # RGBA image
60
+ img = cv2.cvtColor(img, cv2.COLOR_RGBA2GRAY)
61
+
62
+ # Ensure single channel
63
+ if img.ndim == 2:
64
+ img = np.expand_dims(img, axis=-1)
65
+
66
+ # Resize and normalize
67
+ img = cv2.resize(img, (img_size, img_size))
68
+ img = img.astype('float32') / 255.0
69
+ img = img.reshape(1, img_size, img_size, 1)
70
+
71
+ # Print processed image details
72
+ print(f"Processed image shape: {img.shape}")
73
+
74
+ # Get predictions from the model
75
+ preds = model.predict(img)[0]
76
+
77
+ # Print predictions for debugging
78
+ print("Predictions:", preds)
79
+
80
+ # Return predictions as a dictionary
81
+ return {label: float(pred) for label, pred in zip(labels, preds)}
82
+
83
+ except Exception as e:
84
+ # Comprehensive error logging
85
+ print(f"Full error during prediction: {e}")
86
+ return {"Error": str(e)}
87
+
88
+
89
+
90
+ # Set up the Gradio interface with the input as a sketchpad and output as labels
91
+ label = gr.Label(num_top_classes=3)
92
+
93
+ # Gradio arayüzü
94
+ interface = gr.Interface(
95
+ fn=predict,
96
+ inputs=gr.Sketchpad(type="pil"),
97
+ outputs=label,
98
+ title="Sketch Recognition App",
99
+ description="Draw a number (0-9) and see the model's top predictions."
100
+ )
101
+
102
+ interface.launch(debug=True, share=True)