wuhp commited on
Commit
e8e74ae
Β·
verified Β·
1 Parent(s): f855bd4

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +180 -76
main.py CHANGED
@@ -5,15 +5,46 @@ from pathlib import Path
5
  from PIL import Image
6
  import shutil
7
  from ultralytics import YOLO
 
8
 
9
- def load_models(models_dir='models', info_file='models_info.json'):
 
 
 
 
 
 
10
  """
11
- Load YOLO models and their information from the specified directory and JSON file.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
 
 
 
 
 
13
  Args:
14
  models_dir (str): Path to the models directory.
15
  info_file (str): Path to the JSON file containing model info.
16
-
17
  Returns:
18
  dict: A dictionary of models and their associated information.
19
  """
@@ -23,138 +54,205 @@ def load_models(models_dir='models', info_file='models_info.json'):
23
  models = {}
24
  for model_info in models_info:
25
  model_name = model_info['model_name']
26
- model_path = os.path.join(models_dir, model_name, 'best.pt') # Assuming 'best.pt' as the weight file
27
- if os.path.isfile(model_path):
28
- try:
29
- # Load the YOLO model
30
- model = YOLO(model_path)
31
- models[model_name] = {
32
- 'model': model,
33
- 'mAP': model_info.get('mAP_score', 'N/A'),
34
- 'num_images': model_info.get('num_images', 'N/A')
35
- }
36
- print(f"Loaded model '{model_name}' from '{model_path}'.")
37
- except Exception as e:
38
- print(f"Error loading model '{model_name}': {e}")
39
- else:
40
- print(f"Model weight file for '{model_name}' not found at '{model_path}'. Skipping.")
 
 
 
 
 
 
 
 
 
 
 
41
  return models
42
 
43
- def get_model_info(model_name, models):
44
  """
45
- Retrieve model information for the selected model.
46
-
47
  Args:
48
- model_name (str): The name of the model.
49
- models (dict): The dictionary containing models and their info.
50
-
51
  Returns:
52
- str: A formatted string containing model information.
53
  """
54
- model_info = models.get(model_name, {})
55
- if not model_info:
56
- return "Model information not available."
 
 
 
 
 
 
57
  info_text = (
58
- f"**Model Name:** {model_name}\n\n"
59
- f"**mAP Score:** {model_info.get('mAP', 'N/A')}\n\n"
60
- f"**Number of Images Trained On:** {model_info.get('num_images', 'N/A')}"
 
 
 
 
 
 
 
 
 
61
  )
62
  return info_text
63
 
64
  def predict_image(model_name, image, models):
65
  """
66
  Perform prediction on an uploaded image using the selected YOLO model.
67
-
68
  Args:
69
  model_name (str): The name of the selected model.
70
  image (PIL.Image.Image): The uploaded image.
71
  models (dict): The dictionary containing models and their info.
72
-
73
  Returns:
74
  tuple: A status message, the processed image, and the path to the output image.
75
  """
76
- model = models.get(model_name, {}).get('model', None)
 
77
  if not model:
78
  return "Error: Model not found.", None, None
79
  try:
 
 
 
 
80
  # Save the uploaded image to a temporary path
81
- input_image_path = f"temp/{model_name}_input_image.jpg"
82
- os.makedirs(os.path.dirname(input_image_path), exist_ok=True)
83
  image.save(input_image_path)
84
-
85
  # Perform prediction
86
  results = model(input_image_path, save=True, save_txt=False, conf=0.25)
87
- # Ultralytics saves the result images in 'runs/detect/predict'
88
- output_image_path = results[0].save()[0] # Get the path to the saved image
89
-
 
 
 
 
 
 
 
 
 
 
 
90
  # Open the output image
91
- output_image = Image.open(output_image_path)
92
-
93
- return "Prediction completed successfully.", output_image, output_image_path
94
  except Exception as e:
95
- return f"Error during prediction: {str(e)}", None, None
96
 
97
  def predict_video(model_name, video, models):
98
  """
99
  Perform prediction on an uploaded video using the selected YOLO model.
100
-
101
  Args:
102
  model_name (str): The name of the selected model.
103
  video (str): Path to the uploaded video file.
104
  models (dict): The dictionary containing models and their info.
105
-
106
  Returns:
107
  tuple: A status message, the processed video, and the path to the output video.
108
  """
109
- model = models.get(model_name, {}).get('model', None)
 
110
  if not model:
111
  return "Error: Model not found.", None, None
112
  try:
113
- # Ensure the video is saved in a temporary location
114
- input_video_path = video.name
115
- if not os.path.isfile(input_video_path):
116
- # If the video is a temp file provided by Gradio
117
- shutil.copy(video.name, input_video_path)
 
 
118
 
119
  # Perform prediction
120
  results = model(input_video_path, save=True, save_txt=False, conf=0.25)
121
- # Ultralytics saves the result videos in 'runs/detect/predict'
122
- output_video_path = results[0].save()[0] # Get the path to the saved video
123
-
124
- return "Prediction completed successfully.", output_video_path, output_video_path
 
 
 
 
 
 
 
 
 
 
125
  except Exception as e:
126
- return f"Error during prediction: {str(e)}", None, None
127
 
128
  def main():
129
  # Load the models and their information
130
  models = load_models()
131
-
 
 
 
132
  # Initialize Gradio Blocks interface
133
  with gr.Blocks() as demo:
134
- gr.Markdown("# πŸ§ͺ YOLO Model Tester")
135
-
136
  gr.Markdown(
137
  """
138
- Upload images or videos to test different YOLO models. Select a model from the dropdown to see its details.
139
  """
140
  )
141
-
142
  # Model selection and info
143
  with gr.Row():
144
  model_dropdown = gr.Dropdown(
145
- choices=list(models.keys()),
146
  label="Select Model",
147
  value=None
148
  )
149
  model_info = gr.Markdown("**Model Information will appear here.**")
150
-
 
 
 
151
  # Update model_info when a model is selected
 
 
 
 
 
 
 
 
 
152
  model_dropdown.change(
153
- fn=lambda model_name: get_model_info(model_name, models) if model_name else "Please select a model.",
154
  inputs=model_dropdown,
155
  outputs=model_info
156
  )
157
-
158
  # Tabs for different input types
159
  with gr.Tabs():
160
  # Image Prediction Tab
@@ -169,18 +267,21 @@ def main():
169
  image_status = gr.Markdown("**Status will appear here.**")
170
  image_output = gr.Image(label="Predicted Image")
171
  image_download_btn = gr.File(label="⬇️ Download Predicted Image")
172
-
173
  # Define the image prediction function
174
- def process_image(model_name, image):
 
 
 
175
  return predict_image(model_name, image, models)
176
-
177
  # Connect the predict button
178
  image_predict_btn.click(
179
  fn=process_image,
180
  inputs=[model_dropdown, image_input],
181
  outputs=[image_status, image_output, image_download_btn]
182
  )
183
-
184
  # Video Prediction Tab
185
  with gr.Tab("πŸŽ₯ Video"):
186
  with gr.Column():
@@ -191,25 +292,28 @@ def main():
191
  video_status = gr.Markdown("**Status will appear here.**")
192
  video_output = gr.Video(label="Predicted Video")
193
  video_download_btn = gr.File(label="⬇️ Download Predicted Video")
194
-
195
  # Define the video prediction function
196
- def process_video(model_name, video):
 
 
 
197
  return predict_video(model_name, video, models)
198
-
199
  # Connect the predict button
200
  video_predict_btn.click(
201
  fn=process_video,
202
  inputs=[model_dropdown, video_input],
203
  outputs=[video_status, video_output, video_download_btn]
204
  )
205
-
206
  gr.Markdown(
207
  """
208
  ---
209
- **Note:** Ensure that the YOLO models are correctly placed in the `models/` directory and that `models_info.json` is properly configured.
210
  """
211
  )
212
-
213
  # Launch the Gradio app
214
  demo.launch()
215
 
 
5
  from PIL import Image
6
  import shutil
7
  from ultralytics import YOLO
8
+ import requests
9
 
10
+ # Constants
11
+ MODELS_DIR = "models"
12
+ MODELS_INFO_FILE = "models_info.json"
13
+ TEMP_DIR = "temp"
14
+ OUTPUT_DIR = "outputs"
15
+
16
+ def download_file(url, dest_path):
17
  """
18
+ Download a file from a URL to the destination path.
19
+
20
+ Args:
21
+ url (str): The URL to download from.
22
+ dest_path (str): The local path to save the file.
23
+
24
+ Returns:
25
+ bool: True if download succeeded, False otherwise.
26
+ """
27
+ try:
28
+ response = requests.get(url, stream=True)
29
+ response.raise_for_status() # Raise an error on bad status
30
+ with open(dest_path, 'wb') as f:
31
+ for chunk in response.iter_content(chunk_size=8192):
32
+ f.write(chunk)
33
+ print(f"Downloaded {url} to {dest_path}.")
34
+ return True
35
+ except Exception as e:
36
+ print(f"Failed to download {url}. Error: {e}")
37
+ return False
38
 
39
+ def load_models(models_dir=MODELS_DIR, info_file=MODELS_INFO_FILE):
40
+ """
41
+ Load YOLO models and their information from the specified directory and JSON file.
42
+ Downloads models if they are not already present.
43
+
44
  Args:
45
  models_dir (str): Path to the models directory.
46
  info_file (str): Path to the JSON file containing model info.
47
+
48
  Returns:
49
  dict: A dictionary of models and their associated information.
50
  """
 
54
  models = {}
55
  for model_info in models_info:
56
  model_name = model_info['model_name']
57
+ display_name = model_info.get('display_name', model_name)
58
+ model_dir = os.path.join(models_dir, model_name)
59
+ os.makedirs(model_dir, exist_ok=True)
60
+ model_path = os.path.join(model_dir, f"{model_name}.pt") # e.g., models/human/human.pt
61
+ download_url = model_info['download_url']
62
+
63
+ # Check if the model file exists
64
+ if not os.path.isfile(model_path):
65
+ print(f"Model '{display_name}' not found locally. Downloading from {download_url}...")
66
+ success = download_file(download_url, model_path)
67
+ if not success:
68
+ print(f"Skipping model '{display_name}' due to download failure.")
69
+ continue # Skip loading this model
70
+
71
+ try:
72
+ # Load the YOLO model
73
+ model = YOLO(model_path)
74
+ models[model_name] = {
75
+ 'display_name': display_name,
76
+ 'model': model,
77
+ 'info': model_info
78
+ }
79
+ print(f"Loaded model '{display_name}' from '{model_path}'.")
80
+ except Exception as e:
81
+ print(f"Error loading model '{display_name}': {e}")
82
+
83
  return models
84
 
85
+ def get_model_info(model_info):
86
  """
87
+ Retrieve formatted model information for display.
88
+
89
  Args:
90
+ model_info (dict): The model's information dictionary.
91
+
 
92
  Returns:
93
+ str: A formatted string containing model details.
94
  """
95
+ info = model_info
96
+ class_ids = info.get('class_ids', {})
97
+ class_image_counts = info.get('class_image_counts', {})
98
+ datasets_used = info.get('datasets_used', [])
99
+
100
+ class_ids_formatted = "\n".join([f"{cid}: {cname}" for cid, cname in class_ids.items()])
101
+ class_image_counts_formatted = "\n".join([f"{cname}: {count}" for cname, count in class_image_counts.items()])
102
+ datasets_used_formatted = "\n".join([f"- {dataset}" for dataset in datasets_used])
103
+
104
  info_text = (
105
+ f"**{info.get('display_name', 'Model Name')}**\n\n"
106
+ f"**Architecture:** {info.get('architecture', 'N/A')}\n\n"
107
+ f"**Training Epochs:** {info.get('training_epochs', 'N/A')}\n\n"
108
+ f"**Batch Size:** {info.get('batch_size', 'N/A')}\n\n"
109
+ f"**Optimizer:** {info.get('optimizer', 'N/A')}\n\n"
110
+ f"**Learning Rate:** {info.get('learning_rate', 'N/A')}\n\n"
111
+ f"**Data Augmentation Level:** {info.get('data_augmentation_level', 'N/A')}\n\n"
112
+ f"**[email protected]:** {info.get('mAP_score', 'N/A')}\n\n"
113
+ f"**Number of Images Trained On:** {info.get('num_images', 'N/A')}\n\n"
114
+ f"**Class IDs:**\n{class_ids_formatted}\n\n"
115
+ f"**Datasets Used:**\n{datasets_used_formatted}\n\n"
116
+ f"**Class Image Counts:**\n{class_image_counts_formatted}"
117
  )
118
  return info_text
119
 
120
  def predict_image(model_name, image, models):
121
  """
122
  Perform prediction on an uploaded image using the selected YOLO model.
123
+
124
  Args:
125
  model_name (str): The name of the selected model.
126
  image (PIL.Image.Image): The uploaded image.
127
  models (dict): The dictionary containing models and their info.
128
+
129
  Returns:
130
  tuple: A status message, the processed image, and the path to the output image.
131
  """
132
+ model_entry = models.get(model_name, {})
133
+ model = model_entry.get('model', None)
134
  if not model:
135
  return "Error: Model not found.", None, None
136
  try:
137
+ # Ensure temporary and output directories exist
138
+ os.makedirs(TEMP_DIR, exist_ok=True)
139
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
140
+
141
  # Save the uploaded image to a temporary path
142
+ input_image_path = os.path.join(TEMP_DIR, f"{model_name}_input_image.jpg")
 
143
  image.save(input_image_path)
144
+
145
  # Perform prediction
146
  results = model(input_image_path, save=True, save_txt=False, conf=0.25)
147
+
148
+ # Determine the output path
149
+ # Ultralytics YOLO saves the results in 'runs/detect/predict' by default
150
+ # We'll move the result to OUTPUT_DIR with a unique name
151
+ latest_run = sorted(Path("runs/detect").glob("predict*"), key=os.path.getmtime)[-1]
152
+ output_image_path = os.path.join(latest_run, Path(input_image_path).name)
153
+ if not os.path.isfile(output_image_path):
154
+ # Alternative method to get the output path
155
+ output_image_path = results[0].save()[0]
156
+
157
+ # Copy the output image to OUTPUT_DIR
158
+ final_output_path = os.path.join(OUTPUT_DIR, f"{model_name}_output_image.jpg")
159
+ shutil.copy(output_image_path, final_output_path)
160
+
161
  # Open the output image
162
+ output_image = Image.open(final_output_path)
163
+
164
+ return "βœ… Prediction completed successfully.", output_image, final_output_path
165
  except Exception as e:
166
+ return f"❌ Error during prediction: {str(e)}", None, None
167
 
168
  def predict_video(model_name, video, models):
169
  """
170
  Perform prediction on an uploaded video using the selected YOLO model.
171
+
172
  Args:
173
  model_name (str): The name of the selected model.
174
  video (str): Path to the uploaded video file.
175
  models (dict): The dictionary containing models and their info.
176
+
177
  Returns:
178
  tuple: A status message, the processed video, and the path to the output video.
179
  """
180
+ model_entry = models.get(model_name, {})
181
+ model = model_entry.get('model', None)
182
  if not model:
183
  return "Error: Model not found.", None, None
184
  try:
185
+ # Ensure temporary and output directories exist
186
+ os.makedirs(TEMP_DIR, exist_ok=True)
187
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
188
+
189
+ # Save the uploaded video to a temporary path
190
+ input_video_path = os.path.join(TEMP_DIR, f"{model_name}_input_video.mp4")
191
+ shutil.copy(video, input_video_path)
192
 
193
  # Perform prediction
194
  results = model(input_video_path, save=True, save_txt=False, conf=0.25)
195
+
196
+ # Determine the output path
197
+ # Ultralytics YOLO saves the results in 'runs/detect/predict' by default
198
+ latest_run = sorted(Path("runs/detect").glob("predict*"), key=os.path.getmtime)[-1]
199
+ output_video_path = os.path.join(latest_run, Path(input_video_path).name)
200
+ if not os.path.isfile(output_video_path):
201
+ # Alternative method to get the output path
202
+ output_video_path = results[0].save()[0]
203
+
204
+ # Copy the output video to OUTPUT_DIR
205
+ final_output_path = os.path.join(OUTPUT_DIR, f"{model_name}_output_video.mp4")
206
+ shutil.copy(output_video_path, final_output_path)
207
+
208
+ return "βœ… Prediction completed successfully.", final_output_path, final_output_path
209
  except Exception as e:
210
+ return f"❌ Error during prediction: {str(e)}", None, None
211
 
212
  def main():
213
  # Load the models and their information
214
  models = load_models()
215
+ if not models:
216
+ print("No models loaded. Please check your models_info.json and model URLs.")
217
+ return
218
+
219
  # Initialize Gradio Blocks interface
220
  with gr.Blocks() as demo:
221
+ gr.Markdown("# πŸ§ͺ YOLOv11 Model Tester")
 
222
  gr.Markdown(
223
  """
224
+ Upload images or videos to test different YOLOv11 models. Select a model from the dropdown to see its details.
225
  """
226
  )
227
+
228
  # Model selection and info
229
  with gr.Row():
230
  model_dropdown = gr.Dropdown(
231
+ choices=[models[m]['display_name'] for m in models],
232
  label="Select Model",
233
  value=None
234
  )
235
  model_info = gr.Markdown("**Model Information will appear here.**")
236
+
237
+ # Mapping from display_name to model_name
238
+ display_to_name = {models[m]['display_name']: m for m in models}
239
+
240
  # Update model_info when a model is selected
241
+ def update_model_info(selected_display_name):
242
+ if not selected_display_name:
243
+ return "Please select a model."
244
+ model_name = display_to_name.get(selected_display_name)
245
+ if not model_name:
246
+ return "Model information not available."
247
+ model_entry = models[model_name]['info']
248
+ return get_model_info(model_entry)
249
+
250
  model_dropdown.change(
251
+ fn=update_model_info,
252
  inputs=model_dropdown,
253
  outputs=model_info
254
  )
255
+
256
  # Tabs for different input types
257
  with gr.Tabs():
258
  # Image Prediction Tab
 
267
  image_status = gr.Markdown("**Status will appear here.**")
268
  image_output = gr.Image(label="Predicted Image")
269
  image_download_btn = gr.File(label="⬇️ Download Predicted Image")
270
+
271
  # Define the image prediction function
272
+ def process_image(selected_display_name, image):
273
+ if not selected_display_name:
274
+ return "❌ Please select a model.", None, None
275
+ model_name = display_to_name.get(selected_display_name)
276
  return predict_image(model_name, image, models)
277
+
278
  # Connect the predict button
279
  image_predict_btn.click(
280
  fn=process_image,
281
  inputs=[model_dropdown, image_input],
282
  outputs=[image_status, image_output, image_download_btn]
283
  )
284
+
285
  # Video Prediction Tab
286
  with gr.Tab("πŸŽ₯ Video"):
287
  with gr.Column():
 
292
  video_status = gr.Markdown("**Status will appear here.**")
293
  video_output = gr.Video(label="Predicted Video")
294
  video_download_btn = gr.File(label="⬇️ Download Predicted Video")
295
+
296
  # Define the video prediction function
297
+ def process_video(selected_display_name, video):
298
+ if not selected_display_name:
299
+ return "❌ Please select a model.", None, None
300
+ model_name = display_to_name.get(selected_display_name)
301
  return predict_video(model_name, video, models)
302
+
303
  # Connect the predict button
304
  video_predict_btn.click(
305
  fn=process_video,
306
  inputs=[model_dropdown, video_input],
307
  outputs=[video_status, video_output, video_download_btn]
308
  )
309
+
310
  gr.Markdown(
311
  """
312
  ---
313
+ **Note:** Models are downloaded from GitHub upon first use. Ensure that you have a stable internet connection and sufficient storage space.
314
  """
315
  )
316
+
317
  # Launch the Gradio app
318
  demo.launch()
319