Eric P. Nusbaum commited on
Commit
029bb24
·
1 Parent(s): b7bd655

Updated app.py

Browse files
Files changed (1) hide show
  1. app.py +87 -76
app.py CHANGED
@@ -1,5 +1,3 @@
1
- # app.py
2
-
3
  import gradio as gr
4
  import tensorflow as tf
5
  import numpy as np
@@ -58,12 +56,12 @@ def preprocess_image(image):
58
  - Normalize pixel values
59
  - Convert RGB to BGR if required by the model
60
  """
61
- image = image.convert('RGB') # Ensure image is in RGB
62
  image = image.resize((TARGET_WIDTH, TARGET_HEIGHT))
63
  image_np = np.array(image).astype(np.float32)
64
  image_np = image_np / 255.0 # Normalize to [0,1]
65
- if image_np.shape[2] == 3:
66
- image_np = image_np[:, :, (2, 1, 0)] # Convert RGB to BGR if required
 
67
  image_np = np.expand_dims(image_np, axis=0) # Add batch dimension
68
  return image_np
69
 
@@ -84,13 +82,17 @@ def draw_boxes(image, boxes, classes, scores, threshold=0.5):
84
  font = ImageFont.truetype("arial.ttf", 15)
85
  except IOError:
86
  font = ImageFont.load_default()
87
-
88
- detections = False # Flag to check if any detections are above threshold
89
 
 
 
 
 
 
 
90
  for box, cls, score in zip(boxes[0], classes[0], scores[0]):
91
  if score < threshold:
92
  continue
93
- detections = True
94
  # Convert box coordinates from normalized to absolute
95
  ymin, xmin, ymax, xmax = box
96
  left = xmin * image.width
@@ -102,42 +104,24 @@ def draw_boxes(image, boxes, classes, scores, threshold=0.5):
102
  draw.rectangle([(left, top), (right, bottom)], outline="red", width=2)
103
 
104
  # Prepare label
105
- label = f"{labels[int(cls) - 1]}: {score:.2f}"
 
 
 
 
106
 
107
  # Calculate text size using textbbox
108
- text_bbox = draw.textbbox((0, 0), label, font=font)
109
  text_width = text_bbox[2] - text_bbox[0]
110
  text_height = text_bbox[3] - text_bbox[1]
111
 
112
  # Draw label background
113
- draw.rectangle([(left, top - text_height - 4), (left + text_width + 4, top)], fill="red")
 
114
 
115
  # Draw text
116
- draw.text((left + 2, top - text_height - 2), label, fill="white", font=font)
117
-
118
- if not detections:
119
- # Optionally, you can add text indicating no detections were found
120
- try:
121
- font_large = ImageFont.truetype("arial.ttf", 20)
122
- except IOError:
123
- font_large = ImageFont.load_default()
124
- no_detect_text = "No detections found."
125
- text_bbox = draw.textbbox((0, 0), no_detect_text, font=font_large)
126
- text_width = text_bbox[2] - text_bbox[0]
127
- text_height = text_bbox[3] - text_bbox[1]
128
- draw.rectangle(
129
- [
130
- ((image.width - text_width) / 2 - 10, (image.height - text_height) / 2 - 10),
131
- ((image.width + text_width) / 2 + 10, (image.height + text_height) / 2 + 10)
132
- ],
133
- fill="black"
134
- )
135
- draw.text(
136
- ((image.width - text_width) / 2, (image.height - text_height) / 2),
137
- no_detect_text,
138
- fill="white",
139
- font=font_large
140
- )
141
 
142
  return image
143
 
@@ -150,26 +134,30 @@ def predict(image):
150
  PIL.Image: Annotated image with bounding boxes and labels.
151
  """
152
  try:
153
- print("Starting prediction...")
154
  # Preprocess the image
155
  input_array = preprocess_image(image)
156
- print(f"Preprocessed image shape: {input_array.shape}")
157
-
 
 
158
  # Run inference
159
  boxes, classes, scores = sess.run(
160
  [detected_boxes, detected_classes, detected_scores],
161
  feed_dict={input_tensor: input_array}
162
  )
163
- print(f"Inference completed. Number of detections: {len(boxes[0])}")
164
-
165
- # Check if detections are present
166
- if boxes.shape[1] == 0:
167
- print("No detections returned by the model.")
168
-
 
 
 
 
169
  # Annotate the image with bounding boxes and labels
170
  annotated_image = draw_boxes(image.copy(), boxes, classes, scores, threshold=0.5)
171
- print("Annotated image created successfully.")
172
-
173
  return annotated_image
174
 
175
  except Exception as e:
@@ -177,34 +165,57 @@ def predict(image):
177
  print(f"Exception during prediction: {e}")
178
 
179
  # Return an error image with the error message
180
- error_image = Image.new('RGB', (500, 500), color=(255, 0, 0))
181
- draw = ImageDraw.Draw(error_image)
182
- try:
183
- font = ImageFont.truetype("arial.ttf", 20)
184
- except IOError:
185
- font = ImageFont.load_default()
186
- error_text = "Error during prediction."
187
-
188
- # Calculate text size using textbbox
189
- text_bbox = draw.textbbox((0, 0), error_text, font=font)
190
- text_width = text_bbox[2] - text_bbox[0]
191
- text_height = text_bbox[3] - text_bbox[1]
192
-
193
- # Center the text
194
- draw.rectangle(
195
- [
196
- ((500 - text_width) / 2 - 10, (500 - text_height) / 2 - 10),
197
- ((500 + text_width) / 2 + 10, (500 + text_height) / 2 + 10)
198
- ],
199
- fill="black"
200
- )
201
- draw.text(
202
- ((500 - text_width) / 2, (500 - text_height) / 2),
203
- error_text,
204
- fill="white",
205
- font=font
206
- )
207
- return error_image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
 
209
  # Define Gradio interface using the new API
210
  title = "JunkWaxHero 🦸‍♂️ - Baseball Card Set Identifier"
@@ -223,7 +234,7 @@ iface = gr.Interface(
223
  title=title,
224
  description=description,
225
  examples=valid_examples,
226
- flagging_mode="never" # Updated parameter
227
  )
228
 
229
  if __name__ == "__main__":
 
 
 
1
  import gradio as gr
2
  import tensorflow as tf
3
  import numpy as np
 
56
  - Normalize pixel values
57
  - Convert RGB to BGR if required by the model
58
  """
 
59
  image = image.resize((TARGET_WIDTH, TARGET_HEIGHT))
60
  image_np = np.array(image).astype(np.float32)
61
  image_np = image_np / 255.0 # Normalize to [0,1]
62
+ if image_np.shape[-1] == 3:
63
+ # Convert RGB to BGR if required by your model
64
+ image_np = image_np[..., (2, 1, 0)]
65
  image_np = np.expand_dims(image_np, axis=0) # Add batch dimension
66
  return image_np
67
 
 
82
  font = ImageFont.truetype("arial.ttf", 15)
83
  except IOError:
84
  font = ImageFont.load_default()
 
 
85
 
86
+ # If there are no detections at all
87
+ if boxes.shape[0] == 0 or boxes.shape[1] == 0:
88
+ # Return the original image without annotation
89
+ return image
90
+
91
+ # Otherwise, proceed to draw bounding boxes
92
  for box, cls, score in zip(boxes[0], classes[0], scores[0]):
93
  if score < threshold:
94
  continue
95
+
96
  # Convert box coordinates from normalized to absolute
97
  ymin, xmin, ymax, xmax = box
98
  left = xmin * image.width
 
104
  draw.rectangle([(left, top), (right, bottom)], outline="red", width=2)
105
 
106
  # Prepare label
107
+ cls_index = int(cls) - 1 # If your classes are 1-indexed
108
+ if cls_index < 0 or cls_index >= len(labels):
109
+ label_str = f"cls_{int(cls)}: {score:.2f}"
110
+ else:
111
+ label_str = f"{labels[cls_index]}: {score:.2f}"
112
 
113
  # Calculate text size using textbbox
114
+ text_bbox = draw.textbbox((0, 0), label_str, font=font)
115
  text_width = text_bbox[2] - text_bbox[0]
116
  text_height = text_bbox[3] - text_bbox[1]
117
 
118
  # Draw label background
119
+ draw.rectangle([(left, top - text_height - 4),
120
+ (left + text_width + 4, top)], fill="red")
121
 
122
  # Draw text
123
+ draw.text((left + 2, top - text_height - 2),
124
+ label_str, fill="white", font=font)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
 
126
  return image
127
 
 
134
  PIL.Image: Annotated image with bounding boxes and labels.
135
  """
136
  try:
 
137
  # Preprocess the image
138
  input_array = preprocess_image(image)
139
+
140
+ # Debug prints
141
+ print(f"[DEBUG] Input shape to model: {input_array.shape}")
142
+
143
  # Run inference
144
  boxes, classes, scores = sess.run(
145
  [detected_boxes, detected_classes, detected_scores],
146
  feed_dict={input_tensor: input_array}
147
  )
148
+
149
+ # Debug prints
150
+ print(f"[DEBUG] boxes shape: {boxes.shape}, classes shape: {classes.shape}, scores shape: {scores.shape}")
151
+
152
+ # Check if the output arrays have zero detections
153
+ # e.g. boxes could have shape (1, 0, 4) if no detections are found
154
+ if boxes.size == 0 or classes.size == 0 or scores.size == 0:
155
+ print("[DEBUG] No detections returned by the model.")
156
+ return _draw_no_detection_message(image)
157
+
158
  # Annotate the image with bounding boxes and labels
159
  annotated_image = draw_boxes(image.copy(), boxes, classes, scores, threshold=0.5)
160
+ print("[DEBUG] Annotation completed.")
 
161
  return annotated_image
162
 
163
  except Exception as e:
 
165
  print(f"Exception during prediction: {e}")
166
 
167
  # Return an error image with the error message
168
+ return _draw_error_message()
169
+
170
+ def _draw_no_detection_message(image):
171
+ """Draws a simple 'No detections found' message on the image."""
172
+ draw = ImageDraw.Draw(image)
173
+ try:
174
+ font = ImageFont.truetype("arial.ttf", 20)
175
+ except IOError:
176
+ font = ImageFont.load_default()
177
+ message = "No detections found."
178
+ text_bbox = draw.textbbox((0, 0), message, font=font)
179
+ text_width = text_bbox[2] - text_bbox[0]
180
+ text_height = text_bbox[3] - text_bbox[1]
181
+
182
+ # Center the message
183
+ x = (image.width - text_width) / 2
184
+ y = (image.height - text_height) / 2
185
+ draw.rectangle(
186
+ [(x - 10, y - 10), (x + text_width + 10, y + text_height + 10)],
187
+ fill="black"
188
+ )
189
+ draw.text((x, y), message, fill="white", font=font)
190
+ return image
191
+
192
+ def _draw_error_message():
193
+ """Creates a red image with a centered error message."""
194
+ error_image = Image.new('RGB', (500, 500), color=(255, 0, 0))
195
+ draw = ImageDraw.Draw(error_image)
196
+ try:
197
+ font = ImageFont.truetype("arial.ttf", 20)
198
+ except IOError:
199
+ font = ImageFont.load_default()
200
+ error_text = "Error during prediction."
201
+ text_bbox = draw.textbbox((0, 0), error_text, font=font)
202
+ text_width = text_bbox[2] - text_bbox[0]
203
+ text_height = text_bbox[3] - text_bbox[1]
204
+
205
+ draw.rectangle(
206
+ [
207
+ ((500 - text_width) / 2 - 10, (500 - text_height) / 2 - 10),
208
+ ((500 + text_width) / 2 + 10, (500 + text_height) / 2 + 10)
209
+ ],
210
+ fill="black"
211
+ )
212
+ draw.text(
213
+ ((500 - text_width) / 2, (500 - text_height) / 2),
214
+ error_text,
215
+ fill="white",
216
+ font=font
217
+ )
218
+ return error_image
219
 
220
  # Define Gradio interface using the new API
221
  title = "JunkWaxHero 🦸‍♂️ - Baseball Card Set Identifier"
 
234
  title=title,
235
  description=description,
236
  examples=valid_examples,
237
+ flagging_mode="never" # Use new Gradio parameter
238
  )
239
 
240
  if __name__ == "__main__":