cisemh commited on
Commit
67aae63
·
verified ·
1 Parent(s): 80396b1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -36
app.py CHANGED
@@ -22,54 +22,69 @@ labels = ["zero", "one", "two", "three", "four", "five", "six", "seven", "eight"
22
 
23
  def predict(img):
24
  try:
25
- # Check if the image is a PIL object and convert it to NumPy array
 
 
 
 
26
  if isinstance(img, Image.Image):
27
  img = np.array(img)
28
-
29
- # If the image is base64, convert it to a PIL image and then to a NumPy array
30
- elif isinstance(img, str): # base64 image
31
- img = Image.open(io.BytesIO(base64.b64decode(img.split(',')[1])))
32
- img = np.array(img)
33
-
34
- # If the input is still not a NumPy array, raise an error
 
 
 
 
 
 
 
 
 
 
35
  if not isinstance(img, np.ndarray):
36
- raise ValueError("Input is not a valid image")
37
-
38
- # Print shape and type of the input image
39
  print(f"Initial image type: {type(img)}, shape: {img.shape}")
40
-
41
- # Ensure the image is in grayscale and has a single channel
42
- if img.ndim == 3 and img.shape[-1] == 3:
43
- img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
44
- elif img.ndim == 2:
 
 
 
 
 
45
  img = np.expand_dims(img, axis=-1)
46
-
47
- # Print the shape of the grayscale image
48
- print(f"Grayscale image shape: {img.shape}")
49
-
50
- # Resize the image
51
  img = cv2.resize(img, (img_size, img_size))
52
-
53
- # Normalize the image
54
  img = img.astype('float32') / 255.0
55
  img = img.reshape(1, img_size, img_size, 1)
56
-
57
- # Print the shape after resizing and normalizing
58
  print(f"Processed image shape: {img.shape}")
59
-
60
- # Get the predictions from the model
61
  preds = model.predict(img)[0]
62
-
63
- # Print the predictions
64
  print("Predictions:", preds)
65
-
66
- # Return the predictions for each label
67
  return {label: float(pred) for label, pred in zip(labels, preds)}
 
68
  except Exception as e:
69
- # Print the exception to the console
70
- print(f"Error during prediction: {e}")
71
  return {"Error": str(e)}
72
-
73
 
74
 
75
  # Set up the Gradio interface with the input as a sketchpad and output as labels
@@ -78,8 +93,8 @@ label = gr.Label(num_top_classes=3)
78
  # Gradio arayüzü
79
  interface = gr.Interface(
80
  fn=predict,
81
- inputs="sketchpad",
82
- outputs="textbox",
83
  title="Sketch Recognition App",
84
  description="Draw a number (0-9) and see the model's top predictions."
85
  )
 
22
 
23
  def predict(img):
24
  try:
25
+ # Enhanced image validation and conversion
26
+ if img is None:
27
+ raise ValueError("No image provided")
28
+
29
+ # Convert to numpy array if it's a PIL Image
30
  if isinstance(img, Image.Image):
31
  img = np.array(img)
32
+
33
+ # Handle base64 image strings
34
+ elif isinstance(img, str):
35
+ # Check if it's a base64 data URL
36
+ if img.startswith('data:image'):
37
+ # Split and decode base64 part
38
+ img = img.split(',')[1]
39
+
40
+ # Decode base64 to image
41
+ try:
42
+ img = Image.open(io.BytesIO(base64.b64decode(img)))
43
+ img = np.array(img)
44
+ except Exception as e:
45
+ print(f"Base64 decoding error: {e}")
46
+ raise ValueError("Invalid base64 image")
47
+
48
+ # Validate numpy array
49
  if not isinstance(img, np.ndarray):
50
+ raise ValueError("Input could not be converted to a valid image")
51
+
52
+ # Print initial image details for debugging
53
  print(f"Initial image type: {type(img)}, shape: {img.shape}")
54
+
55
+ # Handle color channels
56
+ if img.ndim == 3:
57
+ if img.shape[-1] == 3: # Color image
58
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
59
+ elif img.shape[-1] == 4: # RGBA image
60
+ img = cv2.cvtColor(img, cv2.COLOR_RGBA2GRAY)
61
+
62
+ # Ensure single channel
63
+ if img.ndim == 2:
64
  img = np.expand_dims(img, axis=-1)
65
+
66
+ # Resize and normalize
 
 
 
67
  img = cv2.resize(img, (img_size, img_size))
 
 
68
  img = img.astype('float32') / 255.0
69
  img = img.reshape(1, img_size, img_size, 1)
70
+
71
+ # Print processed image details
72
  print(f"Processed image shape: {img.shape}")
73
+
74
+ # Get predictions from the model
75
  preds = model.predict(img)[0]
76
+
77
+ # Print predictions for debugging
78
  print("Predictions:", preds)
79
+
80
+ # Return predictions as a dictionary
81
  return {label: float(pred) for label, pred in zip(labels, preds)}
82
+
83
  except Exception as e:
84
+ # Comprehensive error logging
85
+ print(f"Full error during prediction: {e}")
86
  return {"Error": str(e)}
87
+
88
 
89
 
90
  # Set up the Gradio interface with the input as a sketchpad and output as labels
 
93
  # Gradio arayüzü
94
  interface = gr.Interface(
95
  fn=predict,
96
+ inputs=gr.Sketchpad(type="pil"),
97
+ outputs=gr.Label(num_top_classes=3), ,
98
  title="Sketch Recognition App",
99
  description="Draw a number (0-9) and see the model's top predictions."
100
  )