File size: 19,886 Bytes
8fd5341
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0053d0c
8fd5341
1871e7a
8fd5341
 
 
 
 
3015dd2
8fd5341
 
3015dd2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8fd5341
 
 
3015dd2
 
 
 
 
8fd5341
 
3015dd2
 
 
 
 
 
 
8fd5341
55f6172
 
 
 
 
 
 
 
8fd5341
 
 
 
 
 
 
 
 
55f6172
 
8fd5341
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6a66591
8fd5341
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55f6172
 
 
8fd5341
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55f6172
8fd5341
55f6172
6cdd732
 
 
 
 
 
 
 
 
 
55f6172
6cdd732
55f6172
6cdd732
55f6172
 
6cdd732
55f6172
6cdd732
8fd5341
6cdd732
8fd5341
55f6172
 
 
 
8fd5341
62a7547
8fd5341
55f6172
8fd5341
55f6172
62a7547
55f6172
6cdd732
 
 
 
55f6172
 
 
 
62a7547
55f6172
 
 
 
62a7547
6cdd732
55f6172
6cdd732
55f6172
 
8fd5341
 
 
 
 
55f6172
6cdd732
8fd5341
 
55f6172
 
8fd5341
55f6172
 
 
 
8fd5341
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6cdd732
3015dd2
8fd5341
 
 
 
3015dd2
 
 
 
 
 
 
 
 
 
 
 
8fd5341
 
 
3015dd2
 
 
 
 
8fd5341
 
 
 
3015dd2
8fd5341
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ec9bfca
 
 
 
 
 
 
 
e5d3a96
ec9bfca
 
 
 
 
 
 
 
 
 
 
8fd5341
 
 
 
 
 
 
 
 
 
 
 
 
55f6172
8fd5341
 
 
 
 
 
 
3015dd2
8fd5341
3015dd2
 
 
 
 
 
 
 
 
 
 
8fd5341
3015dd2
 
 
 
8fd5341
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
import gradio as gr
import cv2
import numpy as np
import tempfile
import os
from pathlib import Path
from ultralytics import YOLO
import torch
from typing import Optional, Tuple, List
import logging

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class DroneDetectionPipeline:
    """
    Professional drone detection pipeline with YOLO + CSRT tracking
    Designed for UK International Lab deployment
    """
    
    def __init__(self, model_path: str = "best.pt"):
        """Initialize the detection pipeline with YOLO model and CSRT tracker"""
        self.model_path = model_path
        self.model = None
        self.tracker = None
        self.tracking_active = False
        self.has_had_detection = False
        self.last_detection_bbox = None
        self.confidence_threshold = 0.5
        
        # Load model
        self._load_model()
        
    def _load_model(self) -> None:
        """Load YOLO model with error handling and PyTorch compatibility"""
        try:
            if os.path.exists(self.model_path):
                # Handle PyTorch 2.6+ security restrictions for model loading
                import torch
                from ultralytics.nn.tasks import DetectionModel
                
                # Add safe globals for ultralytics models
                torch.serialization.add_safe_globals([DetectionModel])
                
                # Load model with weights_only=False for older model formats
                try:
                    self.model = YOLO(self.model_path)
                except Exception as pytorch_error:
                    logger.warning(f"⚠️ Standard loading failed: {str(pytorch_error)}")
                    logger.info("πŸ”„ Attempting alternative loading method...")
                    
                    # Alternative loading for older PyTorch model formats
                    import tempfile
                    import shutil
                    
                    # Create temporary model with legacy loading
                    with tempfile.TemporaryDirectory() as temp_dir:
                        temp_model_path = os.path.join(temp_dir, "temp_model.pt")
                        shutil.copy2(self.model_path, temp_model_path)
                        
                        # Patch torch.load temporarily for this specific model
                        original_load = torch.load
                        def patched_load(f, **kwargs):
                            kwargs['weights_only'] = False
                            return original_load(f, **kwargs)
                        
                        torch.load = patched_load
                        try:
                            self.model = YOLO(temp_model_path)
                        finally:
                            torch.load = original_load
                
                logger.info(f"βœ… Model loaded successfully from {self.model_path}")
            else:
                logger.error(f"❌ Model file not found: {self.model_path}")
                # Try to load a default YOLOv8 model if custom model is not found
                logger.info("πŸ”„ Loading default YOLOv8n model as fallback...")
                self.model = YOLO('yolov8n.pt')
                logger.warning("⚠️ Using default YOLOv8n model - detection accuracy may be reduced")
                
        except Exception as e:
            logger.error(f"❌ Error loading model: {str(e)}")
            logger.info("πŸ”„ Attempting to load default YOLOv8n model as final fallback...")
            try:
                self.model = YOLO('yolov8n.pt')
                logger.warning("⚠️ Using default YOLOv8n model - please upload your custom 'best.pt' model for optimal drone detection")
            except Exception as fallback_error:
                logger.error(f"❌ Failed to load any model: {str(fallback_error)}")
                raise RuntimeError("Could not load any YOLO model. Please check your model file and dependencies.")
    
    def _reset_tracking_state(self):
        """Reset all tracking-related state variables"""
        self.tracker = None
        self.tracking_active = False
        self.has_had_detection = False
        self.last_detection_bbox = None
        logger.info("πŸ”„ Tracking state reset")
    
    def _initialize_tracker(self, frame: np.ndarray, bbox: Tuple[int, int, int, int]) -> bool:
        """Initialize CSRT tracker with given bounding box"""
        try:
            self.tracker = cv2.TrackerCSRT_create()
            success = self.tracker.init(frame, bbox)
            if success:
                self.tracking_active = True
                self.last_detection_bbox = bbox
                logger.info("βœ… Tracker initialized successfully")
            else:
                logger.warning("⚠️ Tracker initialization failed")
            return success
        except Exception as e:
            logger.error(f"❌ Error initializing tracker: {str(e)}")
            return False
    
    def _detect_drones(self, frame: np.ndarray) -> List[Tuple[int, int, int, int, float]]:
        """Run YOLO inference on frame and return detections"""
        try:
            results = self.model(frame, verbose=False, conf=self.confidence_threshold)
            detections = []
            
            for result in results:
                if result.boxes is not None:
                    for box in result.boxes:
                        # Extract coordinates and confidence
                        x1, y1, x2, y2 = box.xyxy[0].cpu().numpy().astype(int)
                        confidence = float(box.conf[0].cpu().numpy())
                        detections.append((x1, y1, x2, y2, confidence))
            
            return detections
        except Exception as e:
            logger.error(f"❌ Error in detection: {str(e)}")
            return []
    
    def _draw_detection(self, frame: np.ndarray, bbox: Tuple[int, int, int, int], 
                       confidence: float = None, is_tracking: bool = False) -> np.ndarray:
        """Draw bounding box and label on frame"""
        x1, y1, x2, y2 = bbox
        
        # Choose color: Red for detection, Blue for tracking
        color = (0, 0, 255) if not is_tracking else (255, 0, 0)
        label_text = f"Drone (Detection)" if not is_tracking else f"Drone (Tracker)"
        
        if confidence is not None and not is_tracking:
            label_text = f"Drone {confidence:.2f}"
        
        # Draw bounding box
        cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
        
        # Draw label background
        label_size = cv2.getTextSize(label_text, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)[0]
        cv2.rectangle(frame, (x1, y1 - label_size[1] - 10), 
                     (x1 + label_size[0], y1), color, -1)
        
        # Draw label text
        cv2.putText(frame, label_text, (x1, y1 - 5), 
                   cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
        
        return frame
    
    def process_video(self, input_video_path: str, progress_callback=None) -> str:
        """
        Process entire video with drone detection and tracking
        
        Args:
            input_video_path: Path to input video
            progress_callback: Optional callback for progress updates
            
        Returns:
            Path to processed output video
        """
        # IMPORTANT: Reset tracking state at the beginning of each video
        self._reset_tracking_state()
        
        # Create temporary output file
        output_dir = tempfile.mkdtemp()
        output_path = os.path.join(output_dir, "drone_detection_output.mp4")
        
        cap = cv2.VideoCapture(input_video_path)
        
        if not cap.isOpened():
            raise ValueError("❌ Could not open input video")
        
        # Get video properties
        fps = int(cap.get(cv2.CAP_PROP_FPS))
        width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        
        # Initialize video writer
        fourcc = cv2.VideoWriter_fourcc(*'mp4v')
        out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
        
        frame_count = 0
        detection_count = 0
        tracking_count = 0
        
        logger.info(f"🎬 Processing video: {total_frames} frames at {fps} FPS")
        
        try:
            while True:
                ret, frame = cap.read()
                if not ret:
                    break
                
                frame_processed = frame.copy()
                detection_made_this_frame = False
                
                # Get all detections with different confidence levels
                all_detections = self.model(frame, verbose=False, conf=0.1)  # Low threshold to get all detections
                high_conf_detections = []
                low_conf_detections = []
    
                for result in all_detections:
                    if result.boxes is not None:
                        for box in result.boxes:
                            x1, y1, x2, y2 = box.xyxy[0].cpu().numpy().astype(int)
                            confidence = float(box.conf[0].cpu().numpy())
                            
                            if confidence >= self.confidence_threshold:  # Use the actual confidence threshold
                                high_conf_detections.append((x1, y1, x2, y2, confidence))
                            elif confidence >= 0.1:  # Lower threshold for tracking continuation
                                low_conf_detections.append((x1, y1, x2, y2, confidence))

                # PRIORITY 1: High confidence detection (RED BOX)
                if high_conf_detections:
                    # Select the largest high confidence detection
                    largest_detection = max(high_conf_detections, key=lambda d: (d[2]-d[0]) * (d[3]-d[1]))
                    x1, y1, x2, y2, conf = largest_detection
                    detection_bbox = (x1, y1, x2-x1, y2-y1)
                    
                    # Draw red detection box
                    frame_processed = self._draw_detection(frame_processed, (x1, y1, x2, y2), conf, False)
                    
                    # Initialize/reinitialize tracker with this high confidence detection
                    self._initialize_tracker(frame, detection_bbox)
                    self.has_had_detection = True
                    detection_count += 1
                    detection_made_this_frame = True
                    
                # PRIORITY 2: Low confidence detection (BLUE BOX) - only if we've had a detection before
                elif low_conf_detections and self.has_had_detection:
                    # Select the largest low confidence detection
                    largest_low_conf = max(low_conf_detections, key=lambda d: (d[2]-d[0]) * (d[3]-d[1]))
                    x1, y1, x2, y2, conf = largest_low_conf
                    tracking_bbox = (x1, y1, x2-x1, y2-y1)
                    
                    # Draw blue tracking box for low confidence detection
                    frame_processed = self._draw_detection(frame_processed, (x1, y1, x2, y2), None, True)
                    
                    # Reinitialize tracker with low confidence detection
                    if self.tracker is not None:
                        success = self.tracker.init(frame, tracking_bbox)
                        if success:
                            self.tracking_active = True
                            self.last_detection_bbox = tracking_bbox
                    
                    tracking_count += 1
                    detection_made_this_frame = True
                    
                # PRIORITY 3: CSRT tracking - only if we've had a detection and no current detections
                elif self.has_had_detection and self.tracker is not None and self.tracking_active and not detection_made_this_frame:
                    success, tracking_bbox = self.tracker.update(frame)
                    
                    if success:
                        x, y, w, h = [int(v) for v in tracking_bbox]
                        self.last_detection_bbox = tracking_bbox
                        # Draw blue tracking box
                        frame_processed = self._draw_detection(frame_processed, (x, y, x+w, y+h), None, True)
                        tracking_count += 1
                    else:
                        # Tracking failed - optionally use last known position with reduced confidence
                        logger.debug("πŸ” CSRT tracking failed")
                        self.tracking_active = False
                        # Uncomment below if you want to show last known position when tracking fails
                        # if self.last_detection_bbox is not None:
                        #     x, y, w, h = [int(v) for v in self.last_detection_bbox]
                        #     frame_processed = self._draw_detection(frame_processed, (x, y, x+w, y+h), None, True)
                
                # Write processed frame
                out.write(frame_processed)
                frame_count += 1
                
                # Update progress
                if progress_callback and frame_count % 10 == 0:
                    progress = frame_count / total_frames
                    progress_callback(progress, f"Processing frame {frame_count}/{total_frames}")
        
        except Exception as e:
            logger.error(f"❌ Error during video processing: {str(e)}")
            raise
        finally:
            cap.release()
            out.release()
        
        logger.info(f"βœ… Processing completed!")
        logger.info(f"πŸ“Š Stats: {detection_count} detections, {tracking_count} tracking frames")
        
        return output_path

# Initialize pipeline
pipeline = DroneDetectionPipeline()

def process_video_gradio(input_video):
    """Gradio interface function with enhanced error handling"""
    if input_video is None:
        return None, "❌ Please upload a video file"
    
    try:
        # Check if model is loaded properly
        if pipeline.model is None:
            return None, "❌ Model not loaded. Please check that 'best.pt' is uploaded to the space."
        
        # Add file size check
        file_size = os.path.getsize(input_video)
        max_size = 200 * 1024 * 1024  # 200MB
        if file_size > max_size:
            return None, f"❌ File too large ({file_size/1024/1024:.1f}MB). Maximum size is 200MB."
        
        logger.info(f"πŸ“Ή Processing video: {input_video} ({file_size/1024/1024:.1f}MB)")
        
        # Process video
        output_path = pipeline.process_video(input_video)
        
        success_msg = "βœ… Video processing completed successfully!"
        if "yolov8n" in str(pipeline.model.model):
            success_msg += "\n⚠️ Note: Using default YOLOv8n model. Upload 'best.pt' for optimal drone detection."
        
        return output_path, success_msg
        
    except Exception as e:
        error_msg = f"❌ Error processing video: {str(e)}"
        logger.error(error_msg)
        logger.error("Full traceback:", exc_info=True)
        return None, error_msg

# Gradio Interface
def create_interface():
    """Create professional Gradio interface"""
    
    title = "🚁 Professional Drone Detection System"
    description = """
    **Advanced Drone Detection and Tracking Pipeline**
    
    This system uses state-of-the-art YOLO object detection combined with CSRT tracking for robust drone detection in video footage.
    
    **Features:**
    - Real-time drone detection with confidence scoring
    - Seamless tracking when detections are lost
    - Professional-grade output for research and analysis
    
    """
    
    article = """
    ### Technical Details
    - **Detection Model**: Custom trained YOLO model optimized for drone detection
    - **Tracking Algorithm**: CSRT (Channel and Spatial Reliability Tracker)
    - **Output Format**: MP4 video with annotated detections and tracking
    
    ### Usage Instructions
    1. Upload your video file (supported formats: MP4, AVI, MOV)
    2. Adjust confidence threshold if needed (0.1-0.9)
    3. Click "Process Video" and wait for completion
    4. Download the processed video with drone annotations
    
    **Note**: Processing time depends on video length and resolution.
    """
    
    with gr.Blocks(theme=gr.themes.Soft(), title=title) as interface:
        gr.Markdown(f"# {title}")
        gr.Markdown(description)
        
        with gr.Row():
            with gr.Column():
                input_video = gr.Video(
                    label="πŸ“Ή Upload Video",
                    format="mp4"
                )
                process_btn = gr.Button(
                    "πŸš€ Process Video",
                    variant="primary",
                    size="lg"
                )
                sample_videos = []
                sample_files = ["sample_drone_1.mp4", "sample_drone_2.mp4", "sample_drone_3.mp4"]
                
                for sample_file in sample_files:
                    if os.path.exists(sample_file):
                        sample_videos.append(sample_file)
                
                if sample_videos:
                    gr.Markdown("**πŸ“‚ Or try these test videos (Not a part of training process):**")
                    
                    def load_sample_video(video_path):
                        return video_path
                    
                    with gr.Row():
                        for i, sample_video in enumerate(sample_videos, 1):
                            sample_btn = gr.Button(f"πŸ“Ή Sample {i}", size="sm")
                            sample_btn.click(
                                fn=lambda x=sample_video: x,
                                outputs=input_video
                            )
            
            with gr.Column():
                output_video = gr.Video(
                    label="πŸ“½οΈ Processed Video",
                    format="mp4"
                )
                status_text = gr.Textbox(
                    label="πŸ“Š Processing Status",
                    interactive=False
                )
        
        process_btn.click(
            fn=process_video_gradio,
            inputs=[input_video],
            outputs=[output_video, status_text],
            show_progress=True
        )
        
        gr.Markdown(article)
        
        # Example section
        gr.Markdown("### πŸ“‹ System Requirements & Setup")
        gr.Markdown("""
        - **Input video formats**: MP4, AVI, MOV, MKV
        - **Maximum file size**: 200MB
        - **Recommended resolution**: Up to 1920x1080
        - **Processing time**: ~1-2 minutes per minute of video
        
        **πŸ“ Setup Instructions:**
        1. Upload your custom trained `best.pt` model to this space for optimal drone detection
        2. If no custom model is found, the system will use YOLOv8n as fallback
        3. For best results, use the custom drone detection model trained specifically for your use case
        
        **πŸ”§ Model Upload:** Go to "Files and versions" tab β†’ Upload your `best.pt` file β†’ Restart space
        """)
        
        # Status indicator
        model_status = "🟒 Custom Model" if os.path.exists("best.pt") else "🟑 Default YOLOv8n (Upload best.pt for better results)"
        gr.Markdown(f"**Current Model Status**: {model_status}")
    
    return interface

# Launch the application
if __name__ == "__main__":
    interface = create_interface()
    interface.launch(
        server_name="0.0.0.0",
        server_port=7860,
        share=False,
        show_error=True,
        quiet=False
    )