import gradio as gr import cv2 import numpy as np import tempfile import os from pathlib import Path from ultralytics import YOLO import torch from typing import Optional, Tuple, List import logging # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class DroneDetectionPipeline: """ Professional drone detection pipeline with YOLO + CSRT tracking Designed for UK International Lab deployment """ def __init__(self, model_path: str = "best.pt"): """Initialize the detection pipeline with YOLO model and CSRT tracker""" self.model_path = model_path self.model = None self.tracker = None self.tracking_active = False self.has_had_detection = False self.last_detection_bbox = None self.confidence_threshold = 0.5 # Load model self._load_model() def _load_model(self) -> None: """Load YOLO model with error handling and PyTorch compatibility""" try: if os.path.exists(self.model_path): # Handle PyTorch 2.6+ security restrictions for model loading import torch from ultralytics.nn.tasks import DetectionModel # Add safe globals for ultralytics models torch.serialization.add_safe_globals([DetectionModel]) # Load model with weights_only=False for older model formats try: self.model = YOLO(self.model_path) except Exception as pytorch_error: logger.warning(f"āš ļø Standard loading failed: {str(pytorch_error)}") logger.info("šŸ”„ Attempting alternative loading method...") # Alternative loading for older PyTorch model formats import tempfile import shutil # Create temporary model with legacy loading with tempfile.TemporaryDirectory() as temp_dir: temp_model_path = os.path.join(temp_dir, "temp_model.pt") shutil.copy2(self.model_path, temp_model_path) # Patch torch.load temporarily for this specific model original_load = torch.load def patched_load(f, **kwargs): kwargs['weights_only'] = False return original_load(f, **kwargs) torch.load = patched_load try: self.model = YOLO(temp_model_path) finally: torch.load = original_load logger.info(f"āœ… Model loaded successfully from {self.model_path}") else: logger.error(f"āŒ Model file not found: {self.model_path}") # Try to load a default YOLOv8 model if custom model is not found logger.info("šŸ”„ Loading default YOLOv8n model as fallback...") self.model = YOLO('yolov8n.pt') logger.warning("āš ļø Using default YOLOv8n model - detection accuracy may be reduced") except Exception as e: logger.error(f"āŒ Error loading model: {str(e)}") logger.info("šŸ”„ Attempting to load default YOLOv8n model as final fallback...") try: self.model = YOLO('yolov8n.pt') logger.warning("āš ļø Using default YOLOv8n model - please upload your custom 'best.pt' model for optimal drone detection") except Exception as fallback_error: logger.error(f"āŒ Failed to load any model: {str(fallback_error)}") raise RuntimeError("Could not load any YOLO model. Please check your model file and dependencies.") def _reset_tracking_state(self): """Reset all tracking-related state variables""" self.tracker = None self.tracking_active = False self.has_had_detection = False self.last_detection_bbox = None logger.info("šŸ”„ Tracking state reset") def _initialize_tracker(self, frame: np.ndarray, bbox: Tuple[int, int, int, int]) -> bool: """Initialize CSRT tracker with given bounding box""" try: self.tracker = cv2.TrackerCSRT_create() success = self.tracker.init(frame, bbox) if success: self.tracking_active = True self.last_detection_bbox = bbox logger.info("āœ… Tracker initialized successfully") else: logger.warning("āš ļø Tracker initialization failed") return success except Exception as e: logger.error(f"āŒ Error initializing tracker: {str(e)}") return False def _detect_drones(self, frame: np.ndarray) -> List[Tuple[int, int, int, int, float]]: """Run YOLO inference on frame and return detections""" try: results = self.model(frame, verbose=False, conf=self.confidence_threshold) detections = [] for result in results: if result.boxes is not None: for box in result.boxes: # Extract coordinates and confidence x1, y1, x2, y2 = box.xyxy[0].cpu().numpy().astype(int) confidence = float(box.conf[0].cpu().numpy()) detections.append((x1, y1, x2, y2, confidence)) return detections except Exception as e: logger.error(f"āŒ Error in detection: {str(e)}") return [] def _draw_detection(self, frame: np.ndarray, bbox: Tuple[int, int, int, int], confidence: float = None, is_tracking: bool = False) -> np.ndarray: """Draw bounding box and label on frame""" x1, y1, x2, y2 = bbox # Choose color: Red for detection, Blue for tracking color = (0, 0, 255) if not is_tracking else (255, 0, 0) label_text = f"Drone (Detection)" if not is_tracking else f"Drone (Tracker)" if confidence is not None and not is_tracking: label_text = f"Drone {confidence:.2f}" # Draw bounding box cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2) # Draw label background label_size = cv2.getTextSize(label_text, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)[0] cv2.rectangle(frame, (x1, y1 - label_size[1] - 10), (x1 + label_size[0], y1), color, -1) # Draw label text cv2.putText(frame, label_text, (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2) return frame def process_video(self, input_video_path: str, progress_callback=None) -> str: """ Process entire video with drone detection and tracking Args: input_video_path: Path to input video progress_callback: Optional callback for progress updates Returns: Path to processed output video """ # IMPORTANT: Reset tracking state at the beginning of each video self._reset_tracking_state() # Create temporary output file output_dir = tempfile.mkdtemp() output_path = os.path.join(output_dir, "drone_detection_output.mp4") cap = cv2.VideoCapture(input_video_path) if not cap.isOpened(): raise ValueError("āŒ Could not open input video") # Get video properties fps = int(cap.get(cv2.CAP_PROP_FPS)) width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) # Initialize video writer fourcc = cv2.VideoWriter_fourcc(*'mp4v') out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) frame_count = 0 detection_count = 0 tracking_count = 0 logger.info(f"šŸŽ¬ Processing video: {total_frames} frames at {fps} FPS") try: while True: ret, frame = cap.read() if not ret: break frame_processed = frame.copy() detection_made_this_frame = False # Get all detections with different confidence levels all_detections = self.model(frame, verbose=False, conf=0.1) # Low threshold to get all detections high_conf_detections = [] low_conf_detections = [] for result in all_detections: if result.boxes is not None: for box in result.boxes: x1, y1, x2, y2 = box.xyxy[0].cpu().numpy().astype(int) confidence = float(box.conf[0].cpu().numpy()) if confidence >= self.confidence_threshold: # Use the actual confidence threshold high_conf_detections.append((x1, y1, x2, y2, confidence)) elif confidence >= 0.1: # Lower threshold for tracking continuation low_conf_detections.append((x1, y1, x2, y2, confidence)) # PRIORITY 1: High confidence detection (RED BOX) if high_conf_detections: # Select the largest high confidence detection largest_detection = max(high_conf_detections, key=lambda d: (d[2]-d[0]) * (d[3]-d[1])) x1, y1, x2, y2, conf = largest_detection detection_bbox = (x1, y1, x2-x1, y2-y1) # Draw red detection box frame_processed = self._draw_detection(frame_processed, (x1, y1, x2, y2), conf, False) # Initialize/reinitialize tracker with this high confidence detection self._initialize_tracker(frame, detection_bbox) self.has_had_detection = True detection_count += 1 detection_made_this_frame = True # PRIORITY 2: Low confidence detection (BLUE BOX) - only if we've had a detection before elif low_conf_detections and self.has_had_detection: # Select the largest low confidence detection largest_low_conf = max(low_conf_detections, key=lambda d: (d[2]-d[0]) * (d[3]-d[1])) x1, y1, x2, y2, conf = largest_low_conf tracking_bbox = (x1, y1, x2-x1, y2-y1) # Draw blue tracking box for low confidence detection frame_processed = self._draw_detection(frame_processed, (x1, y1, x2, y2), None, True) # Reinitialize tracker with low confidence detection if self.tracker is not None: success = self.tracker.init(frame, tracking_bbox) if success: self.tracking_active = True self.last_detection_bbox = tracking_bbox tracking_count += 1 detection_made_this_frame = True # PRIORITY 3: CSRT tracking - only if we've had a detection and no current detections elif self.has_had_detection and self.tracker is not None and self.tracking_active and not detection_made_this_frame: success, tracking_bbox = self.tracker.update(frame) if success: x, y, w, h = [int(v) for v in tracking_bbox] self.last_detection_bbox = tracking_bbox # Draw blue tracking box frame_processed = self._draw_detection(frame_processed, (x, y, x+w, y+h), None, True) tracking_count += 1 else: # Tracking failed - optionally use last known position with reduced confidence logger.debug("šŸ” CSRT tracking failed") self.tracking_active = False # Uncomment below if you want to show last known position when tracking fails # if self.last_detection_bbox is not None: # x, y, w, h = [int(v) for v in self.last_detection_bbox] # frame_processed = self._draw_detection(frame_processed, (x, y, x+w, y+h), None, True) # Write processed frame out.write(frame_processed) frame_count += 1 # Update progress if progress_callback and frame_count % 10 == 0: progress = frame_count / total_frames progress_callback(progress, f"Processing frame {frame_count}/{total_frames}") except Exception as e: logger.error(f"āŒ Error during video processing: {str(e)}") raise finally: cap.release() out.release() logger.info(f"āœ… Processing completed!") logger.info(f"šŸ“Š Stats: {detection_count} detections, {tracking_count} tracking frames") return output_path # Initialize pipeline pipeline = DroneDetectionPipeline() def process_video_gradio(input_video): """Gradio interface function with enhanced error handling""" if input_video is None: return None, "āŒ Please upload a video file" try: # Check if model is loaded properly if pipeline.model is None: return None, "āŒ Model not loaded. Please check that 'best.pt' is uploaded to the space." # Add file size check file_size = os.path.getsize(input_video) max_size = 200 * 1024 * 1024 # 200MB if file_size > max_size: return None, f"āŒ File too large ({file_size/1024/1024:.1f}MB). Maximum size is 200MB." logger.info(f"šŸ“¹ Processing video: {input_video} ({file_size/1024/1024:.1f}MB)") # Process video output_path = pipeline.process_video(input_video) success_msg = "āœ… Video processing completed successfully!" if "yolov8n" in str(pipeline.model.model): success_msg += "\nāš ļø Note: Using default YOLOv8n model. Upload 'best.pt' for optimal drone detection." return output_path, success_msg except Exception as e: error_msg = f"āŒ Error processing video: {str(e)}" logger.error(error_msg) logger.error("Full traceback:", exc_info=True) return None, error_msg # Gradio Interface def create_interface(): """Create professional Gradio interface""" title = "🚁 Professional Drone Detection System" description = """ **Advanced Drone Detection and Tracking Pipeline** This system uses state-of-the-art YOLO object detection combined with CSRT tracking for robust drone detection in video footage. **Features:** - Real-time drone detection with confidence scoring - Seamless tracking when detections are lost - Professional-grade output for research and analysis """ article = """ ### Technical Details - **Detection Model**: Custom trained YOLO model optimized for drone detection - **Tracking Algorithm**: CSRT (Channel and Spatial Reliability Tracker) - **Output Format**: MP4 video with annotated detections and tracking ### Usage Instructions 1. Upload your video file (supported formats: MP4, AVI, MOV) 2. Adjust confidence threshold if needed (0.1-0.9) 3. Click "Process Video" and wait for completion 4. Download the processed video with drone annotations **Note**: Processing time depends on video length and resolution. """ with gr.Blocks(theme=gr.themes.Soft(), title=title) as interface: gr.Markdown(f"# {title}") gr.Markdown(description) with gr.Row(): with gr.Column(): input_video = gr.Video( label="šŸ“¹ Upload Video", format="mp4" ) process_btn = gr.Button( "šŸš€ Process Video", variant="primary", size="lg" ) sample_videos = [] sample_files = ["sample_drone_1.mp4", "sample_drone_2.mp4", "sample_drone_3.mp4"] for sample_file in sample_files: if os.path.exists(sample_file): sample_videos.append(sample_file) if sample_videos: gr.Markdown("**šŸ“‚ Or try these test videos (Not a part of training process):**") def load_sample_video(video_path): return video_path with gr.Row(): for i, sample_video in enumerate(sample_videos, 1): sample_btn = gr.Button(f"šŸ“¹ Sample {i}", size="sm") sample_btn.click( fn=lambda x=sample_video: x, outputs=input_video ) with gr.Column(): output_video = gr.Video( label="šŸ“½ļø Processed Video", format="mp4" ) status_text = gr.Textbox( label="šŸ“Š Processing Status", interactive=False ) process_btn.click( fn=process_video_gradio, inputs=[input_video], outputs=[output_video, status_text], show_progress=True ) gr.Markdown(article) # Example section gr.Markdown("### šŸ“‹ System Requirements & Setup") gr.Markdown(""" - **Input video formats**: MP4, AVI, MOV, MKV - **Maximum file size**: 200MB - **Recommended resolution**: Up to 1920x1080 - **Processing time**: ~1-2 minutes per minute of video **šŸ“ Setup Instructions:** 1. Upload your custom trained `best.pt` model to this space for optimal drone detection 2. If no custom model is found, the system will use YOLOv8n as fallback 3. For best results, use the custom drone detection model trained specifically for your use case **šŸ”§ Model Upload:** Go to "Files and versions" tab → Upload your `best.pt` file → Restart space """) # Status indicator model_status = "🟢 Custom Model" if os.path.exists("best.pt") else "🟔 Default YOLOv8n (Upload best.pt for better results)" gr.Markdown(f"**Current Model Status**: {model_status}") return interface # Launch the application if __name__ == "__main__": interface = create_interface() interface.launch( server_name="0.0.0.0", server_port=7860, share=False, show_error=True, quiet=False )