Vector73 commited on
Commit
82d82cc
·
1 Parent(s): e3d6299

Added YOLO model.

Browse files
Files changed (5) hide show
  1. .gitignore +3 -1
  2. app.py +300 -4
  3. requirements.txt +2 -1
  4. utils/helpers.py +49 -0
  5. utils/onnx_inference.py +47 -0
.gitignore CHANGED
@@ -1 +1,3 @@
1
- **/__pycache__
 
 
 
1
+ **/__pycache__
2
+ *.onnx
3
+ *.pth
app.py CHANGED
@@ -17,6 +17,9 @@ from utils.helpers import calculate_deforestation_metrics, create_overlay
17
  from utils.audio_processing import preprocess_audio
18
  from utils.audio_model import load_audio_model, predict_audio, class_names
19
 
 
 
 
20
  # Ensure torch classes path is initialized to avoid warnings
21
  torch.classes.__path__ = []
22
 
@@ -31,12 +34,15 @@ st.set_page_config(
31
  # Constants
32
  DEFOREST_MODEL_INPUT_SIZE = 256
33
  AUDIO_MODEL_PATH = "models/best_model.pth"
 
34
 
35
  # Initialize session state for navigation
36
  if 'current_service' not in st.session_state:
37
  st.session_state.current_service = 'deforestation'
38
  if 'audio_input_method' not in st.session_state:
39
  st.session_state.audio_input_method = 'upload'
 
 
40
 
41
  # Sidebar for navigation
42
  with st.sidebar:
@@ -45,9 +51,15 @@ with st.sidebar:
45
 
46
  selected_service = st.radio(
47
  "Select Service:",
48
- ["Deforestation Detection", "Forest Audio Surveillance"]
49
  )
50
- st.session_state.current_service = 'deforestation' if selected_service == "Deforestation Detection" else 'audio'
 
 
 
 
 
 
51
 
52
  st.markdown("---")
53
 
@@ -60,7 +72,7 @@ with st.sidebar:
60
  Upload satellite or aerial images to detect areas of deforestation.
61
  """
62
  )
63
- else:
64
  st.info(
65
  """
66
  **Forest Audio Surveillance**
@@ -92,6 +104,44 @@ with st.sidebar:
92
  st.markdown("🔨 **Tool Sounds:** " + ", ".join([s.capitalize() for s in tool_sounds]))
93
  st.markdown("🚗 **Vehicle Sounds:** " + ", ".join([s.capitalize() for s in vehicle_sounds]))
94
  st.markdown("💥 **Other Sounds:** " + ", ".join([s.capitalize() for s in other_sounds]))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
  # Load deforestation model
97
  @st.cache_resource
@@ -104,6 +154,10 @@ def load_cached_deforestation_model():
104
  def load_cached_audio_model():
105
  return load_audio_model(AUDIO_MODEL_PATH)
106
 
 
 
 
 
107
  # Process image for deforestation detection
108
  def process_image(model, image):
109
  """Process a single image and return results"""
@@ -379,13 +433,255 @@ def show_audio_classification():
379
  else:
380
  st.write("Waiting for recording...")
381
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
382
  # Main function
383
  def main():
384
  # Check which service is selected and render appropriate UI
385
  if st.session_state.current_service == 'deforestation':
386
  show_deforestation_detection()
387
- else:
388
  show_audio_classification()
 
 
389
 
390
  # Footer
391
  st.markdown("---")
 
17
  from utils.audio_processing import preprocess_audio
18
  from utils.audio_model import load_audio_model, predict_audio, class_names
19
 
20
+ # Import YOLO detection modules
21
+ from utils.onnx_inference import YOLOv11
22
+
23
  # Ensure torch classes path is initialized to avoid warnings
24
  torch.classes.__path__ = []
25
 
 
34
  # Constants
35
  DEFOREST_MODEL_INPUT_SIZE = 256
36
  AUDIO_MODEL_PATH = "models/best_model.pth"
37
+ YOLO_MODEL_PATH = "models/best_model.onnx"
38
 
39
  # Initialize session state for navigation
40
  if 'current_service' not in st.session_state:
41
  st.session_state.current_service = 'deforestation'
42
  if 'audio_input_method' not in st.session_state:
43
  st.session_state.audio_input_method = 'upload'
44
+ if 'detection_input_method' not in st.session_state:
45
+ st.session_state.detection_input_method = 'image'
46
 
47
  # Sidebar for navigation
48
  with st.sidebar:
 
51
 
52
  selected_service = st.radio(
53
  "Select Service:",
54
+ ["Deforestation Detection", "Forest Audio Surveillance", "Object Detection"]
55
  )
56
+
57
+ if selected_service == "Deforestation Detection":
58
+ st.session_state.current_service = 'deforestation'
59
+ elif selected_service == "Forest Audio Surveillance":
60
+ st.session_state.current_service = 'audio'
61
+ else:
62
+ st.session_state.current_service = 'detection'
63
 
64
  st.markdown("---")
65
 
 
72
  Upload satellite or aerial images to detect areas of deforestation.
73
  """
74
  )
75
+ elif st.session_state.current_service == 'audio':
76
  st.info(
77
  """
78
  **Forest Audio Surveillance**
 
104
  st.markdown("🔨 **Tool Sounds:** " + ", ".join([s.capitalize() for s in tool_sounds]))
105
  st.markdown("🚗 **Vehicle Sounds:** " + ", ".join([s.capitalize() for s in vehicle_sounds]))
106
  st.markdown("💥 **Other Sounds:** " + ", ".join([s.capitalize() for s in other_sounds]))
107
+ else: # Object Detection
108
+ st.info(
109
+ """
110
+ **Object Detection**
111
+
112
+ Detect trespassers, vehicles, fires, and other objects in forest surveillance footage.
113
+ """
114
+ )
115
+
116
+ # Detection service specific controls
117
+ st.subheader("Detection Configuration")
118
+ detection_input_method = st.radio(
119
+ "Select Input Method:",
120
+ ("Image", "Video", "Camera"),
121
+ index=0 if st.session_state.detection_input_method == 'image' else
122
+ (1 if st.session_state.detection_input_method == 'video' else 2)
123
+ )
124
+
125
+ if detection_input_method == "Image":
126
+ st.session_state.detection_input_method = 'image'
127
+ elif detection_input_method == "Video":
128
+ st.session_state.detection_input_method = 'video'
129
+ else:
130
+ st.session_state.detection_input_method = 'camera'
131
+
132
+ # Detection threshold controls
133
+ st.subheader("Detection Settings")
134
+ confidence = st.slider("Confidence Threshold", 0.0, 1.0, 0.5)
135
+ iou_thres = st.slider("IoU Threshold", 0.0, 1.0, 0.5)
136
+
137
+ # Detection class information
138
+ st.markdown("**Detection Classes:**")
139
+ st.markdown("🚴 **Bike/Bicycle**")
140
+ st.markdown("🚚 **Bus/Truck**")
141
+ st.markdown("🚗 **Car**")
142
+ st.markdown("🔥 **Fire**")
143
+ st.markdown("👤 **Human**")
144
+ st.markdown("💨 **Smoke**")
145
 
146
  # Load deforestation model
147
  @st.cache_resource
 
154
  def load_cached_audio_model():
155
  return load_audio_model(AUDIO_MODEL_PATH)
156
 
157
+ @st.cache_resource
158
+ def load_cached_yolo_model():
159
+ return YOLOv11(YOLO_MODEL_PATH)
160
+
161
  # Process image for deforestation detection
162
  def process_image(model, image):
163
  """Process a single image and return results"""
 
433
  else:
434
  st.write("Waiting for recording...")
435
 
436
+ # Object Detection UI
437
+ def show_object_detection():
438
+ # App title and description
439
+ st.title("🔍 Forest Object Detection")
440
+ st.markdown(
441
+ """
442
+ Detect trespassers, vehicles, fires, and other objects in forest surveillance footage.
443
+ Choose an input method to begin detection.
444
+ """
445
+ )
446
+
447
+ # Model info
448
+ st.info("⚙️ Object detection model optimized with ONNX runtime for faster inference")
449
+
450
+ # Load model
451
+ try:
452
+ model = load_cached_yolo_model()
453
+ # Update model confidence and IoU thresholds from sidebar
454
+ confidence = st.session_state.get('confidence', 0.5)
455
+ iou_thres = st.session_state.get('iou_thres', 0.5)
456
+ model.conf_thres = confidence
457
+ model.iou_thres = iou_thres
458
+ except Exception as e:
459
+ st.error(f"Error loading model: {e}")
460
+ st.info(
461
+ "Make sure you have the YOLO ONNX model file available at models/best_model.onnx"
462
+ )
463
+ return
464
+
465
+ # Input method based selection
466
+ if st.session_state.detection_input_method == 'image':
467
+ # Image upload
468
+ img_file = st.file_uploader("Upload Image", type=["jpg", "jpeg", "png"])
469
+ if img_file is not None:
470
+ # Load image
471
+ file_bytes = np.asarray(bytearray(img_file.read()), dtype=np.uint8)
472
+ image = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR)
473
+ if image is not None:
474
+ # Display original image
475
+ st.subheader("Original Image")
476
+ st.image(
477
+ cv2.cvtColor(image, cv2.COLOR_BGR2RGB),
478
+ caption="Uploaded Image",
479
+ use_container_width=True,
480
+ )
481
+
482
+ # Process with detection model
483
+ with st.spinner("Processing image..."):
484
+ try:
485
+ detections = model.detect(image)
486
+ result_image = model.draw_detections(image.copy(), detections)
487
+
488
+ # Display results
489
+ st.subheader("Detection Results")
490
+ st.image(
491
+ cv2.cvtColor(result_image, cv2.COLOR_BGR2RGB),
492
+ caption="Detected Objects",
493
+ use_container_width=True,
494
+ )
495
+
496
+ # Display detection statistics
497
+ st.subheader("Detection Statistics")
498
+
499
+ # Count detections by class
500
+ class_counts = {}
501
+ for det in detections:
502
+ class_name = det['class']
503
+ if class_name in class_counts:
504
+ class_counts[class_name] += 1
505
+ else:
506
+ class_counts[class_name] = 1
507
+
508
+ # Display counts with emojis
509
+ cols = st.columns(3)
510
+ col_idx = 0
511
+
512
+ for class_name, count in class_counts.items():
513
+ emoji = "👤" if class_name == "human" else (
514
+ "🔥" if class_name == "fire" else (
515
+ "💨" if class_name == "smoke" else (
516
+ "🚗" if class_name == "car" else (
517
+ "🚴" if class_name == "bike-bicycle" else "🚚"))))
518
+
519
+ with cols[col_idx % 3]:
520
+ st.metric(f"{emoji} {class_name.capitalize()}", count)
521
+ col_idx += 1
522
+
523
+ # Check for priority threats
524
+ if "fire" in class_counts or "smoke" in class_counts:
525
+ st.error("🚨 **ALERT: Fire Detected!** Potential forest fire detected. Immediate action required.")
526
+
527
+ if "human" in class_counts or "car" in class_counts or "bike-bicycle" in class_counts or "bus-truck" in class_counts:
528
+ st.warning("⚠️ **Trespassers Detected!** Unauthorized entry detected in monitored area.")
529
+
530
+ except Exception as e:
531
+ st.error(f"Error during detection: {e}")
532
+ st.exception(e)
533
+
534
+ elif st.session_state.detection_input_method == 'video':
535
+ # Video upload
536
+ video_file = st.file_uploader("Upload Video", type=["mp4", "avi", "mov"])
537
+ if video_file is not None:
538
+ # Save uploaded video to temp file
539
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') as tfile:
540
+ tfile.write(video_file.read())
541
+ temp_video_path = tfile.name
542
+
543
+ # Display video upload success
544
+ st.success("Video uploaded successfully!")
545
+
546
+ # Process video button
547
+ if st.button("Process Video"):
548
+ with st.spinner("Processing video... This may take a while."):
549
+ try:
550
+ # Open video file
551
+ cap = cv2.VideoCapture(temp_video_path)
552
+
553
+ # Create video writer for output
554
+ output_path = "output_video.mp4"
555
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
556
+ fps = cap.get(cv2.CAP_PROP_FPS)
557
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
558
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
559
+ out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
560
+
561
+ # Create placeholder for video frames
562
+ video_placeholder = st.empty()
563
+ status_text = st.empty()
564
+
565
+ # Process frames
566
+ frame_count = 0
567
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
568
+
569
+ while cap.isOpened():
570
+ ret, frame = cap.read()
571
+ if not ret:
572
+ break
573
+
574
+ # Process every 5th frame for speed
575
+ if frame_count % 5 == 0:
576
+ detections = model.detect(frame)
577
+ result_frame = model.draw_detections(frame.copy(), detections)
578
+
579
+ # Update preview
580
+ if frame_count % 15 == 0: # Update display less frequently
581
+ video_placeholder.image(
582
+ cv2.cvtColor(result_frame, cv2.COLOR_BGR2RGB),
583
+ caption="Processing Video",
584
+ use_container_width=True
585
+ )
586
+ progress = min(100, int((frame_count / total_frames) * 100))
587
+ status_text.text(f"Processing: {progress}% complete")
588
+ else:
589
+ result_frame = frame # Skip detection on some frames
590
+
591
+ # Write frame to output video
592
+ out.write(result_frame)
593
+ frame_count += 1
594
+
595
+ # Release resources
596
+ cap.release()
597
+ out.release()
598
+
599
+ # Display completion message
600
+ st.success("Video processing complete!")
601
+
602
+ # Provide download button for processed video
603
+ with open(output_path, "rb") as file:
604
+ st.download_button(
605
+ label="Download Processed Video",
606
+ data=file,
607
+ file_name="forest_surveillance_results.mp4",
608
+ mime="video/mp4"
609
+ )
610
+
611
+ except Exception as e:
612
+ st.error(f"Error processing video: {e}")
613
+ st.exception(e)
614
+ finally:
615
+ # Clean up temp file
616
+ try:
617
+ os.unlink(temp_video_path)
618
+ except:
619
+ pass
620
+
621
+ else: # Camera mode
622
+ # Live camera feed
623
+ st.subheader("Live Camera Detection")
624
+ st.info("Use your webcam to detect objects in real-time")
625
+
626
+ cam = st.camera_input("Camera Feed")
627
+
628
+ if cam:
629
+ # Process camera input
630
+ with st.spinner("Processing image..."):
631
+ try:
632
+ # Convert image
633
+ image = cv2.imdecode(np.frombuffer(cam.getvalue(), np.uint8), cv2.IMREAD_COLOR)
634
+
635
+ # Run detection
636
+ detections = model.detect(image)
637
+ result_image = model.draw_detections(image.copy(), detections)
638
+
639
+ # Display results
640
+ st.image(
641
+ cv2.cvtColor(result_image, cv2.COLOR_BGR2RGB),
642
+ caption="Detection Results",
643
+ use_container_width=True
644
+ )
645
+
646
+ # Show detection summary
647
+ if detections:
648
+ # Count detections by class
649
+ class_counts = {}
650
+ for det in detections:
651
+ class_name = det['class']
652
+ if class_name in class_counts:
653
+ class_counts[class_name] += 1
654
+ else:
655
+ class_counts[class_name] = 1
656
+
657
+ # Display as metrics
658
+ st.subheader("Detection Summary")
659
+ cols = st.columns(3)
660
+ for i, (class_name, count) in enumerate(class_counts.items()):
661
+ with cols[i % 3]:
662
+ st.metric(class_name.capitalize(), count)
663
+
664
+ # Check for priority threats
665
+ if "fire" in class_counts or "smoke" in class_counts:
666
+ st.error("🚨 **ALERT: Fire Detected!** Potential forest fire detected.")
667
+
668
+ if "human" in class_counts:
669
+ st.warning("⚠️ **Trespasser Detected!** Human presence detected.")
670
+ else:
671
+ st.info("No objects detected in frame")
672
+
673
+ except Exception as e:
674
+ st.error(f"Error processing camera feed: {e}")
675
+
676
  # Main function
677
  def main():
678
  # Check which service is selected and render appropriate UI
679
  if st.session_state.current_service == 'deforestation':
680
  show_deforestation_detection()
681
+ elif st.session_state.current_service == 'audio':
682
  show_audio_classification()
683
+ else: # 'detection'
684
+ show_object_detection()
685
 
686
  # Footer
687
  st.markdown("---")
requirements.txt CHANGED
@@ -13,4 +13,5 @@ onnxruntime-gpu
13
  onnx
14
  librosa
15
  soundfile
16
- pydub
 
 
13
  onnx
14
  librosa
15
  soundfile
16
+ pydub
17
+ supervision
utils/helpers.py CHANGED
@@ -71,3 +71,52 @@ def create_overlay(original_image, mask, threshold=0.5, alpha=0.5):
71
  overlay = cv2.addWeighted(original_image, 1 - alpha, colored_mask, alpha, 0)
72
 
73
  return overlay
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  overlay = cv2.addWeighted(original_image, 1 - alpha, colored_mask, alpha, 0)
72
 
73
  return overlay
74
+
75
+
76
+ CLASS_NAMES = ['bike-bicycle', 'bus-truck', 'car', 'fire', 'human', 'smoke']
77
+ COLORS = np.random.uniform(0, 255, size=(len(CLASS_NAMES), 3))
78
+
79
+ def preprocess(image, img_size=640):
80
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
81
+ image = cv2.resize(image, (img_size, img_size))
82
+ image = image.transpose((2, 0, 1)) # HWC to CHW
83
+ image = np.ascontiguousarray(image, dtype=np.float32) / 255.0
84
+ return image[np.newaxis, ...]
85
+
86
+ def postprocess(outputs, conf_thresh=0.5, iou_thresh=0.5):
87
+ outputs = outputs[0].transpose()
88
+ boxes, scores, class_ids = [], [], []
89
+
90
+ for row in outputs:
91
+ cls_scores = row[4:4+len(CLASS_NAMES)]
92
+ class_id = np.argmax(cls_scores)
93
+ max_score = cls_scores[class_id]
94
+
95
+ if max_score >= conf_thresh:
96
+ cx, cy, w, h = row[:4]
97
+ x = (cx - w/2).item() # Convert to Python float
98
+ y = (cy - h/2).item()
99
+ width = w.item()
100
+ height = h.item()
101
+ boxes.append([x, y, width, height])
102
+ scores.append(float(max_score))
103
+ class_ids.append(int(class_id))
104
+
105
+ if len(boxes) > 0:
106
+ # Convert to list of lists with native Python floats
107
+ boxes = [[float(x) for x in box] for box in boxes]
108
+ scores = [float(score) for score in scores]
109
+
110
+ indices = cv2.dnn.NMSBoxes(
111
+ bboxes=boxes,
112
+ scores=scores,
113
+ score_threshold=conf_thresh,
114
+ nms_threshold=iou_thresh
115
+ )
116
+
117
+ if len(indices) > 0:
118
+ boxes = [boxes[i] for i in indices.flatten()]
119
+ scores = [scores[i] for i in indices.flatten()]
120
+ class_ids = [class_ids[i] for i in indices.flatten()]
121
+
122
+ return boxes, scores, class_ids
utils/onnx_inference.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import onnxruntime as ort
4
+ from .helpers import CLASS_NAMES, COLORS, preprocess, postprocess
5
+
6
+ class YOLOv11:
7
+ def __init__(self, onnx_path, conf_thres=0.5, iou_thres=0.5):
8
+ self.session = ort.InferenceSession(onnx_path)
9
+ self.conf_thres = conf_thres
10
+ self.iou_thres = iou_thres
11
+ self.input_name = self.session.get_inputs()[0].name
12
+ self.output_name = self.session.get_outputs()[0].name
13
+
14
+ # Verify input type
15
+ input_type = self.session.get_inputs()[0].type
16
+ assert "float" in input_type, f"Model expects {input_type}"
17
+
18
+ def detect(self, image):
19
+ orig_h, orig_w = image.shape[:2]
20
+ blob = preprocess(image)
21
+ outputs = self.session.run([self.output_name], {self.input_name: blob})
22
+ boxes, scores, class_ids = postprocess(outputs, self.conf_thres, self.iou_thres)
23
+
24
+ results = []
25
+ for box, score, class_id in zip(boxes, scores, class_ids):
26
+ x, y, w, h = box
27
+ x1 = int(x * orig_w / 640)
28
+ y1 = int(y * orig_h / 640)
29
+ x2 = int((x + w) * orig_w / 640)
30
+ y2 = int((y + h) * orig_h / 640)
31
+
32
+ results.append({
33
+ 'class': CLASS_NAMES[class_id],
34
+ 'confidence': score,
35
+ 'box': [x1, y1, x2, y2]
36
+ })
37
+ return results
38
+
39
+ def draw_detections(self, image, detections):
40
+ for det in detections:
41
+ x1, y1, x2, y2 = det['box']
42
+ color = COLORS[CLASS_NAMES.index(det['class'])]
43
+ cv2.rectangle(image, (x1, y1), (x2, y2), color, 2)
44
+ label = f"{det['class']}: {det['confidence']:.2f}"
45
+ cv2.putText(image, label, (x1, y1 - 10),
46
+ cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2)
47
+ return image