smokeyScraper commited on
Commit
0e5cf7b
·
unverified ·
2 Parent(s): e3d6299 4655d1b

Merge pull request #2 from kartikbhtt7/add-audio-model

Browse files
Files changed (5) hide show
  1. .gitignore +3 -1
  2. app.py +308 -4
  3. requirements.txt +2 -1
  4. utils/helpers.py +49 -0
  5. utils/onnx_inference.py +47 -0
.gitignore CHANGED
@@ -1 +1,3 @@
1
- **/__pycache__
 
 
 
1
+ **/__pycache__
2
+ *.onnx
3
+ *.pth
app.py CHANGED
@@ -6,9 +6,16 @@ import tempfile
6
  import librosa
7
  import librosa.display
8
  import matplotlib.pyplot as plt
 
 
 
 
9
  from PIL import Image
10
  import torch
11
 
 
 
 
12
  # Import deforestation modules
13
  from prediction_engine import load_onnx_model
14
  from utils.helpers import calculate_deforestation_metrics, create_overlay
@@ -17,6 +24,9 @@ from utils.helpers import calculate_deforestation_metrics, create_overlay
17
  from utils.audio_processing import preprocess_audio
18
  from utils.audio_model import load_audio_model, predict_audio, class_names
19
 
 
 
 
20
  # Ensure torch classes path is initialized to avoid warnings
21
  torch.classes.__path__ = []
22
 
@@ -28,15 +38,19 @@ st.set_page_config(
28
  initial_sidebar_state="expanded"
29
  )
30
 
 
31
  # Constants
32
  DEFOREST_MODEL_INPUT_SIZE = 256
33
  AUDIO_MODEL_PATH = "models/best_model.pth"
 
34
 
35
  # Initialize session state for navigation
36
  if 'current_service' not in st.session_state:
37
  st.session_state.current_service = 'deforestation'
38
  if 'audio_input_method' not in st.session_state:
39
  st.session_state.audio_input_method = 'upload'
 
 
40
 
41
  # Sidebar for navigation
42
  with st.sidebar:
@@ -45,9 +59,15 @@ with st.sidebar:
45
 
46
  selected_service = st.radio(
47
  "Select Service:",
48
- ["Deforestation Detection", "Forest Audio Surveillance"]
49
  )
50
- st.session_state.current_service = 'deforestation' if selected_service == "Deforestation Detection" else 'audio'
 
 
 
 
 
 
51
 
52
  st.markdown("---")
53
 
@@ -60,7 +80,7 @@ with st.sidebar:
60
  Upload satellite or aerial images to detect areas of deforestation.
61
  """
62
  )
63
- else:
64
  st.info(
65
  """
66
  **Forest Audio Surveillance**
@@ -92,6 +112,44 @@ with st.sidebar:
92
  st.markdown("🔨 **Tool Sounds:** " + ", ".join([s.capitalize() for s in tool_sounds]))
93
  st.markdown("🚗 **Vehicle Sounds:** " + ", ".join([s.capitalize() for s in vehicle_sounds]))
94
  st.markdown("💥 **Other Sounds:** " + ", ".join([s.capitalize() for s in other_sounds]))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
  # Load deforestation model
97
  @st.cache_resource
@@ -104,6 +162,10 @@ def load_cached_deforestation_model():
104
  def load_cached_audio_model():
105
  return load_audio_model(AUDIO_MODEL_PATH)
106
 
 
 
 
 
107
  # Process image for deforestation detection
108
  def process_image(model, image):
109
  """Process a single image and return results"""
@@ -379,13 +441,255 @@ def show_audio_classification():
379
  else:
380
  st.write("Waiting for recording...")
381
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
382
  # Main function
383
  def main():
384
  # Check which service is selected and render appropriate UI
385
  if st.session_state.current_service == 'deforestation':
386
  show_deforestation_detection()
387
- else:
388
  show_audio_classification()
 
 
389
 
390
  # Footer
391
  st.markdown("---")
 
6
  import librosa
7
  import librosa.display
8
  import matplotlib.pyplot as plt
9
+ import tempfile
10
+ import librosa
11
+ import librosa.display
12
+ import matplotlib.pyplot as plt
13
  from PIL import Image
14
  import torch
15
 
16
+ # Import deforestation modules
17
+ from prediction_engine import load_onnx_model
18
+
19
  # Import deforestation modules
20
  from prediction_engine import load_onnx_model
21
  from utils.helpers import calculate_deforestation_metrics, create_overlay
 
24
  from utils.audio_processing import preprocess_audio
25
  from utils.audio_model import load_audio_model, predict_audio, class_names
26
 
27
+ # Import YOLO detection modules
28
+ from utils.onnx_inference import YOLOv11
29
+
30
  # Ensure torch classes path is initialized to avoid warnings
31
  torch.classes.__path__ = []
32
 
 
38
  initial_sidebar_state="expanded"
39
  )
40
 
41
+
42
  # Constants
43
  DEFOREST_MODEL_INPUT_SIZE = 256
44
  AUDIO_MODEL_PATH = "models/best_model.pth"
45
+ YOLO_MODEL_PATH = "models/best_model.onnx"
46
 
47
  # Initialize session state for navigation
48
  if 'current_service' not in st.session_state:
49
  st.session_state.current_service = 'deforestation'
50
  if 'audio_input_method' not in st.session_state:
51
  st.session_state.audio_input_method = 'upload'
52
+ if 'detection_input_method' not in st.session_state:
53
+ st.session_state.detection_input_method = 'image'
54
 
55
  # Sidebar for navigation
56
  with st.sidebar:
 
59
 
60
  selected_service = st.radio(
61
  "Select Service:",
62
+ ["Deforestation Detection", "Forest Audio Surveillance", "Object Detection"]
63
  )
64
+
65
+ if selected_service == "Deforestation Detection":
66
+ st.session_state.current_service = 'deforestation'
67
+ elif selected_service == "Forest Audio Surveillance":
68
+ st.session_state.current_service = 'audio'
69
+ else:
70
+ st.session_state.current_service = 'detection'
71
 
72
  st.markdown("---")
73
 
 
80
  Upload satellite or aerial images to detect areas of deforestation.
81
  """
82
  )
83
+ elif st.session_state.current_service == 'audio':
84
  st.info(
85
  """
86
  **Forest Audio Surveillance**
 
112
  st.markdown("🔨 **Tool Sounds:** " + ", ".join([s.capitalize() for s in tool_sounds]))
113
  st.markdown("🚗 **Vehicle Sounds:** " + ", ".join([s.capitalize() for s in vehicle_sounds]))
114
  st.markdown("💥 **Other Sounds:** " + ", ".join([s.capitalize() for s in other_sounds]))
115
+ else: # Object Detection
116
+ st.info(
117
+ """
118
+ **Object Detection**
119
+
120
+ Detect trespassers, vehicles, fires, and other objects in forest surveillance footage.
121
+ """
122
+ )
123
+
124
+ # Detection service specific controls
125
+ st.subheader("Detection Configuration")
126
+ detection_input_method = st.radio(
127
+ "Select Input Method:",
128
+ ("Image", "Video", "Camera"),
129
+ index=0 if st.session_state.detection_input_method == 'image' else
130
+ (1 if st.session_state.detection_input_method == 'video' else 2)
131
+ )
132
+
133
+ if detection_input_method == "Image":
134
+ st.session_state.detection_input_method = 'image'
135
+ elif detection_input_method == "Video":
136
+ st.session_state.detection_input_method = 'video'
137
+ else:
138
+ st.session_state.detection_input_method = 'camera'
139
+
140
+ # Detection threshold controls
141
+ st.subheader("Detection Settings")
142
+ confidence = st.slider("Confidence Threshold", 0.0, 1.0, 0.5)
143
+ iou_thres = st.slider("IoU Threshold", 0.0, 1.0, 0.5)
144
+
145
+ # Detection class information
146
+ st.markdown("**Detection Classes:**")
147
+ st.markdown("🚴 **Bike/Bicycle**")
148
+ st.markdown("🚚 **Bus/Truck**")
149
+ st.markdown("🚗 **Car**")
150
+ st.markdown("🔥 **Fire**")
151
+ st.markdown("👤 **Human**")
152
+ st.markdown("💨 **Smoke**")
153
 
154
  # Load deforestation model
155
  @st.cache_resource
 
162
  def load_cached_audio_model():
163
  return load_audio_model(AUDIO_MODEL_PATH)
164
 
165
+ @st.cache_resource
166
+ def load_cached_yolo_model():
167
+ return YOLOv11(YOLO_MODEL_PATH)
168
+
169
  # Process image for deforestation detection
170
  def process_image(model, image):
171
  """Process a single image and return results"""
 
441
  else:
442
  st.write("Waiting for recording...")
443
 
444
+ # Object Detection UI
445
+ def show_object_detection():
446
+ # App title and description
447
+ st.title("🔍 Forest Object Detection")
448
+ st.markdown(
449
+ """
450
+ Detect trespassers, vehicles, fires, and other objects in forest surveillance footage.
451
+ Choose an input method to begin detection.
452
+ """
453
+ )
454
+
455
+ # Model info
456
+ st.info("⚙️ Object detection model optimized with ONNX runtime for faster inference")
457
+
458
+ # Load model
459
+ try:
460
+ model = load_cached_yolo_model()
461
+ # Update model confidence and IoU thresholds from sidebar
462
+ confidence = st.session_state.get('confidence', 0.5)
463
+ iou_thres = st.session_state.get('iou_thres', 0.5)
464
+ model.conf_thres = confidence
465
+ model.iou_thres = iou_thres
466
+ except Exception as e:
467
+ st.error(f"Error loading model: {e}")
468
+ st.info(
469
+ "Make sure you have the YOLO ONNX model file available at models/best_model.onnx"
470
+ )
471
+ return
472
+
473
+ # Input method based selection
474
+ if st.session_state.detection_input_method == 'image':
475
+ # Image upload
476
+ img_file = st.file_uploader("Upload Image", type=["jpg", "jpeg", "png"])
477
+ if img_file is not None:
478
+ # Load image
479
+ file_bytes = np.asarray(bytearray(img_file.read()), dtype=np.uint8)
480
+ image = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR)
481
+ if image is not None:
482
+ # Display original image
483
+ st.subheader("Original Image")
484
+ st.image(
485
+ cv2.cvtColor(image, cv2.COLOR_BGR2RGB),
486
+ caption="Uploaded Image",
487
+ use_container_width=True,
488
+ )
489
+
490
+ # Process with detection model
491
+ with st.spinner("Processing image..."):
492
+ try:
493
+ detections = model.detect(image)
494
+ result_image = model.draw_detections(image.copy(), detections)
495
+
496
+ # Display results
497
+ st.subheader("Detection Results")
498
+ st.image(
499
+ cv2.cvtColor(result_image, cv2.COLOR_BGR2RGB),
500
+ caption="Detected Objects",
501
+ use_container_width=True,
502
+ )
503
+
504
+ # Display detection statistics
505
+ st.subheader("Detection Statistics")
506
+
507
+ # Count detections by class
508
+ class_counts = {}
509
+ for det in detections:
510
+ class_name = det['class']
511
+ if class_name in class_counts:
512
+ class_counts[class_name] += 1
513
+ else:
514
+ class_counts[class_name] = 1
515
+
516
+ # Display counts with emojis
517
+ cols = st.columns(3)
518
+ col_idx = 0
519
+
520
+ for class_name, count in class_counts.items():
521
+ emoji = "👤" if class_name == "human" else (
522
+ "🔥" if class_name == "fire" else (
523
+ "💨" if class_name == "smoke" else (
524
+ "🚗" if class_name == "car" else (
525
+ "🚴" if class_name == "bike-bicycle" else "🚚"))))
526
+
527
+ with cols[col_idx % 3]:
528
+ st.metric(f"{emoji} {class_name.capitalize()}", count)
529
+ col_idx += 1
530
+
531
+ # Check for priority threats
532
+ if "fire" in class_counts or "smoke" in class_counts:
533
+ st.error("🚨 **ALERT: Fire Detected!** Potential forest fire detected. Immediate action required.")
534
+
535
+ if "human" in class_counts or "car" in class_counts or "bike-bicycle" in class_counts or "bus-truck" in class_counts:
536
+ st.warning("⚠️ **Trespassers Detected!** Unauthorized entry detected in monitored area.")
537
+
538
+ except Exception as e:
539
+ st.error(f"Error during detection: {e}")
540
+ st.exception(e)
541
+
542
+ elif st.session_state.detection_input_method == 'video':
543
+ # Video upload
544
+ video_file = st.file_uploader("Upload Video", type=["mp4", "avi", "mov"])
545
+ if video_file is not None:
546
+ # Save uploaded video to temp file
547
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') as tfile:
548
+ tfile.write(video_file.read())
549
+ temp_video_path = tfile.name
550
+
551
+ # Display video upload success
552
+ st.success("Video uploaded successfully!")
553
+
554
+ # Process video button
555
+ if st.button("Process Video"):
556
+ with st.spinner("Processing video... This may take a while."):
557
+ try:
558
+ # Open video file
559
+ cap = cv2.VideoCapture(temp_video_path)
560
+
561
+ # Create video writer for output
562
+ output_path = "output_video.mp4"
563
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
564
+ fps = cap.get(cv2.CAP_PROP_FPS)
565
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
566
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
567
+ out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
568
+
569
+ # Create placeholder for video frames
570
+ video_placeholder = st.empty()
571
+ status_text = st.empty()
572
+
573
+ # Process frames
574
+ frame_count = 0
575
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
576
+
577
+ while cap.isOpened():
578
+ ret, frame = cap.read()
579
+ if not ret:
580
+ break
581
+
582
+ # Process every 5th frame for speed
583
+ if frame_count % 5 == 0:
584
+ detections = model.detect(frame)
585
+ result_frame = model.draw_detections(frame.copy(), detections)
586
+
587
+ # Update preview
588
+ if frame_count % 15 == 0: # Update display less frequently
589
+ video_placeholder.image(
590
+ cv2.cvtColor(result_frame, cv2.COLOR_BGR2RGB),
591
+ caption="Processing Video",
592
+ use_container_width=True
593
+ )
594
+ progress = min(100, int((frame_count / total_frames) * 100))
595
+ status_text.text(f"Processing: {progress}% complete")
596
+ else:
597
+ result_frame = frame # Skip detection on some frames
598
+
599
+ # Write frame to output video
600
+ out.write(result_frame)
601
+ frame_count += 1
602
+
603
+ # Release resources
604
+ cap.release()
605
+ out.release()
606
+
607
+ # Display completion message
608
+ st.success("Video processing complete!")
609
+
610
+ # Provide download button for processed video
611
+ with open(output_path, "rb") as file:
612
+ st.download_button(
613
+ label="Download Processed Video",
614
+ data=file,
615
+ file_name="forest_surveillance_results.mp4",
616
+ mime="video/mp4"
617
+ )
618
+
619
+ except Exception as e:
620
+ st.error(f"Error processing video: {e}")
621
+ st.exception(e)
622
+ finally:
623
+ # Clean up temp file
624
+ try:
625
+ os.unlink(temp_video_path)
626
+ except:
627
+ pass
628
+
629
+ else: # Camera mode
630
+ # Live camera feed
631
+ st.subheader("Live Camera Detection")
632
+ st.info("Use your webcam to detect objects in real-time")
633
+
634
+ cam = st.camera_input("Camera Feed")
635
+
636
+ if cam:
637
+ # Process camera input
638
+ with st.spinner("Processing image..."):
639
+ try:
640
+ # Convert image
641
+ image = cv2.imdecode(np.frombuffer(cam.getvalue(), np.uint8), cv2.IMREAD_COLOR)
642
+
643
+ # Run detection
644
+ detections = model.detect(image)
645
+ result_image = model.draw_detections(image.copy(), detections)
646
+
647
+ # Display results
648
+ st.image(
649
+ cv2.cvtColor(result_image, cv2.COLOR_BGR2RGB),
650
+ caption="Detection Results",
651
+ use_container_width=True
652
+ )
653
+
654
+ # Show detection summary
655
+ if detections:
656
+ # Count detections by class
657
+ class_counts = {}
658
+ for det in detections:
659
+ class_name = det['class']
660
+ if class_name in class_counts:
661
+ class_counts[class_name] += 1
662
+ else:
663
+ class_counts[class_name] = 1
664
+
665
+ # Display as metrics
666
+ st.subheader("Detection Summary")
667
+ cols = st.columns(3)
668
+ for i, (class_name, count) in enumerate(class_counts.items()):
669
+ with cols[i % 3]:
670
+ st.metric(class_name.capitalize(), count)
671
+
672
+ # Check for priority threats
673
+ if "fire" in class_counts or "smoke" in class_counts:
674
+ st.error("🚨 **ALERT: Fire Detected!** Potential forest fire detected.")
675
+
676
+ if "human" in class_counts:
677
+ st.warning("⚠️ **Trespasser Detected!** Human presence detected.")
678
+ else:
679
+ st.info("No objects detected in frame")
680
+
681
+ except Exception as e:
682
+ st.error(f"Error processing camera feed: {e}")
683
+
684
  # Main function
685
  def main():
686
  # Check which service is selected and render appropriate UI
687
  if st.session_state.current_service == 'deforestation':
688
  show_deforestation_detection()
689
+ elif st.session_state.current_service == 'audio':
690
  show_audio_classification()
691
+ else: # 'detection'
692
+ show_object_detection()
693
 
694
  # Footer
695
  st.markdown("---")
requirements.txt CHANGED
@@ -13,4 +13,5 @@ onnxruntime-gpu
13
  onnx
14
  librosa
15
  soundfile
16
- pydub
 
 
13
  onnx
14
  librosa
15
  soundfile
16
+ pydub
17
+ supervision
utils/helpers.py CHANGED
@@ -71,3 +71,52 @@ def create_overlay(original_image, mask, threshold=0.5, alpha=0.5):
71
  overlay = cv2.addWeighted(original_image, 1 - alpha, colored_mask, alpha, 0)
72
 
73
  return overlay
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  overlay = cv2.addWeighted(original_image, 1 - alpha, colored_mask, alpha, 0)
72
 
73
  return overlay
74
+
75
+
76
+ CLASS_NAMES = ['bike-bicycle', 'bus-truck', 'car', 'fire', 'human', 'smoke']
77
+ COLORS = np.random.uniform(0, 255, size=(len(CLASS_NAMES), 3))
78
+
79
+ def preprocess(image, img_size=640):
80
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
81
+ image = cv2.resize(image, (img_size, img_size))
82
+ image = image.transpose((2, 0, 1)) # HWC to CHW
83
+ image = np.ascontiguousarray(image, dtype=np.float32) / 255.0
84
+ return image[np.newaxis, ...]
85
+
86
+ def postprocess(outputs, conf_thresh=0.5, iou_thresh=0.5):
87
+ outputs = outputs[0].transpose()
88
+ boxes, scores, class_ids = [], [], []
89
+
90
+ for row in outputs:
91
+ cls_scores = row[4:4+len(CLASS_NAMES)]
92
+ class_id = np.argmax(cls_scores)
93
+ max_score = cls_scores[class_id]
94
+
95
+ if max_score >= conf_thresh:
96
+ cx, cy, w, h = row[:4]
97
+ x = (cx - w/2).item() # Convert to Python float
98
+ y = (cy - h/2).item()
99
+ width = w.item()
100
+ height = h.item()
101
+ boxes.append([x, y, width, height])
102
+ scores.append(float(max_score))
103
+ class_ids.append(int(class_id))
104
+
105
+ if len(boxes) > 0:
106
+ # Convert to list of lists with native Python floats
107
+ boxes = [[float(x) for x in box] for box in boxes]
108
+ scores = [float(score) for score in scores]
109
+
110
+ indices = cv2.dnn.NMSBoxes(
111
+ bboxes=boxes,
112
+ scores=scores,
113
+ score_threshold=conf_thresh,
114
+ nms_threshold=iou_thresh
115
+ )
116
+
117
+ if len(indices) > 0:
118
+ boxes = [boxes[i] for i in indices.flatten()]
119
+ scores = [scores[i] for i in indices.flatten()]
120
+ class_ids = [class_ids[i] for i in indices.flatten()]
121
+
122
+ return boxes, scores, class_ids
utils/onnx_inference.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import onnxruntime as ort
4
+ from .helpers import CLASS_NAMES, COLORS, preprocess, postprocess
5
+
6
+ class YOLOv11:
7
+ def __init__(self, onnx_path, conf_thres=0.5, iou_thres=0.5):
8
+ self.session = ort.InferenceSession(onnx_path)
9
+ self.conf_thres = conf_thres
10
+ self.iou_thres = iou_thres
11
+ self.input_name = self.session.get_inputs()[0].name
12
+ self.output_name = self.session.get_outputs()[0].name
13
+
14
+ # Verify input type
15
+ input_type = self.session.get_inputs()[0].type
16
+ assert "float" in input_type, f"Model expects {input_type}"
17
+
18
+ def detect(self, image):
19
+ orig_h, orig_w = image.shape[:2]
20
+ blob = preprocess(image)
21
+ outputs = self.session.run([self.output_name], {self.input_name: blob})
22
+ boxes, scores, class_ids = postprocess(outputs, self.conf_thres, self.iou_thres)
23
+
24
+ results = []
25
+ for box, score, class_id in zip(boxes, scores, class_ids):
26
+ x, y, w, h = box
27
+ x1 = int(x * orig_w / 640)
28
+ y1 = int(y * orig_h / 640)
29
+ x2 = int((x + w) * orig_w / 640)
30
+ y2 = int((y + h) * orig_h / 640)
31
+
32
+ results.append({
33
+ 'class': CLASS_NAMES[class_id],
34
+ 'confidence': score,
35
+ 'box': [x1, y1, x2, y2]
36
+ })
37
+ return results
38
+
39
+ def draw_detections(self, image, detections):
40
+ for det in detections:
41
+ x1, y1, x2, y2 = det['box']
42
+ color = COLORS[CLASS_NAMES.index(det['class'])]
43
+ cv2.rectangle(image, (x1, y1), (x2, y2), color, 2)
44
+ label = f"{det['class']}: {det['confidence']:.2f}"
45
+ cv2.putText(image, label, (x1, y1 - 10),
46
+ cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2)
47
+ return image