|
import gradio as gr |
|
import cv2 |
|
import numpy as np |
|
import tempfile |
|
import os |
|
from pathlib import Path |
|
from ultralytics import YOLO |
|
import torch |
|
from typing import Optional, Tuple, List |
|
import logging |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
class DroneDetectionPipeline: |
|
""" |
|
Professional drone detection pipeline with YOLO + CSRT tracking |
|
Designed for UK International Lab deployment |
|
""" |
|
|
|
def __init__(self, model_path: str = "best.pt"): |
|
"""Initialize the detection pipeline with YOLO model and CSRT tracker""" |
|
self.model_path = model_path |
|
self.model = None |
|
self.tracker = None |
|
self.tracking_active = False |
|
self.has_had_detection = False |
|
self.last_detection_bbox = None |
|
self.confidence_threshold = 0.5 |
|
|
|
|
|
self._load_model() |
|
|
|
def _load_model(self) -> None: |
|
"""Load YOLO model with error handling and PyTorch compatibility""" |
|
try: |
|
if os.path.exists(self.model_path): |
|
|
|
import torch |
|
from ultralytics.nn.tasks import DetectionModel |
|
|
|
|
|
torch.serialization.add_safe_globals([DetectionModel]) |
|
|
|
|
|
try: |
|
self.model = YOLO(self.model_path) |
|
except Exception as pytorch_error: |
|
logger.warning(f"β οΈ Standard loading failed: {str(pytorch_error)}") |
|
logger.info("π Attempting alternative loading method...") |
|
|
|
|
|
import tempfile |
|
import shutil |
|
|
|
|
|
with tempfile.TemporaryDirectory() as temp_dir: |
|
temp_model_path = os.path.join(temp_dir, "temp_model.pt") |
|
shutil.copy2(self.model_path, temp_model_path) |
|
|
|
|
|
original_load = torch.load |
|
def patched_load(f, **kwargs): |
|
kwargs['weights_only'] = False |
|
return original_load(f, **kwargs) |
|
|
|
torch.load = patched_load |
|
try: |
|
self.model = YOLO(temp_model_path) |
|
finally: |
|
torch.load = original_load |
|
|
|
logger.info(f"β
Model loaded successfully from {self.model_path}") |
|
else: |
|
logger.error(f"β Model file not found: {self.model_path}") |
|
|
|
logger.info("π Loading default YOLOv8n model as fallback...") |
|
self.model = YOLO('yolov8n.pt') |
|
logger.warning("β οΈ Using default YOLOv8n model - detection accuracy may be reduced") |
|
|
|
except Exception as e: |
|
logger.error(f"β Error loading model: {str(e)}") |
|
logger.info("π Attempting to load default YOLOv8n model as final fallback...") |
|
try: |
|
self.model = YOLO('yolov8n.pt') |
|
logger.warning("β οΈ Using default YOLOv8n model - please upload your custom 'best.pt' model for optimal drone detection") |
|
except Exception as fallback_error: |
|
logger.error(f"β Failed to load any model: {str(fallback_error)}") |
|
raise RuntimeError("Could not load any YOLO model. Please check your model file and dependencies.") |
|
|
|
def _reset_tracking_state(self): |
|
"""Reset all tracking-related state variables""" |
|
self.tracker = None |
|
self.tracking_active = False |
|
self.has_had_detection = False |
|
self.last_detection_bbox = None |
|
logger.info("π Tracking state reset") |
|
|
|
def _initialize_tracker(self, frame: np.ndarray, bbox: Tuple[int, int, int, int]) -> bool: |
|
"""Initialize CSRT tracker with given bounding box""" |
|
try: |
|
self.tracker = cv2.TrackerCSRT_create() |
|
success = self.tracker.init(frame, bbox) |
|
if success: |
|
self.tracking_active = True |
|
self.last_detection_bbox = bbox |
|
logger.info("β
Tracker initialized successfully") |
|
else: |
|
logger.warning("β οΈ Tracker initialization failed") |
|
return success |
|
except Exception as e: |
|
logger.error(f"β Error initializing tracker: {str(e)}") |
|
return False |
|
|
|
def _detect_drones(self, frame: np.ndarray) -> List[Tuple[int, int, int, int, float]]: |
|
"""Run YOLO inference on frame and return detections""" |
|
try: |
|
results = self.model(frame, verbose=False, conf=self.confidence_threshold) |
|
detections = [] |
|
|
|
for result in results: |
|
if result.boxes is not None: |
|
for box in result.boxes: |
|
|
|
x1, y1, x2, y2 = box.xyxy[0].cpu().numpy().astype(int) |
|
confidence = float(box.conf[0].cpu().numpy()) |
|
detections.append((x1, y1, x2, y2, confidence)) |
|
|
|
return detections |
|
except Exception as e: |
|
logger.error(f"β Error in detection: {str(e)}") |
|
return [] |
|
|
|
def _draw_detection(self, frame: np.ndarray, bbox: Tuple[int, int, int, int], |
|
confidence: float = None, is_tracking: bool = False) -> np.ndarray: |
|
"""Draw bounding box and label on frame""" |
|
x1, y1, x2, y2 = bbox |
|
|
|
|
|
color = (0, 0, 255) if not is_tracking else (255, 0, 0) |
|
label_text = f"Drone (Detection)" if not is_tracking else f"Drone (Tracker)" |
|
|
|
if confidence is not None and not is_tracking: |
|
label_text = f"Drone {confidence:.2f}" |
|
|
|
|
|
cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2) |
|
|
|
|
|
label_size = cv2.getTextSize(label_text, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)[0] |
|
cv2.rectangle(frame, (x1, y1 - label_size[1] - 10), |
|
(x1 + label_size[0], y1), color, -1) |
|
|
|
|
|
cv2.putText(frame, label_text, (x1, y1 - 5), |
|
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2) |
|
|
|
return frame |
|
|
|
def process_video(self, input_video_path: str, progress_callback=None) -> str: |
|
""" |
|
Process entire video with drone detection and tracking |
|
|
|
Args: |
|
input_video_path: Path to input video |
|
progress_callback: Optional callback for progress updates |
|
|
|
Returns: |
|
Path to processed output video |
|
""" |
|
|
|
self._reset_tracking_state() |
|
|
|
|
|
output_dir = tempfile.mkdtemp() |
|
output_path = os.path.join(output_dir, "drone_detection_output.mp4") |
|
|
|
cap = cv2.VideoCapture(input_video_path) |
|
|
|
if not cap.isOpened(): |
|
raise ValueError("β Could not open input video") |
|
|
|
|
|
fps = int(cap.get(cv2.CAP_PROP_FPS)) |
|
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
|
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
|
|
|
|
fourcc = cv2.VideoWriter_fourcc(*'mp4v') |
|
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) |
|
|
|
frame_count = 0 |
|
detection_count = 0 |
|
tracking_count = 0 |
|
|
|
logger.info(f"π¬ Processing video: {total_frames} frames at {fps} FPS") |
|
|
|
try: |
|
while True: |
|
ret, frame = cap.read() |
|
if not ret: |
|
break |
|
|
|
frame_processed = frame.copy() |
|
detection_made_this_frame = False |
|
|
|
|
|
all_detections = self.model(frame, verbose=False, conf=0.1) |
|
high_conf_detections = [] |
|
low_conf_detections = [] |
|
|
|
for result in all_detections: |
|
if result.boxes is not None: |
|
for box in result.boxes: |
|
x1, y1, x2, y2 = box.xyxy[0].cpu().numpy().astype(int) |
|
confidence = float(box.conf[0].cpu().numpy()) |
|
|
|
if confidence >= self.confidence_threshold: |
|
high_conf_detections.append((x1, y1, x2, y2, confidence)) |
|
elif confidence >= 0.1: |
|
low_conf_detections.append((x1, y1, x2, y2, confidence)) |
|
|
|
|
|
if high_conf_detections: |
|
|
|
largest_detection = max(high_conf_detections, key=lambda d: (d[2]-d[0]) * (d[3]-d[1])) |
|
x1, y1, x2, y2, conf = largest_detection |
|
detection_bbox = (x1, y1, x2-x1, y2-y1) |
|
|
|
|
|
frame_processed = self._draw_detection(frame_processed, (x1, y1, x2, y2), conf, False) |
|
|
|
|
|
self._initialize_tracker(frame, detection_bbox) |
|
self.has_had_detection = True |
|
detection_count += 1 |
|
detection_made_this_frame = True |
|
|
|
|
|
elif low_conf_detections and self.has_had_detection: |
|
|
|
largest_low_conf = max(low_conf_detections, key=lambda d: (d[2]-d[0]) * (d[3]-d[1])) |
|
x1, y1, x2, y2, conf = largest_low_conf |
|
tracking_bbox = (x1, y1, x2-x1, y2-y1) |
|
|
|
|
|
frame_processed = self._draw_detection(frame_processed, (x1, y1, x2, y2), None, True) |
|
|
|
|
|
if self.tracker is not None: |
|
success = self.tracker.init(frame, tracking_bbox) |
|
if success: |
|
self.tracking_active = True |
|
self.last_detection_bbox = tracking_bbox |
|
|
|
tracking_count += 1 |
|
detection_made_this_frame = True |
|
|
|
|
|
elif self.has_had_detection and self.tracker is not None and self.tracking_active and not detection_made_this_frame: |
|
success, tracking_bbox = self.tracker.update(frame) |
|
|
|
if success: |
|
x, y, w, h = [int(v) for v in tracking_bbox] |
|
self.last_detection_bbox = tracking_bbox |
|
|
|
frame_processed = self._draw_detection(frame_processed, (x, y, x+w, y+h), None, True) |
|
tracking_count += 1 |
|
else: |
|
|
|
logger.debug("π CSRT tracking failed") |
|
self.tracking_active = False |
|
|
|
|
|
|
|
|
|
|
|
|
|
out.write(frame_processed) |
|
frame_count += 1 |
|
|
|
|
|
if progress_callback and frame_count % 10 == 0: |
|
progress = frame_count / total_frames |
|
progress_callback(progress, f"Processing frame {frame_count}/{total_frames}") |
|
|
|
except Exception as e: |
|
logger.error(f"β Error during video processing: {str(e)}") |
|
raise |
|
finally: |
|
cap.release() |
|
out.release() |
|
|
|
logger.info(f"β
Processing completed!") |
|
logger.info(f"π Stats: {detection_count} detections, {tracking_count} tracking frames") |
|
|
|
return output_path |
|
|
|
|
|
pipeline = DroneDetectionPipeline() |
|
|
|
def process_video_gradio(input_video): |
|
"""Gradio interface function with enhanced error handling""" |
|
if input_video is None: |
|
return None, "β Please upload a video file" |
|
|
|
try: |
|
|
|
if pipeline.model is None: |
|
return None, "β Model not loaded. Please check that 'best.pt' is uploaded to the space." |
|
|
|
|
|
file_size = os.path.getsize(input_video) |
|
max_size = 200 * 1024 * 1024 |
|
if file_size > max_size: |
|
return None, f"β File too large ({file_size/1024/1024:.1f}MB). Maximum size is 200MB." |
|
|
|
logger.info(f"πΉ Processing video: {input_video} ({file_size/1024/1024:.1f}MB)") |
|
|
|
|
|
output_path = pipeline.process_video(input_video) |
|
|
|
success_msg = "β
Video processing completed successfully!" |
|
if "yolov8n" in str(pipeline.model.model): |
|
success_msg += "\nβ οΈ Note: Using default YOLOv8n model. Upload 'best.pt' for optimal drone detection." |
|
|
|
return output_path, success_msg |
|
|
|
except Exception as e: |
|
error_msg = f"β Error processing video: {str(e)}" |
|
logger.error(error_msg) |
|
logger.error("Full traceback:", exc_info=True) |
|
return None, error_msg |
|
|
|
|
|
def create_interface(): |
|
"""Create professional Gradio interface""" |
|
|
|
title = "π Professional Drone Detection System" |
|
description = """ |
|
**Advanced Drone Detection and Tracking Pipeline** |
|
|
|
This system uses state-of-the-art YOLO object detection combined with CSRT tracking for robust drone detection in video footage. |
|
|
|
**Features:** |
|
- Real-time drone detection with confidence scoring |
|
- Seamless tracking when detections are lost |
|
- Professional-grade output for research and analysis |
|
|
|
""" |
|
|
|
article = """ |
|
### Technical Details |
|
- **Detection Model**: Custom trained YOLO model optimized for drone detection |
|
- **Tracking Algorithm**: CSRT (Channel and Spatial Reliability Tracker) |
|
- **Output Format**: MP4 video with annotated detections and tracking |
|
|
|
### Usage Instructions |
|
1. Upload your video file (supported formats: MP4, AVI, MOV) |
|
2. Adjust confidence threshold if needed (0.1-0.9) |
|
3. Click "Process Video" and wait for completion |
|
4. Download the processed video with drone annotations |
|
|
|
**Note**: Processing time depends on video length and resolution. |
|
""" |
|
|
|
with gr.Blocks(theme=gr.themes.Soft(), title=title) as interface: |
|
gr.Markdown(f"# {title}") |
|
gr.Markdown(description) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
input_video = gr.Video( |
|
label="πΉ Upload Video", |
|
format="mp4" |
|
) |
|
process_btn = gr.Button( |
|
"π Process Video", |
|
variant="primary", |
|
size="lg" |
|
) |
|
sample_videos = [] |
|
sample_files = ["sample_drone_1.mp4", "sample_drone_2.mp4", "sample_drone_3.mp4"] |
|
|
|
for sample_file in sample_files: |
|
if os.path.exists(sample_file): |
|
sample_videos.append(sample_file) |
|
|
|
if sample_videos: |
|
gr.Markdown("**π Or try these test videos (Not a part of training process):**") |
|
|
|
def load_sample_video(video_path): |
|
return video_path |
|
|
|
with gr.Row(): |
|
for i, sample_video in enumerate(sample_videos, 1): |
|
sample_btn = gr.Button(f"πΉ Sample {i}", size="sm") |
|
sample_btn.click( |
|
fn=lambda x=sample_video: x, |
|
outputs=input_video |
|
) |
|
|
|
with gr.Column(): |
|
output_video = gr.Video( |
|
label="π½οΈ Processed Video", |
|
format="mp4" |
|
) |
|
status_text = gr.Textbox( |
|
label="π Processing Status", |
|
interactive=False |
|
) |
|
|
|
process_btn.click( |
|
fn=process_video_gradio, |
|
inputs=[input_video], |
|
outputs=[output_video, status_text], |
|
show_progress=True |
|
) |
|
|
|
gr.Markdown(article) |
|
|
|
|
|
gr.Markdown("### π System Requirements & Setup") |
|
gr.Markdown(""" |
|
- **Input video formats**: MP4, AVI, MOV, MKV |
|
- **Maximum file size**: 200MB |
|
- **Recommended resolution**: Up to 1920x1080 |
|
- **Processing time**: ~1-2 minutes per minute of video |
|
|
|
**π Setup Instructions:** |
|
1. Upload your custom trained `best.pt` model to this space for optimal drone detection |
|
2. If no custom model is found, the system will use YOLOv8n as fallback |
|
3. For best results, use the custom drone detection model trained specifically for your use case |
|
|
|
**π§ Model Upload:** Go to "Files and versions" tab β Upload your `best.pt` file β Restart space |
|
""") |
|
|
|
|
|
model_status = "π’ Custom Model" if os.path.exists("best.pt") else "π‘ Default YOLOv8n (Upload best.pt for better results)" |
|
gr.Markdown(f"**Current Model Status**: {model_status}") |
|
|
|
return interface |
|
|
|
|
|
if __name__ == "__main__": |
|
interface = create_interface() |
|
interface.launch( |
|
server_name="0.0.0.0", |
|
server_port=7860, |
|
share=False, |
|
show_error=True, |
|
quiet=False |
|
) |