wuhp commited on
Commit
69a8551
Β·
verified Β·
1 Parent(s): 68db559

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +82 -56
app.py CHANGED
@@ -7,6 +7,7 @@ import shutil
7
  from ultralytics import YOLO
8
  import requests
9
 
 
10
  MODELS_DIR = "models"
11
  MODELS_INFO_FILE = "models_info.json"
12
  TEMP_DIR = "temp"
@@ -15,17 +16,17 @@ OUTPUT_DIR = "outputs"
15
  def download_file(url, dest_path):
16
  """
17
  Download a file from a URL to the destination path.
18
-
19
  Args:
20
  url (str): The URL to download from.
21
  dest_path (str): The local path to save the file.
22
-
23
  Returns:
24
  bool: True if download succeeded, False otherwise.
25
  """
26
  try:
27
  response = requests.get(url, stream=True)
28
- response.raise_for_status()
29
  with open(dest_path, 'wb') as f:
30
  for chunk in response.iter_content(chunk_size=8192):
31
  f.write(chunk)
@@ -39,35 +40,36 @@ def load_models(models_dir=MODELS_DIR, info_file=MODELS_INFO_FILE):
39
  """
40
  Load YOLO models and their information from the specified directory and JSON file.
41
  Downloads models if they are not already present.
42
-
43
  Args:
44
  models_dir (str): Path to the models directory.
45
  info_file (str): Path to the JSON file containing model info.
46
-
47
  Returns:
48
  dict: A dictionary of models and their associated information.
49
  """
50
  with open(info_file, 'r') as f:
51
  models_info = json.load(f)
52
-
53
  models = {}
54
  for model_info in models_info:
55
  model_name = model_info['model_name']
56
  display_name = model_info.get('display_name', model_name)
57
  model_dir = os.path.join(models_dir, model_name)
58
  os.makedirs(model_dir, exist_ok=True)
59
- model_path = os.path.join(model_dir, f"{model_name}.pt")
60
  download_url = model_info['download_url']
61
-
 
62
  if not os.path.isfile(model_path):
63
  print(f"Model '{display_name}' not found locally. Downloading from {download_url}...")
64
  success = download_file(download_url, model_path)
65
  if not success:
66
  print(f"Skipping model '{display_name}' due to download failure.")
67
- continue
68
-
69
  try:
70
-
71
  model = YOLO(model_path)
72
  models[model_name] = {
73
  'display_name': display_name,
@@ -77,16 +79,16 @@ def load_models(models_dir=MODELS_DIR, info_file=MODELS_INFO_FILE):
77
  print(f"Loaded model '{display_name}' from '{model_path}'.")
78
  except Exception as e:
79
  print(f"Error loading model '{display_name}': {e}")
80
-
81
  return models
82
 
83
  def get_model_info(model_info):
84
  """
85
  Retrieve formatted model information for display.
86
-
87
  Args:
88
  model_info (dict): The model's information dictionary.
89
-
90
  Returns:
91
  str: A formatted string containing model details.
92
  """
@@ -94,11 +96,11 @@ def get_model_info(model_info):
94
  class_ids = info.get('class_ids', {})
95
  class_image_counts = info.get('class_image_counts', {})
96
  datasets_used = info.get('datasets_used', [])
97
-
98
  class_ids_formatted = "\n".join([f"{cid}: {cname}" for cid, cname in class_ids.items()])
99
  class_image_counts_formatted = "\n".join([f"{cname}: {count}" for cname, count in class_image_counts.items()])
100
  datasets_used_formatted = "\n".join([f"- {dataset}" for dataset in datasets_used])
101
-
102
  info_text = (
103
  f"**{info.get('display_name', 'Model Name')}**\n\n"
104
  f"**Architecture:** {info.get('architecture', 'N/A')}\n\n"
@@ -118,13 +120,13 @@ def get_model_info(model_info):
118
  def predict_image(model_name, image, confidence, models):
119
  """
120
  Perform prediction on an uploaded image using the selected YOLO model.
121
-
122
  Args:
123
  model_name (str): The name of the selected model.
124
  image (PIL.Image.Image): The uploaded image.
125
  confidence (float): The confidence threshold for detections.
126
  models (dict): The dictionary containing models and their info.
127
-
128
  Returns:
129
  tuple: A status message, the processed image, and the path to the output image.
130
  """
@@ -133,26 +135,32 @@ def predict_image(model_name, image, confidence, models):
133
  if not model:
134
  return "Error: Model not found.", None, None
135
  try:
136
-
137
  os.makedirs(TEMP_DIR, exist_ok=True)
138
  os.makedirs(OUTPUT_DIR, exist_ok=True)
139
-
 
140
  input_image_path = os.path.join(TEMP_DIR, f"{model_name}_input_image.jpg")
141
  image.save(input_image_path)
142
-
 
143
  results = model(input_image_path, save=True, save_txt=False, conf=confidence)
144
-
 
 
145
  latest_run = sorted(Path("runs/detect").glob("predict*"), key=os.path.getmtime)[-1]
146
  output_image_path = os.path.join(latest_run, Path(input_image_path).name)
147
  if not os.path.isfile(output_image_path):
148
-
149
  output_image_path = results[0].save()[0]
150
-
 
151
  final_output_path = os.path.join(OUTPUT_DIR, f"{model_name}_output_image.jpg")
152
  shutil.copy(output_image_path, final_output_path)
153
-
 
154
  output_image = Image.open(final_output_path)
155
-
156
  return "βœ… Prediction completed successfully.", output_image, final_output_path
157
  except Exception as e:
158
  return f"❌ Error during prediction: {str(e)}", None, None
@@ -160,13 +168,13 @@ def predict_image(model_name, image, confidence, models):
160
  def predict_video(model_name, video, confidence, models):
161
  """
162
  Perform prediction on an uploaded video using the selected YOLO model.
163
-
164
  Args:
165
  model_name (str): The name of the selected model.
166
  video (str): Path to the uploaded video file.
167
  confidence (float): The confidence threshold for detections.
168
  models (dict): The dictionary containing models and their info.
169
-
170
  Returns:
171
  tuple: A status message, the processed video, and the path to the output video.
172
  """
@@ -175,35 +183,42 @@ def predict_video(model_name, video, confidence, models):
175
  if not model:
176
  return "Error: Model not found.", None, None
177
  try:
178
-
179
  os.makedirs(TEMP_DIR, exist_ok=True)
180
  os.makedirs(OUTPUT_DIR, exist_ok=True)
181
-
 
182
  input_video_path = os.path.join(TEMP_DIR, f"{model_name}_input_video.mp4")
183
  shutil.copy(video, input_video_path)
184
-
185
- results = model(input_video_path, save=True, save_txt=False, conf=confidence)
186
-
 
 
 
 
187
  latest_run = sorted(Path("runs/detect").glob("predict*"), key=os.path.getmtime)[-1]
188
- output_video_path = os.path.join(latest_run, Path(input_video_path).name)
189
  if not os.path.isfile(output_video_path):
190
-
191
  output_video_path = results[0].save()[0]
192
-
193
- final_output_path = os.path.join(OUTPUT_DIR, f"{model_name}_output_video.mp4")
 
194
  shutil.copy(output_video_path, final_output_path)
195
-
196
  return "βœ… Prediction completed successfully.", final_output_path, final_output_path
197
  except Exception as e:
198
  return f"❌ Error during prediction: {str(e)}", None, None
199
 
200
  def main():
201
-
202
  models = load_models()
203
  if not models:
204
  print("No models loaded. Please check your models_info.json and model URLs.")
205
  return
206
-
 
207
  with gr.Blocks() as demo:
208
  gr.Markdown("# πŸ§ͺ YOLOv11 Model Tester")
209
  gr.Markdown(
@@ -211,7 +226,8 @@ def main():
211
  Upload images or videos to test different YOLOv11 models. Select a model from the dropdown to see its details.
212
  """
213
  )
214
-
 
215
  with gr.Row():
216
  model_dropdown = gr.Dropdown(
217
  choices=[models[m]['display_name'] for m in models],
@@ -219,9 +235,11 @@ def main():
219
  value=None
220
  )
221
  model_info = gr.Markdown("**Model Information will appear here.**")
222
-
 
223
  display_to_name = {models[m]['display_name']: m for m in models}
224
-
 
225
  def update_model_info(selected_display_name):
226
  if not selected_display_name:
227
  return "Please select a model."
@@ -230,13 +248,14 @@ def main():
230
  return "Model information not available."
231
  model_entry = models[model_name]['info']
232
  return get_model_info(model_entry)
233
-
234
  model_dropdown.change(
235
  fn=update_model_info,
236
  inputs=model_dropdown,
237
  outputs=model_info
238
  )
239
-
 
240
  with gr.Row():
241
  confidence_slider = gr.Slider(
242
  minimum=0.0,
@@ -246,33 +265,37 @@ def main():
246
  label="Confidence Threshold",
247
  info="Adjust the minimum confidence required for detections to be displayed."
248
  )
249
-
 
250
  with gr.Tabs():
251
-
252
  with gr.Tab("πŸ–ΌοΈ Image"):
253
  with gr.Column():
254
  image_input = gr.Image(
255
  type='pil',
256
  label="Upload Image for Prediction"
257
-
258
  )
259
  image_predict_btn = gr.Button("πŸ” Predict on Image")
260
  image_status = gr.Markdown("**Status will appear here.**")
261
  image_output = gr.Image(label="Predicted Image")
262
  image_download_btn = gr.File(label="⬇️ Download Predicted Image")
263
-
 
264
  def process_image(selected_display_name, image, confidence):
265
  if not selected_display_name:
266
  return "❌ Please select a model.", None, None
267
  model_name = display_to_name.get(selected_display_name)
268
  return predict_image(model_name, image, confidence, models)
269
-
 
270
  image_predict_btn.click(
271
  fn=process_image,
272
  inputs=[model_dropdown, image_input, confidence_slider],
273
  outputs=[image_status, image_output, image_download_btn]
274
  )
275
-
 
276
  with gr.Tab("πŸŽ₯ Video"):
277
  with gr.Column():
278
  video_input = gr.Video(
@@ -282,27 +305,30 @@ def main():
282
  video_status = gr.Markdown("**Status will appear here.**")
283
  video_output = gr.Video(label="Predicted Video")
284
  video_download_btn = gr.File(label="⬇️ Download Predicted Video")
285
-
 
286
  def process_video(selected_display_name, video, confidence):
287
  if not selected_display_name:
288
  return "❌ Please select a model.", None, None
289
  model_name = display_to_name.get(selected_display_name)
290
  return predict_video(model_name, video, confidence, models)
291
-
 
292
  video_predict_btn.click(
293
  fn=process_video,
294
  inputs=[model_dropdown, video_input, confidence_slider],
295
  outputs=[video_status, video_output, video_download_btn]
296
  )
297
-
298
  gr.Markdown(
299
  """
300
  ---
301
  **Note:** Models are downloaded from GitHub upon first use. Ensure that you have a stable internet connection and sufficient storage space.
302
  """
303
  )
304
-
 
305
  demo.launch()
306
 
307
  if __name__ == "__main__":
308
- main()
 
7
  from ultralytics import YOLO
8
  import requests
9
 
10
+ # Constants
11
  MODELS_DIR = "models"
12
  MODELS_INFO_FILE = "models_info.json"
13
  TEMP_DIR = "temp"
 
16
  def download_file(url, dest_path):
17
  """
18
  Download a file from a URL to the destination path.
19
+
20
  Args:
21
  url (str): The URL to download from.
22
  dest_path (str): The local path to save the file.
23
+
24
  Returns:
25
  bool: True if download succeeded, False otherwise.
26
  """
27
  try:
28
  response = requests.get(url, stream=True)
29
+ response.raise_for_status() # Raise an error on bad status
30
  with open(dest_path, 'wb') as f:
31
  for chunk in response.iter_content(chunk_size=8192):
32
  f.write(chunk)
 
40
  """
41
  Load YOLO models and their information from the specified directory and JSON file.
42
  Downloads models if they are not already present.
43
+
44
  Args:
45
  models_dir (str): Path to the models directory.
46
  info_file (str): Path to the JSON file containing model info.
47
+
48
  Returns:
49
  dict: A dictionary of models and their associated information.
50
  """
51
  with open(info_file, 'r') as f:
52
  models_info = json.load(f)
53
+
54
  models = {}
55
  for model_info in models_info:
56
  model_name = model_info['model_name']
57
  display_name = model_info.get('display_name', model_name)
58
  model_dir = os.path.join(models_dir, model_name)
59
  os.makedirs(model_dir, exist_ok=True)
60
+ model_path = os.path.join(model_dir, f"{model_name}.pt") # e.g., models/human/human.pt
61
  download_url = model_info['download_url']
62
+
63
+ # Check if the model file exists
64
  if not os.path.isfile(model_path):
65
  print(f"Model '{display_name}' not found locally. Downloading from {download_url}...")
66
  success = download_file(download_url, model_path)
67
  if not success:
68
  print(f"Skipping model '{display_name}' due to download failure.")
69
+ continue # Skip loading this model
70
+
71
  try:
72
+ # Load the YOLO model
73
  model = YOLO(model_path)
74
  models[model_name] = {
75
  'display_name': display_name,
 
79
  print(f"Loaded model '{display_name}' from '{model_path}'.")
80
  except Exception as e:
81
  print(f"Error loading model '{display_name}': {e}")
82
+
83
  return models
84
 
85
  def get_model_info(model_info):
86
  """
87
  Retrieve formatted model information for display.
88
+
89
  Args:
90
  model_info (dict): The model's information dictionary.
91
+
92
  Returns:
93
  str: A formatted string containing model details.
94
  """
 
96
  class_ids = info.get('class_ids', {})
97
  class_image_counts = info.get('class_image_counts', {})
98
  datasets_used = info.get('datasets_used', [])
99
+
100
  class_ids_formatted = "\n".join([f"{cid}: {cname}" for cid, cname in class_ids.items()])
101
  class_image_counts_formatted = "\n".join([f"{cname}: {count}" for cname, count in class_image_counts.items()])
102
  datasets_used_formatted = "\n".join([f"- {dataset}" for dataset in datasets_used])
103
+
104
  info_text = (
105
  f"**{info.get('display_name', 'Model Name')}**\n\n"
106
  f"**Architecture:** {info.get('architecture', 'N/A')}\n\n"
 
120
  def predict_image(model_name, image, confidence, models):
121
  """
122
  Perform prediction on an uploaded image using the selected YOLO model.
123
+
124
  Args:
125
  model_name (str): The name of the selected model.
126
  image (PIL.Image.Image): The uploaded image.
127
  confidence (float): The confidence threshold for detections.
128
  models (dict): The dictionary containing models and their info.
129
+
130
  Returns:
131
  tuple: A status message, the processed image, and the path to the output image.
132
  """
 
135
  if not model:
136
  return "Error: Model not found.", None, None
137
  try:
138
+ # Ensure temporary and output directories exist
139
  os.makedirs(TEMP_DIR, exist_ok=True)
140
  os.makedirs(OUTPUT_DIR, exist_ok=True)
141
+
142
+ # Save the uploaded image to a temporary path
143
  input_image_path = os.path.join(TEMP_DIR, f"{model_name}_input_image.jpg")
144
  image.save(input_image_path)
145
+
146
+ # Perform prediction with user-specified confidence
147
  results = model(input_image_path, save=True, save_txt=False, conf=confidence)
148
+
149
+ # Determine the output path
150
+ # Ultralytics YOLO saves the results in 'runs/detect/predict' by default
151
  latest_run = sorted(Path("runs/detect").glob("predict*"), key=os.path.getmtime)[-1]
152
  output_image_path = os.path.join(latest_run, Path(input_image_path).name)
153
  if not os.path.isfile(output_image_path):
154
+ # Alternative method to get the output path
155
  output_image_path = results[0].save()[0]
156
+
157
+ # Copy the output image to OUTPUT_DIR
158
  final_output_path = os.path.join(OUTPUT_DIR, f"{model_name}_output_image.jpg")
159
  shutil.copy(output_image_path, final_output_path)
160
+
161
+ # Open the output image
162
  output_image = Image.open(final_output_path)
163
+
164
  return "βœ… Prediction completed successfully.", output_image, final_output_path
165
  except Exception as e:
166
  return f"❌ Error during prediction: {str(e)}", None, None
 
168
  def predict_video(model_name, video, confidence, models):
169
  """
170
  Perform prediction on an uploaded video using the selected YOLO model.
171
+
172
  Args:
173
  model_name (str): The name of the selected model.
174
  video (str): Path to the uploaded video file.
175
  confidence (float): The confidence threshold for detections.
176
  models (dict): The dictionary containing models and their info.
177
+
178
  Returns:
179
  tuple: A status message, the processed video, and the path to the output video.
180
  """
 
183
  if not model:
184
  return "Error: Model not found.", None, None
185
  try:
186
+ # Ensure temporary and output directories exist
187
  os.makedirs(TEMP_DIR, exist_ok=True)
188
  os.makedirs(OUTPUT_DIR, exist_ok=True)
189
+
190
+ # Save the uploaded video to a temporary path
191
  input_video_path = os.path.join(TEMP_DIR, f"{model_name}_input_video.mp4")
192
  shutil.copy(video, input_video_path)
193
+
194
+ # Perform prediction with user-specified confidence and specify output format
195
+ # Here, we set save_format to 'avi' to ensure compatibility
196
+ results = model(input_video_path, save=True, save_txt=False, conf=confidence, save_format='avi')
197
+
198
+ # Determine the output path
199
+ # Ultralytics YOLO saves the results in 'runs/detect/predict' by default
200
  latest_run = sorted(Path("runs/detect").glob("predict*"), key=os.path.getmtime)[-1]
201
+ output_video_path = os.path.join(latest_run, f"{model_name}_input_video.avi")
202
  if not os.path.isfile(output_video_path):
203
+ # Alternative method to get the output path
204
  output_video_path = results[0].save()[0]
205
+
206
+ # Copy the output video to OUTPUT_DIR
207
+ final_output_path = os.path.join(OUTPUT_DIR, f"{model_name}_output_video.avi")
208
  shutil.copy(output_video_path, final_output_path)
209
+
210
  return "βœ… Prediction completed successfully.", final_output_path, final_output_path
211
  except Exception as e:
212
  return f"❌ Error during prediction: {str(e)}", None, None
213
 
214
  def main():
215
+ # Load the models and their information
216
  models = load_models()
217
  if not models:
218
  print("No models loaded. Please check your models_info.json and model URLs.")
219
  return
220
+
221
+ # Initialize Gradio Blocks interface
222
  with gr.Blocks() as demo:
223
  gr.Markdown("# πŸ§ͺ YOLOv11 Model Tester")
224
  gr.Markdown(
 
226
  Upload images or videos to test different YOLOv11 models. Select a model from the dropdown to see its details.
227
  """
228
  )
229
+
230
+ # Model selection and info
231
  with gr.Row():
232
  model_dropdown = gr.Dropdown(
233
  choices=[models[m]['display_name'] for m in models],
 
235
  value=None
236
  )
237
  model_info = gr.Markdown("**Model Information will appear here.**")
238
+
239
+ # Mapping from display_name to model_name
240
  display_to_name = {models[m]['display_name']: m for m in models}
241
+
242
+ # Update model_info when a model is selected
243
  def update_model_info(selected_display_name):
244
  if not selected_display_name:
245
  return "Please select a model."
 
248
  return "Model information not available."
249
  model_entry = models[model_name]['info']
250
  return get_model_info(model_entry)
251
+
252
  model_dropdown.change(
253
  fn=update_model_info,
254
  inputs=model_dropdown,
255
  outputs=model_info
256
  )
257
+
258
+ # Confidence Threshold Slider
259
  with gr.Row():
260
  confidence_slider = gr.Slider(
261
  minimum=0.0,
 
265
  label="Confidence Threshold",
266
  info="Adjust the minimum confidence required for detections to be displayed."
267
  )
268
+
269
+ # Tabs for different input types
270
  with gr.Tabs():
271
+ # Image Prediction Tab
272
  with gr.Tab("πŸ–ΌοΈ Image"):
273
  with gr.Column():
274
  image_input = gr.Image(
275
  type='pil',
276
  label="Upload Image for Prediction"
277
+ # Removed 'tool' parameter
278
  )
279
  image_predict_btn = gr.Button("πŸ” Predict on Image")
280
  image_status = gr.Markdown("**Status will appear here.**")
281
  image_output = gr.Image(label="Predicted Image")
282
  image_download_btn = gr.File(label="⬇️ Download Predicted Image")
283
+
284
+ # Define the image prediction function
285
  def process_image(selected_display_name, image, confidence):
286
  if not selected_display_name:
287
  return "❌ Please select a model.", None, None
288
  model_name = display_to_name.get(selected_display_name)
289
  return predict_image(model_name, image, confidence, models)
290
+
291
+ # Connect the predict button
292
  image_predict_btn.click(
293
  fn=process_image,
294
  inputs=[model_dropdown, image_input, confidence_slider],
295
  outputs=[image_status, image_output, image_download_btn]
296
  )
297
+
298
+ # Video Prediction Tab
299
  with gr.Tab("πŸŽ₯ Video"):
300
  with gr.Column():
301
  video_input = gr.Video(
 
305
  video_status = gr.Markdown("**Status will appear here.**")
306
  video_output = gr.Video(label="Predicted Video")
307
  video_download_btn = gr.File(label="⬇️ Download Predicted Video")
308
+
309
+ # Define the video prediction function
310
  def process_video(selected_display_name, video, confidence):
311
  if not selected_display_name:
312
  return "❌ Please select a model.", None, None
313
  model_name = display_to_name.get(selected_display_name)
314
  return predict_video(model_name, video, confidence, models)
315
+
316
+ # Connect the predict button
317
  video_predict_btn.click(
318
  fn=process_video,
319
  inputs=[model_dropdown, video_input, confidence_slider],
320
  outputs=[video_status, video_output, video_download_btn]
321
  )
322
+
323
  gr.Markdown(
324
  """
325
  ---
326
  **Note:** Models are downloaded from GitHub upon first use. Ensure that you have a stable internet connection and sufficient storage space.
327
  """
328
  )
329
+
330
+ # Launch the Gradio app
331
  demo.launch()
332
 
333
  if __name__ == "__main__":
334
+ main()