Spaces:
Sleeping
Sleeping
Eric P. Nusbaum
commited on
Commit
·
029bb24
1
Parent(s):
b7bd655
Updated app.py
Browse files
app.py
CHANGED
@@ -1,5 +1,3 @@
|
|
1 |
-
# app.py
|
2 |
-
|
3 |
import gradio as gr
|
4 |
import tensorflow as tf
|
5 |
import numpy as np
|
@@ -58,12 +56,12 @@ def preprocess_image(image):
|
|
58 |
- Normalize pixel values
|
59 |
- Convert RGB to BGR if required by the model
|
60 |
"""
|
61 |
-
image = image.convert('RGB') # Ensure image is in RGB
|
62 |
image = image.resize((TARGET_WIDTH, TARGET_HEIGHT))
|
63 |
image_np = np.array(image).astype(np.float32)
|
64 |
image_np = image_np / 255.0 # Normalize to [0,1]
|
65 |
-
if image_np.shape[
|
66 |
-
|
|
|
67 |
image_np = np.expand_dims(image_np, axis=0) # Add batch dimension
|
68 |
return image_np
|
69 |
|
@@ -84,13 +82,17 @@ def draw_boxes(image, boxes, classes, scores, threshold=0.5):
|
|
84 |
font = ImageFont.truetype("arial.ttf", 15)
|
85 |
except IOError:
|
86 |
font = ImageFont.load_default()
|
87 |
-
|
88 |
-
detections = False # Flag to check if any detections are above threshold
|
89 |
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
for box, cls, score in zip(boxes[0], classes[0], scores[0]):
|
91 |
if score < threshold:
|
92 |
continue
|
93 |
-
|
94 |
# Convert box coordinates from normalized to absolute
|
95 |
ymin, xmin, ymax, xmax = box
|
96 |
left = xmin * image.width
|
@@ -102,42 +104,24 @@ def draw_boxes(image, boxes, classes, scores, threshold=0.5):
|
|
102 |
draw.rectangle([(left, top), (right, bottom)], outline="red", width=2)
|
103 |
|
104 |
# Prepare label
|
105 |
-
|
|
|
|
|
|
|
|
|
106 |
|
107 |
# Calculate text size using textbbox
|
108 |
-
text_bbox = draw.textbbox((0, 0),
|
109 |
text_width = text_bbox[2] - text_bbox[0]
|
110 |
text_height = text_bbox[3] - text_bbox[1]
|
111 |
|
112 |
# Draw label background
|
113 |
-
draw.rectangle([(left, top - text_height - 4),
|
|
|
114 |
|
115 |
# Draw text
|
116 |
-
draw.text((left + 2, top - text_height - 2),
|
117 |
-
|
118 |
-
if not detections:
|
119 |
-
# Optionally, you can add text indicating no detections were found
|
120 |
-
try:
|
121 |
-
font_large = ImageFont.truetype("arial.ttf", 20)
|
122 |
-
except IOError:
|
123 |
-
font_large = ImageFont.load_default()
|
124 |
-
no_detect_text = "No detections found."
|
125 |
-
text_bbox = draw.textbbox((0, 0), no_detect_text, font=font_large)
|
126 |
-
text_width = text_bbox[2] - text_bbox[0]
|
127 |
-
text_height = text_bbox[3] - text_bbox[1]
|
128 |
-
draw.rectangle(
|
129 |
-
[
|
130 |
-
((image.width - text_width) / 2 - 10, (image.height - text_height) / 2 - 10),
|
131 |
-
((image.width + text_width) / 2 + 10, (image.height + text_height) / 2 + 10)
|
132 |
-
],
|
133 |
-
fill="black"
|
134 |
-
)
|
135 |
-
draw.text(
|
136 |
-
((image.width - text_width) / 2, (image.height - text_height) / 2),
|
137 |
-
no_detect_text,
|
138 |
-
fill="white",
|
139 |
-
font=font_large
|
140 |
-
)
|
141 |
|
142 |
return image
|
143 |
|
@@ -150,26 +134,30 @@ def predict(image):
|
|
150 |
PIL.Image: Annotated image with bounding boxes and labels.
|
151 |
"""
|
152 |
try:
|
153 |
-
print("Starting prediction...")
|
154 |
# Preprocess the image
|
155 |
input_array = preprocess_image(image)
|
156 |
-
|
157 |
-
|
|
|
|
|
158 |
# Run inference
|
159 |
boxes, classes, scores = sess.run(
|
160 |
[detected_boxes, detected_classes, detected_scores],
|
161 |
feed_dict={input_tensor: input_array}
|
162 |
)
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
|
|
|
|
|
|
|
|
169 |
# Annotate the image with bounding boxes and labels
|
170 |
annotated_image = draw_boxes(image.copy(), boxes, classes, scores, threshold=0.5)
|
171 |
-
print("
|
172 |
-
|
173 |
return annotated_image
|
174 |
|
175 |
except Exception as e:
|
@@ -177,34 +165,57 @@ def predict(image):
|
|
177 |
print(f"Exception during prediction: {e}")
|
178 |
|
179 |
# Return an error image with the error message
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
208 |
|
209 |
# Define Gradio interface using the new API
|
210 |
title = "JunkWaxHero 🦸♂️ - Baseball Card Set Identifier"
|
@@ -223,7 +234,7 @@ iface = gr.Interface(
|
|
223 |
title=title,
|
224 |
description=description,
|
225 |
examples=valid_examples,
|
226 |
-
flagging_mode="never" #
|
227 |
)
|
228 |
|
229 |
if __name__ == "__main__":
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
import tensorflow as tf
|
3 |
import numpy as np
|
|
|
56 |
- Normalize pixel values
|
57 |
- Convert RGB to BGR if required by the model
|
58 |
"""
|
|
|
59 |
image = image.resize((TARGET_WIDTH, TARGET_HEIGHT))
|
60 |
image_np = np.array(image).astype(np.float32)
|
61 |
image_np = image_np / 255.0 # Normalize to [0,1]
|
62 |
+
if image_np.shape[-1] == 3:
|
63 |
+
# Convert RGB to BGR if required by your model
|
64 |
+
image_np = image_np[..., (2, 1, 0)]
|
65 |
image_np = np.expand_dims(image_np, axis=0) # Add batch dimension
|
66 |
return image_np
|
67 |
|
|
|
82 |
font = ImageFont.truetype("arial.ttf", 15)
|
83 |
except IOError:
|
84 |
font = ImageFont.load_default()
|
|
|
|
|
85 |
|
86 |
+
# If there are no detections at all
|
87 |
+
if boxes.shape[0] == 0 or boxes.shape[1] == 0:
|
88 |
+
# Return the original image without annotation
|
89 |
+
return image
|
90 |
+
|
91 |
+
# Otherwise, proceed to draw bounding boxes
|
92 |
for box, cls, score in zip(boxes[0], classes[0], scores[0]):
|
93 |
if score < threshold:
|
94 |
continue
|
95 |
+
|
96 |
# Convert box coordinates from normalized to absolute
|
97 |
ymin, xmin, ymax, xmax = box
|
98 |
left = xmin * image.width
|
|
|
104 |
draw.rectangle([(left, top), (right, bottom)], outline="red", width=2)
|
105 |
|
106 |
# Prepare label
|
107 |
+
cls_index = int(cls) - 1 # If your classes are 1-indexed
|
108 |
+
if cls_index < 0 or cls_index >= len(labels):
|
109 |
+
label_str = f"cls_{int(cls)}: {score:.2f}"
|
110 |
+
else:
|
111 |
+
label_str = f"{labels[cls_index]}: {score:.2f}"
|
112 |
|
113 |
# Calculate text size using textbbox
|
114 |
+
text_bbox = draw.textbbox((0, 0), label_str, font=font)
|
115 |
text_width = text_bbox[2] - text_bbox[0]
|
116 |
text_height = text_bbox[3] - text_bbox[1]
|
117 |
|
118 |
# Draw label background
|
119 |
+
draw.rectangle([(left, top - text_height - 4),
|
120 |
+
(left + text_width + 4, top)], fill="red")
|
121 |
|
122 |
# Draw text
|
123 |
+
draw.text((left + 2, top - text_height - 2),
|
124 |
+
label_str, fill="white", font=font)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
|
126 |
return image
|
127 |
|
|
|
134 |
PIL.Image: Annotated image with bounding boxes and labels.
|
135 |
"""
|
136 |
try:
|
|
|
137 |
# Preprocess the image
|
138 |
input_array = preprocess_image(image)
|
139 |
+
|
140 |
+
# Debug prints
|
141 |
+
print(f"[DEBUG] Input shape to model: {input_array.shape}")
|
142 |
+
|
143 |
# Run inference
|
144 |
boxes, classes, scores = sess.run(
|
145 |
[detected_boxes, detected_classes, detected_scores],
|
146 |
feed_dict={input_tensor: input_array}
|
147 |
)
|
148 |
+
|
149 |
+
# Debug prints
|
150 |
+
print(f"[DEBUG] boxes shape: {boxes.shape}, classes shape: {classes.shape}, scores shape: {scores.shape}")
|
151 |
+
|
152 |
+
# Check if the output arrays have zero detections
|
153 |
+
# e.g. boxes could have shape (1, 0, 4) if no detections are found
|
154 |
+
if boxes.size == 0 or classes.size == 0 or scores.size == 0:
|
155 |
+
print("[DEBUG] No detections returned by the model.")
|
156 |
+
return _draw_no_detection_message(image)
|
157 |
+
|
158 |
# Annotate the image with bounding boxes and labels
|
159 |
annotated_image = draw_boxes(image.copy(), boxes, classes, scores, threshold=0.5)
|
160 |
+
print("[DEBUG] Annotation completed.")
|
|
|
161 |
return annotated_image
|
162 |
|
163 |
except Exception as e:
|
|
|
165 |
print(f"Exception during prediction: {e}")
|
166 |
|
167 |
# Return an error image with the error message
|
168 |
+
return _draw_error_message()
|
169 |
+
|
170 |
+
def _draw_no_detection_message(image):
|
171 |
+
"""Draws a simple 'No detections found' message on the image."""
|
172 |
+
draw = ImageDraw.Draw(image)
|
173 |
+
try:
|
174 |
+
font = ImageFont.truetype("arial.ttf", 20)
|
175 |
+
except IOError:
|
176 |
+
font = ImageFont.load_default()
|
177 |
+
message = "No detections found."
|
178 |
+
text_bbox = draw.textbbox((0, 0), message, font=font)
|
179 |
+
text_width = text_bbox[2] - text_bbox[0]
|
180 |
+
text_height = text_bbox[3] - text_bbox[1]
|
181 |
+
|
182 |
+
# Center the message
|
183 |
+
x = (image.width - text_width) / 2
|
184 |
+
y = (image.height - text_height) / 2
|
185 |
+
draw.rectangle(
|
186 |
+
[(x - 10, y - 10), (x + text_width + 10, y + text_height + 10)],
|
187 |
+
fill="black"
|
188 |
+
)
|
189 |
+
draw.text((x, y), message, fill="white", font=font)
|
190 |
+
return image
|
191 |
+
|
192 |
+
def _draw_error_message():
|
193 |
+
"""Creates a red image with a centered error message."""
|
194 |
+
error_image = Image.new('RGB', (500, 500), color=(255, 0, 0))
|
195 |
+
draw = ImageDraw.Draw(error_image)
|
196 |
+
try:
|
197 |
+
font = ImageFont.truetype("arial.ttf", 20)
|
198 |
+
except IOError:
|
199 |
+
font = ImageFont.load_default()
|
200 |
+
error_text = "Error during prediction."
|
201 |
+
text_bbox = draw.textbbox((0, 0), error_text, font=font)
|
202 |
+
text_width = text_bbox[2] - text_bbox[0]
|
203 |
+
text_height = text_bbox[3] - text_bbox[1]
|
204 |
+
|
205 |
+
draw.rectangle(
|
206 |
+
[
|
207 |
+
((500 - text_width) / 2 - 10, (500 - text_height) / 2 - 10),
|
208 |
+
((500 + text_width) / 2 + 10, (500 + text_height) / 2 + 10)
|
209 |
+
],
|
210 |
+
fill="black"
|
211 |
+
)
|
212 |
+
draw.text(
|
213 |
+
((500 - text_width) / 2, (500 - text_height) / 2),
|
214 |
+
error_text,
|
215 |
+
fill="white",
|
216 |
+
font=font
|
217 |
+
)
|
218 |
+
return error_image
|
219 |
|
220 |
# Define Gradio interface using the new API
|
221 |
title = "JunkWaxHero 🦸♂️ - Baseball Card Set Identifier"
|
|
|
234 |
title=title,
|
235 |
description=description,
|
236 |
examples=valid_examples,
|
237 |
+
flagging_mode="never" # Use new Gradio parameter
|
238 |
)
|
239 |
|
240 |
if __name__ == "__main__":
|