mlbench123 commited on
Commit
8fd5341
Β·
verified Β·
1 Parent(s): 9602d53

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +326 -0
app.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import cv2
3
+ import numpy as np
4
+ import tempfile
5
+ import os
6
+ from pathlib import Path
7
+ from ultralytics import YOLO
8
+ import torch
9
+ from typing import Optional, Tuple, List
10
+ import logging
11
+
12
+ # Configure logging
13
+ logging.basicConfig(level=logging.INFO)
14
+ logger = logging.getLogger(__name__)
15
+
16
+ class DroneDetectionPipeline:
17
+ """
18
+ Professional drone detection pipeline with YOLO + CSRT tracking
19
+ Designed for UK International Lab deployment
20
+ """
21
+
22
+ def __init__(self, model_path: str = "best.pt"):
23
+ """Initialize the detection pipeline with YOLO model and CSRT tracker"""
24
+ self.model_path = model_path
25
+ self.model = None
26
+ self.tracker = None
27
+ self.tracking_active = False
28
+ self.last_detection_bbox = None
29
+ self.confidence_threshold = 0.5
30
+
31
+ # Load model
32
+ self._load_model()
33
+
34
+ def _load_model(self) -> None:
35
+ """Load YOLO model with error handling"""
36
+ try:
37
+ if os.path.exists(self.model_path):
38
+ self.model = YOLO(self.model_path)
39
+ logger.info(f"βœ… Model loaded successfully from {self.model_path}")
40
+ else:
41
+ logger.error(f"❌ Model file not found: {self.model_path}")
42
+ raise FileNotFoundError(f"Model file not found: {self.model_path}")
43
+ except Exception as e:
44
+ logger.error(f"❌ Error loading model: {str(e)}")
45
+ raise
46
+
47
+ def _initialize_tracker(self, frame: np.ndarray, bbox: Tuple[int, int, int, int]) -> bool:
48
+ """Initialize CSRT tracker with given bounding box"""
49
+ try:
50
+ self.tracker = cv2.TrackerCSRT_create()
51
+ success = self.tracker.init(frame, bbox)
52
+ if success:
53
+ self.tracking_active = True
54
+ self.last_detection_bbox = bbox
55
+ logger.info("βœ… Tracker initialized successfully")
56
+ return success
57
+ except Exception as e:
58
+ logger.error(f"❌ Error initializing tracker: {str(e)}")
59
+ return False
60
+
61
+ def _detect_drones(self, frame: np.ndarray) -> List[Tuple[int, int, int, int, float]]:
62
+ """Run YOLO inference on frame and return detections"""
63
+ try:
64
+ results = self.model(frame, verbose=False, conf=self.confidence_threshold)
65
+ detections = []
66
+
67
+ for result in results:
68
+ if result.boxes is not None:
69
+ for box in result.boxes:
70
+ # Extract coordinates and confidence
71
+ x1, y1, x2, y2 = box.xyxy[0].cpu().numpy().astype(int)
72
+ confidence = float(box.conf[0].cpu().numpy())
73
+ detections.append((x1, y1, x2, y2, confidence))
74
+
75
+ return detections
76
+ except Exception as e:
77
+ logger.error(f"❌ Error in detection: {str(e)}")
78
+ return []
79
+
80
+ def _draw_detection(self, frame: np.ndarray, bbox: Tuple[int, int, int, int],
81
+ confidence: float = None, is_tracking: bool = False) -> np.ndarray:
82
+ """Draw bounding box and label on frame"""
83
+ x1, y1, x2, y2 = bbox
84
+
85
+ # Choose color: Red for detection, Blue for tracking
86
+ color = (0, 0, 255) if not is_tracking else (255, 0, 0)
87
+ label_text = f"Drone (Det)" if not is_tracking else f"Drone (Track)"
88
+
89
+ if confidence is not None and not is_tracking:
90
+ label_text = f"Drone {confidence:.2f}"
91
+
92
+ # Draw bounding box
93
+ cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
94
+
95
+ # Draw label background
96
+ label_size = cv2.getTextSize(label_text, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)[0]
97
+ cv2.rectangle(frame, (x1, y1 - label_size[1] - 10),
98
+ (x1 + label_size[0], y1), color, -1)
99
+
100
+ # Draw label text
101
+ cv2.putText(frame, label_text, (x1, y1 - 5),
102
+ cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
103
+
104
+ return frame
105
+
106
+ def process_video(self, input_video_path: str, progress_callback=None) -> str:
107
+ """
108
+ Process entire video with drone detection and tracking
109
+
110
+ Args:
111
+ input_video_path: Path to input video
112
+ progress_callback: Optional callback for progress updates
113
+
114
+ Returns:
115
+ Path to processed output video
116
+ """
117
+ # Create temporary output file
118
+ output_dir = tempfile.mkdtemp()
119
+ output_path = os.path.join(output_dir, "drone_detection_output.mp4")
120
+
121
+ cap = cv2.VideoCapture(input_video_path)
122
+
123
+ if not cap.isOpened():
124
+ raise ValueError("❌ Could not open input video")
125
+
126
+ # Get video properties
127
+ fps = int(cap.get(cv2.CAP_PROP_FPS))
128
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
129
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
130
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
131
+
132
+ # Initialize video writer
133
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
134
+ out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
135
+
136
+ frame_count = 0
137
+ detection_count = 0
138
+ tracking_count = 0
139
+
140
+ logger.info(f"🎬 Processing video: {total_frames} frames at {fps} FPS")
141
+
142
+ try:
143
+ while True:
144
+ ret, frame = cap.read()
145
+ if not ret:
146
+ break
147
+
148
+ frame_processed = frame.copy()
149
+ current_detections = self._detect_drones(frame)
150
+
151
+ if current_detections:
152
+ # We have detections - use the largest detection
153
+ largest_detection = max(current_detections,
154
+ key=lambda d: (d[2]-d[0]) * (d[3]-d[1]))
155
+
156
+ x1, y1, x2, y2, conf = largest_detection
157
+ detection_bbox = (x1, y1, x2-x1, y2-y1) # Convert to (x, y, w, h)
158
+
159
+ # Draw detection
160
+ frame_processed = self._draw_detection(frame_processed,
161
+ (x1, y1, x2, y2), conf, False)
162
+
163
+ # Reinitialize tracker with new detection
164
+ self._initialize_tracker(frame, detection_bbox)
165
+ detection_count += 1
166
+
167
+ elif self.tracking_active and self.tracker is not None:
168
+ # No detections but tracker is active
169
+ success, tracking_bbox = self.tracker.update(frame)
170
+
171
+ if success:
172
+ x, y, w, h = [int(v) for v in tracking_bbox]
173
+ self.last_detection_bbox = tracking_bbox
174
+
175
+ # Draw tracking box
176
+ frame_processed = self._draw_detection(frame_processed,
177
+ (x, y, x+w, y+h), None, True)
178
+ tracking_count += 1
179
+ else:
180
+ # Tracking failed
181
+ if self.last_detection_bbox is not None:
182
+ # Use last known position
183
+ x, y, w, h = [int(v) for v in self.last_detection_bbox]
184
+ frame_processed = self._draw_detection(frame_processed,
185
+ (x, y, x+w, y+h), None, True)
186
+ self.tracking_active = False
187
+
188
+ # Write processed frame
189
+ out.write(frame_processed)
190
+ frame_count += 1
191
+
192
+ # Update progress
193
+ if progress_callback and frame_count % 10 == 0:
194
+ progress = frame_count / total_frames
195
+ progress_callback(progress, f"Processing frame {frame_count}/{total_frames}")
196
+
197
+ except Exception as e:
198
+ logger.error(f"❌ Error during video processing: {str(e)}")
199
+ raise
200
+ finally:
201
+ cap.release()
202
+ out.release()
203
+
204
+ logger.info(f"βœ… Processing completed!")
205
+ logger.info(f"πŸ“Š Stats: {detection_count} detections, {tracking_count} tracking frames")
206
+
207
+ return output_path
208
+
209
+ # Initialize pipeline
210
+ pipeline = DroneDetectionPipeline()
211
+
212
+ def process_video_gradio(input_video, confidence_threshold):
213
+ """Gradio interface function"""
214
+ if input_video is None:
215
+ return None, "❌ Please upload a video file"
216
+
217
+ try:
218
+ # Update confidence threshold
219
+ pipeline.confidence_threshold = confidence_threshold
220
+
221
+ # Process video
222
+ output_path = pipeline.process_video(input_video)
223
+
224
+ return output_path, "βœ… Video processing completed successfully!"
225
+
226
+ except Exception as e:
227
+ error_msg = f"❌ Error processing video: {str(e)}"
228
+ logger.error(error_msg)
229
+ return None, error_msg
230
+
231
+ # Gradio Interface
232
+ def create_interface():
233
+ """Create professional Gradio interface"""
234
+
235
+ title = "🚁 Professional Drone Detection System"
236
+ description = """
237
+ **Advanced Drone Detection and Tracking Pipeline**
238
+
239
+ This system uses state-of-the-art YOLO object detection combined with CSRT tracking for robust drone detection in video footage.
240
+
241
+ **Features:**
242
+ - Real-time drone detection with confidence scoring
243
+ - Seamless tracking when detections are lost
244
+ - Professional-grade output for research and analysis
245
+
246
+ **Developed for UK International Lab**
247
+ """
248
+
249
+ article = """
250
+ ### Technical Details
251
+ - **Detection Model**: Custom trained YOLO model optimized for drone detection
252
+ - **Tracking Algorithm**: CSRT (Channel and Spatial Reliability Tracker)
253
+ - **Output Format**: MP4 video with annotated detections and tracking
254
+
255
+ ### Usage Instructions
256
+ 1. Upload your video file (supported formats: MP4, AVI, MOV)
257
+ 2. Adjust confidence threshold if needed (0.1-0.9)
258
+ 3. Click "Process Video" and wait for completion
259
+ 4. Download the processed video with drone annotations
260
+
261
+ **Note**: Processing time depends on video length and resolution.
262
+ """
263
+
264
+ with gr.Blocks(theme=gr.themes.Soft(), title=title) as interface:
265
+ gr.Markdown(f"# {title}")
266
+ gr.Markdown(description)
267
+
268
+ with gr.Row():
269
+ with gr.Column():
270
+ input_video = gr.Video(
271
+ label="πŸ“Ή Upload Video",
272
+ format="mp4"
273
+ )
274
+ confidence_slider = gr.Slider(
275
+ minimum=0.1,
276
+ maximum=0.9,
277
+ value=0.5,
278
+ step=0.1,
279
+ label="🎯 Detection Confidence Threshold"
280
+ )
281
+ process_btn = gr.Button(
282
+ "πŸš€ Process Video",
283
+ variant="primary",
284
+ size="lg"
285
+ )
286
+
287
+ with gr.Column():
288
+ output_video = gr.Video(
289
+ label="πŸ“½οΈ Processed Video",
290
+ format="mp4"
291
+ )
292
+ status_text = gr.Textbox(
293
+ label="πŸ“Š Processing Status",
294
+ interactive=False
295
+ )
296
+
297
+ process_btn.click(
298
+ fn=process_video_gradio,
299
+ inputs=[input_video, confidence_slider],
300
+ outputs=[output_video, status_text],
301
+ show_progress=True
302
+ )
303
+
304
+ gr.Markdown(article)
305
+
306
+ # Example section
307
+ gr.Markdown("### πŸ“‹ System Requirements")
308
+ gr.Markdown("""
309
+ - Input video formats: MP4, AVI, MOV, MKV
310
+ - Maximum file size: 200MB
311
+ - Recommended resolution: Up to 1920x1080
312
+ - Processing time: ~1-2 minutes per minute of video
313
+ """)
314
+
315
+ return interface
316
+
317
+ # Launch the application
318
+ if __name__ == "__main__":
319
+ interface = create_interface()
320
+ interface.launch(
321
+ server_name="0.0.0.0",
322
+ server_port=7860,
323
+ share=False,
324
+ show_error=True,
325
+ quiet=False
326
+ )