Eric P. Nusbaum commited on
Commit
2d23a8b
·
1 Parent(s): 987957a

Updated app.py

Browse files
Files changed (2) hide show
  1. app.py +41 -15
  2. requirements.txt +2 -2
app.py CHANGED
@@ -9,6 +9,9 @@ import os
9
  # Suppress TensorFlow logging for cleaner logs
10
  os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
11
 
 
 
 
12
  # Load labels
13
  with open('tensorflow/labels.txt', 'r') as f:
14
  labels = f.read().splitlines()
@@ -36,9 +39,8 @@ detected_classes = graph.get_tensor_by_name('detected_classes:0')
36
  detected_scores = graph.get_tensor_by_name('detected_scores:0')
37
 
38
  # Define the target size based on your model's expected input
39
- # You may need to adjust these values based on your model's requirements
40
- TARGET_WIDTH = 224
41
- TARGET_HEIGHT = 224
42
 
43
  def preprocess_image(image):
44
  """
@@ -89,12 +91,16 @@ def draw_boxes(image, boxes, classes, scores, threshold=0.5):
89
  # Prepare label
90
  label = f"{labels[int(cls) - 1]}: {score:.2f}"
91
 
 
 
 
 
 
92
  # Draw label background
93
- text_size = draw.textsize(label, font=font)
94
- draw.rectangle([(left, top - text_size[1]), (left + text_size[0], top)], fill="red")
95
 
96
  # Draw text
97
- draw.text((left, top - text_size[1]), label, fill="white", font=font)
98
 
99
  return image
100
 
@@ -122,7 +128,7 @@ def predict(image):
122
  return annotated_image
123
 
124
  except Exception as e:
125
- # Return a red image indicating an error
126
  error_image = Image.new('RGB', (500, 500), color=(255, 0, 0))
127
  draw = ImageDraw.Draw(error_image)
128
  try:
@@ -130,26 +136,46 @@ def predict(image):
130
  except IOError:
131
  font = ImageFont.load_default()
132
  error_text = "Error during prediction."
133
- text_size = draw.textsize(error_text, font=font)
134
- draw.text(((500 - text_size[0]) / 2, (500 - text_size[1]) / 2), error_text, fill="white", font=font)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  return error_image
136
 
137
  # Define Gradio interface using the new API
138
  title = "JunkWaxHero 🦸‍♂️ - Baseball Card Set Identifier"
139
  description = "Upload an image of a baseball card, and JunkWaxHero will identify the set it belongs to with high accuracy."
140
 
 
 
 
 
 
 
141
  iface = gr.Interface(
142
  fn=predict,
143
  inputs=gr.Image(type="pil"),
144
  outputs=gr.Image(type="pil"),
145
  title=title,
146
  description=description,
147
- examples=[
148
- ["examples/card1.jpg"],
149
- ["examples/card2.jpg"],
150
- ["examples/card3.jpg"]
151
- ],
152
- allow_flagging="never"
153
  )
154
 
155
  if __name__ == "__main__":
 
9
  # Suppress TensorFlow logging for cleaner logs
10
  os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
11
 
12
+ # Disable GPU usage explicitly to prevent TensorFlow from attempting to access GPU libraries
13
+ os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
14
+
15
  # Load labels
16
  with open('tensorflow/labels.txt', 'r') as f:
17
  labels = f.read().splitlines()
 
39
  detected_scores = graph.get_tensor_by_name('detected_scores:0')
40
 
41
  # Define the target size based on your model's expected input
42
+ TARGET_WIDTH = 320
43
+ TARGET_HEIGHT = 320
 
44
 
45
  def preprocess_image(image):
46
  """
 
91
  # Prepare label
92
  label = f"{labels[int(cls) - 1]}: {score:.2f}"
93
 
94
+ # Calculate text size using textbbox
95
+ text_bbox = draw.textbbox((0, 0), label, font=font)
96
+ text_width = text_bbox[2] - text_bbox[0]
97
+ text_height = text_bbox[3] - text_bbox[1]
98
+
99
  # Draw label background
100
+ draw.rectangle([(left, top - text_height - 4), (left + text_width + 4, top)], fill="red")
 
101
 
102
  # Draw text
103
+ draw.text((left + 2, top - text_height - 2), label, fill="white", font=font)
104
 
105
  return image
106
 
 
128
  return annotated_image
129
 
130
  except Exception as e:
131
+ # Return an error image with the error message
132
  error_image = Image.new('RGB', (500, 500), color=(255, 0, 0))
133
  draw = ImageDraw.Draw(error_image)
134
  try:
 
136
  except IOError:
137
  font = ImageFont.load_default()
138
  error_text = "Error during prediction."
139
+
140
+ # Calculate text size using textbbox
141
+ text_bbox = draw.textbbox((0, 0), error_text, font=font)
142
+ text_width = text_bbox[2] - text_bbox[0]
143
+ text_height = text_bbox[3] - text_bbox[1]
144
+
145
+ # Center the text
146
+ draw.rectangle(
147
+ [
148
+ ((500 - text_width) / 2 - 10, (500 - text_height) / 2 - 10),
149
+ ((500 + text_width) / 2 + 10, (500 + text_height) / 2 + 10)
150
+ ],
151
+ fill="black"
152
+ )
153
+ draw.text(
154
+ ((500 - text_width) / 2, (500 - text_height) / 2),
155
+ error_text,
156
+ fill="white",
157
+ font=font
158
+ )
159
  return error_image
160
 
161
  # Define Gradio interface using the new API
162
  title = "JunkWaxHero 🦸‍♂️ - Baseball Card Set Identifier"
163
  description = "Upload an image of a baseball card, and JunkWaxHero will identify the set it belongs to with high accuracy."
164
 
165
+ # Verify that example images exist to prevent FileNotFoundError
166
+ example_images = ["examples/card1.jpg", "examples/card2.jpg", "examples/card3.jpg"]
167
+ valid_examples = [img for img in example_images if os.path.exists(img)]
168
+ if not valid_examples:
169
+ valid_examples = None # Remove examples if none exist
170
+
171
  iface = gr.Interface(
172
  fn=predict,
173
  inputs=gr.Image(type="pil"),
174
  outputs=gr.Image(type="pil"),
175
  title=title,
176
  description=description,
177
+ examples=valid_examples,
178
+ flagging_mode="never" # Updated parameter
 
 
 
 
179
  )
180
 
181
  if __name__ == "__main__":
requirements.txt CHANGED
@@ -1,4 +1,4 @@
1
- gradio
2
  tensorflow==2.11.0 # Ensure this matches the TensorFlow version used during model training
3
  Pillow
4
- numpy<2.0.0
 
1
+ gradio>=3.0.0
2
  tensorflow==2.11.0 # Ensure this matches the TensorFlow version used during model training
3
  Pillow
4
+ numpy<2.0.0