File size: 19,886 Bytes
8fd5341 0053d0c 8fd5341 1871e7a 8fd5341 3015dd2 8fd5341 3015dd2 8fd5341 3015dd2 8fd5341 3015dd2 8fd5341 55f6172 8fd5341 55f6172 8fd5341 6a66591 8fd5341 55f6172 8fd5341 55f6172 8fd5341 55f6172 6cdd732 55f6172 6cdd732 55f6172 6cdd732 55f6172 6cdd732 55f6172 6cdd732 8fd5341 6cdd732 8fd5341 55f6172 8fd5341 62a7547 8fd5341 55f6172 8fd5341 55f6172 62a7547 55f6172 6cdd732 55f6172 62a7547 55f6172 62a7547 6cdd732 55f6172 6cdd732 55f6172 8fd5341 55f6172 6cdd732 8fd5341 55f6172 8fd5341 55f6172 8fd5341 6cdd732 3015dd2 8fd5341 3015dd2 8fd5341 3015dd2 8fd5341 3015dd2 8fd5341 ec9bfca e5d3a96 ec9bfca 8fd5341 55f6172 8fd5341 3015dd2 8fd5341 3015dd2 8fd5341 3015dd2 8fd5341 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 |
import gradio as gr
import cv2
import numpy as np
import tempfile
import os
from pathlib import Path
from ultralytics import YOLO
import torch
from typing import Optional, Tuple, List
import logging
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class DroneDetectionPipeline:
"""
Professional drone detection pipeline with YOLO + CSRT tracking
Designed for UK International Lab deployment
"""
def __init__(self, model_path: str = "best.pt"):
"""Initialize the detection pipeline with YOLO model and CSRT tracker"""
self.model_path = model_path
self.model = None
self.tracker = None
self.tracking_active = False
self.has_had_detection = False
self.last_detection_bbox = None
self.confidence_threshold = 0.5
# Load model
self._load_model()
def _load_model(self) -> None:
"""Load YOLO model with error handling and PyTorch compatibility"""
try:
if os.path.exists(self.model_path):
# Handle PyTorch 2.6+ security restrictions for model loading
import torch
from ultralytics.nn.tasks import DetectionModel
# Add safe globals for ultralytics models
torch.serialization.add_safe_globals([DetectionModel])
# Load model with weights_only=False for older model formats
try:
self.model = YOLO(self.model_path)
except Exception as pytorch_error:
logger.warning(f"β οΈ Standard loading failed: {str(pytorch_error)}")
logger.info("π Attempting alternative loading method...")
# Alternative loading for older PyTorch model formats
import tempfile
import shutil
# Create temporary model with legacy loading
with tempfile.TemporaryDirectory() as temp_dir:
temp_model_path = os.path.join(temp_dir, "temp_model.pt")
shutil.copy2(self.model_path, temp_model_path)
# Patch torch.load temporarily for this specific model
original_load = torch.load
def patched_load(f, **kwargs):
kwargs['weights_only'] = False
return original_load(f, **kwargs)
torch.load = patched_load
try:
self.model = YOLO(temp_model_path)
finally:
torch.load = original_load
logger.info(f"β
Model loaded successfully from {self.model_path}")
else:
logger.error(f"β Model file not found: {self.model_path}")
# Try to load a default YOLOv8 model if custom model is not found
logger.info("π Loading default YOLOv8n model as fallback...")
self.model = YOLO('yolov8n.pt')
logger.warning("β οΈ Using default YOLOv8n model - detection accuracy may be reduced")
except Exception as e:
logger.error(f"β Error loading model: {str(e)}")
logger.info("π Attempting to load default YOLOv8n model as final fallback...")
try:
self.model = YOLO('yolov8n.pt')
logger.warning("β οΈ Using default YOLOv8n model - please upload your custom 'best.pt' model for optimal drone detection")
except Exception as fallback_error:
logger.error(f"β Failed to load any model: {str(fallback_error)}")
raise RuntimeError("Could not load any YOLO model. Please check your model file and dependencies.")
def _reset_tracking_state(self):
"""Reset all tracking-related state variables"""
self.tracker = None
self.tracking_active = False
self.has_had_detection = False
self.last_detection_bbox = None
logger.info("π Tracking state reset")
def _initialize_tracker(self, frame: np.ndarray, bbox: Tuple[int, int, int, int]) -> bool:
"""Initialize CSRT tracker with given bounding box"""
try:
self.tracker = cv2.TrackerCSRT_create()
success = self.tracker.init(frame, bbox)
if success:
self.tracking_active = True
self.last_detection_bbox = bbox
logger.info("β
Tracker initialized successfully")
else:
logger.warning("β οΈ Tracker initialization failed")
return success
except Exception as e:
logger.error(f"β Error initializing tracker: {str(e)}")
return False
def _detect_drones(self, frame: np.ndarray) -> List[Tuple[int, int, int, int, float]]:
"""Run YOLO inference on frame and return detections"""
try:
results = self.model(frame, verbose=False, conf=self.confidence_threshold)
detections = []
for result in results:
if result.boxes is not None:
for box in result.boxes:
# Extract coordinates and confidence
x1, y1, x2, y2 = box.xyxy[0].cpu().numpy().astype(int)
confidence = float(box.conf[0].cpu().numpy())
detections.append((x1, y1, x2, y2, confidence))
return detections
except Exception as e:
logger.error(f"β Error in detection: {str(e)}")
return []
def _draw_detection(self, frame: np.ndarray, bbox: Tuple[int, int, int, int],
confidence: float = None, is_tracking: bool = False) -> np.ndarray:
"""Draw bounding box and label on frame"""
x1, y1, x2, y2 = bbox
# Choose color: Red for detection, Blue for tracking
color = (0, 0, 255) if not is_tracking else (255, 0, 0)
label_text = f"Drone (Detection)" if not is_tracking else f"Drone (Tracker)"
if confidence is not None and not is_tracking:
label_text = f"Drone {confidence:.2f}"
# Draw bounding box
cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
# Draw label background
label_size = cv2.getTextSize(label_text, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)[0]
cv2.rectangle(frame, (x1, y1 - label_size[1] - 10),
(x1 + label_size[0], y1), color, -1)
# Draw label text
cv2.putText(frame, label_text, (x1, y1 - 5),
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
return frame
def process_video(self, input_video_path: str, progress_callback=None) -> str:
"""
Process entire video with drone detection and tracking
Args:
input_video_path: Path to input video
progress_callback: Optional callback for progress updates
Returns:
Path to processed output video
"""
# IMPORTANT: Reset tracking state at the beginning of each video
self._reset_tracking_state()
# Create temporary output file
output_dir = tempfile.mkdtemp()
output_path = os.path.join(output_dir, "drone_detection_output.mp4")
cap = cv2.VideoCapture(input_video_path)
if not cap.isOpened():
raise ValueError("β Could not open input video")
# Get video properties
fps = int(cap.get(cv2.CAP_PROP_FPS))
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
# Initialize video writer
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
frame_count = 0
detection_count = 0
tracking_count = 0
logger.info(f"π¬ Processing video: {total_frames} frames at {fps} FPS")
try:
while True:
ret, frame = cap.read()
if not ret:
break
frame_processed = frame.copy()
detection_made_this_frame = False
# Get all detections with different confidence levels
all_detections = self.model(frame, verbose=False, conf=0.1) # Low threshold to get all detections
high_conf_detections = []
low_conf_detections = []
for result in all_detections:
if result.boxes is not None:
for box in result.boxes:
x1, y1, x2, y2 = box.xyxy[0].cpu().numpy().astype(int)
confidence = float(box.conf[0].cpu().numpy())
if confidence >= self.confidence_threshold: # Use the actual confidence threshold
high_conf_detections.append((x1, y1, x2, y2, confidence))
elif confidence >= 0.1: # Lower threshold for tracking continuation
low_conf_detections.append((x1, y1, x2, y2, confidence))
# PRIORITY 1: High confidence detection (RED BOX)
if high_conf_detections:
# Select the largest high confidence detection
largest_detection = max(high_conf_detections, key=lambda d: (d[2]-d[0]) * (d[3]-d[1]))
x1, y1, x2, y2, conf = largest_detection
detection_bbox = (x1, y1, x2-x1, y2-y1)
# Draw red detection box
frame_processed = self._draw_detection(frame_processed, (x1, y1, x2, y2), conf, False)
# Initialize/reinitialize tracker with this high confidence detection
self._initialize_tracker(frame, detection_bbox)
self.has_had_detection = True
detection_count += 1
detection_made_this_frame = True
# PRIORITY 2: Low confidence detection (BLUE BOX) - only if we've had a detection before
elif low_conf_detections and self.has_had_detection:
# Select the largest low confidence detection
largest_low_conf = max(low_conf_detections, key=lambda d: (d[2]-d[0]) * (d[3]-d[1]))
x1, y1, x2, y2, conf = largest_low_conf
tracking_bbox = (x1, y1, x2-x1, y2-y1)
# Draw blue tracking box for low confidence detection
frame_processed = self._draw_detection(frame_processed, (x1, y1, x2, y2), None, True)
# Reinitialize tracker with low confidence detection
if self.tracker is not None:
success = self.tracker.init(frame, tracking_bbox)
if success:
self.tracking_active = True
self.last_detection_bbox = tracking_bbox
tracking_count += 1
detection_made_this_frame = True
# PRIORITY 3: CSRT tracking - only if we've had a detection and no current detections
elif self.has_had_detection and self.tracker is not None and self.tracking_active and not detection_made_this_frame:
success, tracking_bbox = self.tracker.update(frame)
if success:
x, y, w, h = [int(v) for v in tracking_bbox]
self.last_detection_bbox = tracking_bbox
# Draw blue tracking box
frame_processed = self._draw_detection(frame_processed, (x, y, x+w, y+h), None, True)
tracking_count += 1
else:
# Tracking failed - optionally use last known position with reduced confidence
logger.debug("π CSRT tracking failed")
self.tracking_active = False
# Uncomment below if you want to show last known position when tracking fails
# if self.last_detection_bbox is not None:
# x, y, w, h = [int(v) for v in self.last_detection_bbox]
# frame_processed = self._draw_detection(frame_processed, (x, y, x+w, y+h), None, True)
# Write processed frame
out.write(frame_processed)
frame_count += 1
# Update progress
if progress_callback and frame_count % 10 == 0:
progress = frame_count / total_frames
progress_callback(progress, f"Processing frame {frame_count}/{total_frames}")
except Exception as e:
logger.error(f"β Error during video processing: {str(e)}")
raise
finally:
cap.release()
out.release()
logger.info(f"β
Processing completed!")
logger.info(f"π Stats: {detection_count} detections, {tracking_count} tracking frames")
return output_path
# Initialize pipeline
pipeline = DroneDetectionPipeline()
def process_video_gradio(input_video):
"""Gradio interface function with enhanced error handling"""
if input_video is None:
return None, "β Please upload a video file"
try:
# Check if model is loaded properly
if pipeline.model is None:
return None, "β Model not loaded. Please check that 'best.pt' is uploaded to the space."
# Add file size check
file_size = os.path.getsize(input_video)
max_size = 200 * 1024 * 1024 # 200MB
if file_size > max_size:
return None, f"β File too large ({file_size/1024/1024:.1f}MB). Maximum size is 200MB."
logger.info(f"πΉ Processing video: {input_video} ({file_size/1024/1024:.1f}MB)")
# Process video
output_path = pipeline.process_video(input_video)
success_msg = "β
Video processing completed successfully!"
if "yolov8n" in str(pipeline.model.model):
success_msg += "\nβ οΈ Note: Using default YOLOv8n model. Upload 'best.pt' for optimal drone detection."
return output_path, success_msg
except Exception as e:
error_msg = f"β Error processing video: {str(e)}"
logger.error(error_msg)
logger.error("Full traceback:", exc_info=True)
return None, error_msg
# Gradio Interface
def create_interface():
"""Create professional Gradio interface"""
title = "π Professional Drone Detection System"
description = """
**Advanced Drone Detection and Tracking Pipeline**
This system uses state-of-the-art YOLO object detection combined with CSRT tracking for robust drone detection in video footage.
**Features:**
- Real-time drone detection with confidence scoring
- Seamless tracking when detections are lost
- Professional-grade output for research and analysis
"""
article = """
### Technical Details
- **Detection Model**: Custom trained YOLO model optimized for drone detection
- **Tracking Algorithm**: CSRT (Channel and Spatial Reliability Tracker)
- **Output Format**: MP4 video with annotated detections and tracking
### Usage Instructions
1. Upload your video file (supported formats: MP4, AVI, MOV)
2. Adjust confidence threshold if needed (0.1-0.9)
3. Click "Process Video" and wait for completion
4. Download the processed video with drone annotations
**Note**: Processing time depends on video length and resolution.
"""
with gr.Blocks(theme=gr.themes.Soft(), title=title) as interface:
gr.Markdown(f"# {title}")
gr.Markdown(description)
with gr.Row():
with gr.Column():
input_video = gr.Video(
label="πΉ Upload Video",
format="mp4"
)
process_btn = gr.Button(
"π Process Video",
variant="primary",
size="lg"
)
sample_videos = []
sample_files = ["sample_drone_1.mp4", "sample_drone_2.mp4", "sample_drone_3.mp4"]
for sample_file in sample_files:
if os.path.exists(sample_file):
sample_videos.append(sample_file)
if sample_videos:
gr.Markdown("**π Or try these test videos (Not a part of training process):**")
def load_sample_video(video_path):
return video_path
with gr.Row():
for i, sample_video in enumerate(sample_videos, 1):
sample_btn = gr.Button(f"πΉ Sample {i}", size="sm")
sample_btn.click(
fn=lambda x=sample_video: x,
outputs=input_video
)
with gr.Column():
output_video = gr.Video(
label="π½οΈ Processed Video",
format="mp4"
)
status_text = gr.Textbox(
label="π Processing Status",
interactive=False
)
process_btn.click(
fn=process_video_gradio,
inputs=[input_video],
outputs=[output_video, status_text],
show_progress=True
)
gr.Markdown(article)
# Example section
gr.Markdown("### π System Requirements & Setup")
gr.Markdown("""
- **Input video formats**: MP4, AVI, MOV, MKV
- **Maximum file size**: 200MB
- **Recommended resolution**: Up to 1920x1080
- **Processing time**: ~1-2 minutes per minute of video
**π Setup Instructions:**
1. Upload your custom trained `best.pt` model to this space for optimal drone detection
2. If no custom model is found, the system will use YOLOv8n as fallback
3. For best results, use the custom drone detection model trained specifically for your use case
**π§ Model Upload:** Go to "Files and versions" tab β Upload your `best.pt` file β Restart space
""")
# Status indicator
model_status = "π’ Custom Model" if os.path.exists("best.pt") else "π‘ Default YOLOv8n (Upload best.pt for better results)"
gr.Markdown(f"**Current Model Status**: {model_status}")
return interface
# Launch the application
if __name__ == "__main__":
interface = create_interface()
interface.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
show_error=True,
quiet=False
) |