mlbench123 commited on
Commit
3015dd2
Β·
verified Β·
1 Parent(s): 7a23e97

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -11
app.py CHANGED
@@ -32,17 +32,61 @@ class DroneDetectionPipeline:
32
  self._load_model()
33
 
34
  def _load_model(self) -> None:
35
- """Load YOLO model with error handling"""
36
  try:
37
  if os.path.exists(self.model_path):
38
- self.model = YOLO(self.model_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  logger.info(f"βœ… Model loaded successfully from {self.model_path}")
40
  else:
41
  logger.error(f"❌ Model file not found: {self.model_path}")
42
- raise FileNotFoundError(f"Model file not found: {self.model_path}")
 
 
 
 
43
  except Exception as e:
44
  logger.error(f"❌ Error loading model: {str(e)}")
45
- raise
 
 
 
 
 
 
46
 
47
  def _initialize_tracker(self, frame: np.ndarray, bbox: Tuple[int, int, int, int]) -> bool:
48
  """Initialize CSRT tracker with given bounding box"""
@@ -210,22 +254,39 @@ class DroneDetectionPipeline:
210
  pipeline = DroneDetectionPipeline()
211
 
212
  def process_video_gradio(input_video, confidence_threshold):
213
- """Gradio interface function"""
214
  if input_video is None:
215
  return None, "❌ Please upload a video file"
216
 
217
  try:
 
 
 
 
218
  # Update confidence threshold
219
  pipeline.confidence_threshold = confidence_threshold
220
 
 
 
 
 
 
 
 
 
221
  # Process video
222
  output_path = pipeline.process_video(input_video)
223
 
224
- return output_path, "βœ… Video processing completed successfully!"
 
 
 
 
225
 
226
  except Exception as e:
227
  error_msg = f"❌ Error processing video: {str(e)}"
228
  logger.error(error_msg)
 
229
  return None, error_msg
230
 
231
  # Gradio Interface
@@ -304,13 +365,24 @@ def create_interface():
304
  gr.Markdown(article)
305
 
306
  # Example section
307
- gr.Markdown("### πŸ“‹ System Requirements")
308
  gr.Markdown("""
309
- - Input video formats: MP4, AVI, MOV, MKV
310
- - Maximum file size: 200MB
311
- - Recommended resolution: Up to 1920x1080
312
- - Processing time: ~1-2 minutes per minute of video
 
 
 
 
 
 
 
313
  """)
 
 
 
 
314
 
315
  return interface
316
 
 
32
  self._load_model()
33
 
34
  def _load_model(self) -> None:
35
+ """Load YOLO model with error handling and PyTorch compatibility"""
36
  try:
37
  if os.path.exists(self.model_path):
38
+ # Handle PyTorch 2.6+ security restrictions for model loading
39
+ import torch
40
+ from ultralytics.nn.tasks import DetectionModel
41
+
42
+ # Add safe globals for ultralytics models
43
+ torch.serialization.add_safe_globals([DetectionModel])
44
+
45
+ # Load model with weights_only=False for older model formats
46
+ try:
47
+ self.model = YOLO(self.model_path)
48
+ except Exception as pytorch_error:
49
+ logger.warning(f"⚠️ Standard loading failed: {str(pytorch_error)}")
50
+ logger.info("πŸ”„ Attempting alternative loading method...")
51
+
52
+ # Alternative loading for older PyTorch model formats
53
+ import tempfile
54
+ import shutil
55
+
56
+ # Create temporary model with legacy loading
57
+ with tempfile.TemporaryDirectory() as temp_dir:
58
+ temp_model_path = os.path.join(temp_dir, "temp_model.pt")
59
+ shutil.copy2(self.model_path, temp_model_path)
60
+
61
+ # Patch torch.load temporarily for this specific model
62
+ original_load = torch.load
63
+ def patched_load(f, **kwargs):
64
+ kwargs['weights_only'] = False
65
+ return original_load(f, **kwargs)
66
+
67
+ torch.load = patched_load
68
+ try:
69
+ self.model = YOLO(temp_model_path)
70
+ finally:
71
+ torch.load = original_load
72
+
73
  logger.info(f"βœ… Model loaded successfully from {self.model_path}")
74
  else:
75
  logger.error(f"❌ Model file not found: {self.model_path}")
76
+ # Try to load a default YOLOv8 model if custom model is not found
77
+ logger.info("πŸ”„ Loading default YOLOv8n model as fallback...")
78
+ self.model = YOLO('yolov8n.pt')
79
+ logger.warning("⚠️ Using default YOLOv8n model - detection accuracy may be reduced")
80
+
81
  except Exception as e:
82
  logger.error(f"❌ Error loading model: {str(e)}")
83
+ logger.info("πŸ”„ Attempting to load default YOLOv8n model as final fallback...")
84
+ try:
85
+ self.model = YOLO('yolov8n.pt')
86
+ logger.warning("⚠️ Using default YOLOv8n model - please upload your custom 'best.pt' model for optimal drone detection")
87
+ except Exception as fallback_error:
88
+ logger.error(f"❌ Failed to load any model: {str(fallback_error)}")
89
+ raise RuntimeError("Could not load any YOLO model. Please check your model file and dependencies.")
90
 
91
  def _initialize_tracker(self, frame: np.ndarray, bbox: Tuple[int, int, int, int]) -> bool:
92
  """Initialize CSRT tracker with given bounding box"""
 
254
  pipeline = DroneDetectionPipeline()
255
 
256
  def process_video_gradio(input_video, confidence_threshold):
257
+ """Gradio interface function with enhanced error handling"""
258
  if input_video is None:
259
  return None, "❌ Please upload a video file"
260
 
261
  try:
262
+ # Check if model is loaded properly
263
+ if pipeline.model is None:
264
+ return None, "❌ Model not loaded. Please check that 'best.pt' is uploaded to the space."
265
+
266
  # Update confidence threshold
267
  pipeline.confidence_threshold = confidence_threshold
268
 
269
+ # Add file size check
270
+ file_size = os.path.getsize(input_video)
271
+ max_size = 200 * 1024 * 1024 # 200MB
272
+ if file_size > max_size:
273
+ return None, f"❌ File too large ({file_size/1024/1024:.1f}MB). Maximum size is 200MB."
274
+
275
+ logger.info(f"πŸ“Ή Processing video: {input_video} ({file_size/1024/1024:.1f}MB)")
276
+
277
  # Process video
278
  output_path = pipeline.process_video(input_video)
279
 
280
+ success_msg = "βœ… Video processing completed successfully!"
281
+ if "yolov8n" in str(pipeline.model.model):
282
+ success_msg += "\n⚠️ Note: Using default YOLOv8n model. Upload 'best.pt' for optimal drone detection."
283
+
284
+ return output_path, success_msg
285
 
286
  except Exception as e:
287
  error_msg = f"❌ Error processing video: {str(e)}"
288
  logger.error(error_msg)
289
+ logger.error("Full traceback:", exc_info=True)
290
  return None, error_msg
291
 
292
  # Gradio Interface
 
365
  gr.Markdown(article)
366
 
367
  # Example section
368
+ gr.Markdown("### πŸ“‹ System Requirements & Setup")
369
  gr.Markdown("""
370
+ - **Input video formats**: MP4, AVI, MOV, MKV
371
+ - **Maximum file size**: 200MB
372
+ - **Recommended resolution**: Up to 1920x1080
373
+ - **Processing time**: ~1-2 minutes per minute of video
374
+
375
+ **πŸ“ Setup Instructions:**
376
+ 1. Upload your custom trained `best.pt` model to this space for optimal drone detection
377
+ 2. If no custom model is found, the system will use YOLOv8n as fallback
378
+ 3. For best results, use the custom drone detection model trained specifically for your use case
379
+
380
+ **πŸ”§ Model Upload:** Go to "Files and versions" tab β†’ Upload your `best.pt` file β†’ Restart space
381
  """)
382
+
383
+ # Status indicator
384
+ model_status = "🟒 Custom Model" if os.path.exists("best.pt") else "🟑 Default YOLOv8n (Upload best.pt for better results)"
385
+ gr.Markdown(f"**Current Model Status**: {model_status}")
386
 
387
  return interface
388