Update app.py
Browse files
app.py
CHANGED
@@ -32,17 +32,61 @@ class DroneDetectionPipeline:
|
|
32 |
self._load_model()
|
33 |
|
34 |
def _load_model(self) -> None:
|
35 |
-
"""Load YOLO model with error handling"""
|
36 |
try:
|
37 |
if os.path.exists(self.model_path):
|
38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
logger.info(f"β
Model loaded successfully from {self.model_path}")
|
40 |
else:
|
41 |
logger.error(f"β Model file not found: {self.model_path}")
|
42 |
-
|
|
|
|
|
|
|
|
|
43 |
except Exception as e:
|
44 |
logger.error(f"β Error loading model: {str(e)}")
|
45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
|
47 |
def _initialize_tracker(self, frame: np.ndarray, bbox: Tuple[int, int, int, int]) -> bool:
|
48 |
"""Initialize CSRT tracker with given bounding box"""
|
@@ -210,22 +254,39 @@ class DroneDetectionPipeline:
|
|
210 |
pipeline = DroneDetectionPipeline()
|
211 |
|
212 |
def process_video_gradio(input_video, confidence_threshold):
|
213 |
-
"""Gradio interface function"""
|
214 |
if input_video is None:
|
215 |
return None, "β Please upload a video file"
|
216 |
|
217 |
try:
|
|
|
|
|
|
|
|
|
218 |
# Update confidence threshold
|
219 |
pipeline.confidence_threshold = confidence_threshold
|
220 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
221 |
# Process video
|
222 |
output_path = pipeline.process_video(input_video)
|
223 |
|
224 |
-
|
|
|
|
|
|
|
|
|
225 |
|
226 |
except Exception as e:
|
227 |
error_msg = f"β Error processing video: {str(e)}"
|
228 |
logger.error(error_msg)
|
|
|
229 |
return None, error_msg
|
230 |
|
231 |
# Gradio Interface
|
@@ -304,13 +365,24 @@ def create_interface():
|
|
304 |
gr.Markdown(article)
|
305 |
|
306 |
# Example section
|
307 |
-
gr.Markdown("### π System Requirements")
|
308 |
gr.Markdown("""
|
309 |
-
- Input video formats
|
310 |
-
- Maximum file size
|
311 |
-
- Recommended resolution
|
312 |
-
- Processing time
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
313 |
""")
|
|
|
|
|
|
|
|
|
314 |
|
315 |
return interface
|
316 |
|
|
|
32 |
self._load_model()
|
33 |
|
34 |
def _load_model(self) -> None:
|
35 |
+
"""Load YOLO model with error handling and PyTorch compatibility"""
|
36 |
try:
|
37 |
if os.path.exists(self.model_path):
|
38 |
+
# Handle PyTorch 2.6+ security restrictions for model loading
|
39 |
+
import torch
|
40 |
+
from ultralytics.nn.tasks import DetectionModel
|
41 |
+
|
42 |
+
# Add safe globals for ultralytics models
|
43 |
+
torch.serialization.add_safe_globals([DetectionModel])
|
44 |
+
|
45 |
+
# Load model with weights_only=False for older model formats
|
46 |
+
try:
|
47 |
+
self.model = YOLO(self.model_path)
|
48 |
+
except Exception as pytorch_error:
|
49 |
+
logger.warning(f"β οΈ Standard loading failed: {str(pytorch_error)}")
|
50 |
+
logger.info("π Attempting alternative loading method...")
|
51 |
+
|
52 |
+
# Alternative loading for older PyTorch model formats
|
53 |
+
import tempfile
|
54 |
+
import shutil
|
55 |
+
|
56 |
+
# Create temporary model with legacy loading
|
57 |
+
with tempfile.TemporaryDirectory() as temp_dir:
|
58 |
+
temp_model_path = os.path.join(temp_dir, "temp_model.pt")
|
59 |
+
shutil.copy2(self.model_path, temp_model_path)
|
60 |
+
|
61 |
+
# Patch torch.load temporarily for this specific model
|
62 |
+
original_load = torch.load
|
63 |
+
def patched_load(f, **kwargs):
|
64 |
+
kwargs['weights_only'] = False
|
65 |
+
return original_load(f, **kwargs)
|
66 |
+
|
67 |
+
torch.load = patched_load
|
68 |
+
try:
|
69 |
+
self.model = YOLO(temp_model_path)
|
70 |
+
finally:
|
71 |
+
torch.load = original_load
|
72 |
+
|
73 |
logger.info(f"β
Model loaded successfully from {self.model_path}")
|
74 |
else:
|
75 |
logger.error(f"β Model file not found: {self.model_path}")
|
76 |
+
# Try to load a default YOLOv8 model if custom model is not found
|
77 |
+
logger.info("π Loading default YOLOv8n model as fallback...")
|
78 |
+
self.model = YOLO('yolov8n.pt')
|
79 |
+
logger.warning("β οΈ Using default YOLOv8n model - detection accuracy may be reduced")
|
80 |
+
|
81 |
except Exception as e:
|
82 |
logger.error(f"β Error loading model: {str(e)}")
|
83 |
+
logger.info("π Attempting to load default YOLOv8n model as final fallback...")
|
84 |
+
try:
|
85 |
+
self.model = YOLO('yolov8n.pt')
|
86 |
+
logger.warning("β οΈ Using default YOLOv8n model - please upload your custom 'best.pt' model for optimal drone detection")
|
87 |
+
except Exception as fallback_error:
|
88 |
+
logger.error(f"β Failed to load any model: {str(fallback_error)}")
|
89 |
+
raise RuntimeError("Could not load any YOLO model. Please check your model file and dependencies.")
|
90 |
|
91 |
def _initialize_tracker(self, frame: np.ndarray, bbox: Tuple[int, int, int, int]) -> bool:
|
92 |
"""Initialize CSRT tracker with given bounding box"""
|
|
|
254 |
pipeline = DroneDetectionPipeline()
|
255 |
|
256 |
def process_video_gradio(input_video, confidence_threshold):
|
257 |
+
"""Gradio interface function with enhanced error handling"""
|
258 |
if input_video is None:
|
259 |
return None, "β Please upload a video file"
|
260 |
|
261 |
try:
|
262 |
+
# Check if model is loaded properly
|
263 |
+
if pipeline.model is None:
|
264 |
+
return None, "β Model not loaded. Please check that 'best.pt' is uploaded to the space."
|
265 |
+
|
266 |
# Update confidence threshold
|
267 |
pipeline.confidence_threshold = confidence_threshold
|
268 |
|
269 |
+
# Add file size check
|
270 |
+
file_size = os.path.getsize(input_video)
|
271 |
+
max_size = 200 * 1024 * 1024 # 200MB
|
272 |
+
if file_size > max_size:
|
273 |
+
return None, f"β File too large ({file_size/1024/1024:.1f}MB). Maximum size is 200MB."
|
274 |
+
|
275 |
+
logger.info(f"πΉ Processing video: {input_video} ({file_size/1024/1024:.1f}MB)")
|
276 |
+
|
277 |
# Process video
|
278 |
output_path = pipeline.process_video(input_video)
|
279 |
|
280 |
+
success_msg = "β
Video processing completed successfully!"
|
281 |
+
if "yolov8n" in str(pipeline.model.model):
|
282 |
+
success_msg += "\nβ οΈ Note: Using default YOLOv8n model. Upload 'best.pt' for optimal drone detection."
|
283 |
+
|
284 |
+
return output_path, success_msg
|
285 |
|
286 |
except Exception as e:
|
287 |
error_msg = f"β Error processing video: {str(e)}"
|
288 |
logger.error(error_msg)
|
289 |
+
logger.error("Full traceback:", exc_info=True)
|
290 |
return None, error_msg
|
291 |
|
292 |
# Gradio Interface
|
|
|
365 |
gr.Markdown(article)
|
366 |
|
367 |
# Example section
|
368 |
+
gr.Markdown("### π System Requirements & Setup")
|
369 |
gr.Markdown("""
|
370 |
+
- **Input video formats**: MP4, AVI, MOV, MKV
|
371 |
+
- **Maximum file size**: 200MB
|
372 |
+
- **Recommended resolution**: Up to 1920x1080
|
373 |
+
- **Processing time**: ~1-2 minutes per minute of video
|
374 |
+
|
375 |
+
**π Setup Instructions:**
|
376 |
+
1. Upload your custom trained `best.pt` model to this space for optimal drone detection
|
377 |
+
2. If no custom model is found, the system will use YOLOv8n as fallback
|
378 |
+
3. For best results, use the custom drone detection model trained specifically for your use case
|
379 |
+
|
380 |
+
**π§ Model Upload:** Go to "Files and versions" tab β Upload your `best.pt` file β Restart space
|
381 |
""")
|
382 |
+
|
383 |
+
# Status indicator
|
384 |
+
model_status = "π’ Custom Model" if os.path.exists("best.pt") else "π‘ Default YOLOv8n (Upload best.pt for better results)"
|
385 |
+
gr.Markdown(f"**Current Model Status**: {model_status}")
|
386 |
|
387 |
return interface
|
388 |
|