Spaces:
Sleeping
Sleeping
Vector73
commited on
Commit
·
82d82cc
1
Parent(s):
e3d6299
Added YOLO model.
Browse files- .gitignore +3 -1
- app.py +300 -4
- requirements.txt +2 -1
- utils/helpers.py +49 -0
- utils/onnx_inference.py +47 -0
.gitignore
CHANGED
@@ -1 +1,3 @@
|
|
1 |
-
**/__pycache__
|
|
|
|
|
|
1 |
+
**/__pycache__
|
2 |
+
*.onnx
|
3 |
+
*.pth
|
app.py
CHANGED
@@ -17,6 +17,9 @@ from utils.helpers import calculate_deforestation_metrics, create_overlay
|
|
17 |
from utils.audio_processing import preprocess_audio
|
18 |
from utils.audio_model import load_audio_model, predict_audio, class_names
|
19 |
|
|
|
|
|
|
|
20 |
# Ensure torch classes path is initialized to avoid warnings
|
21 |
torch.classes.__path__ = []
|
22 |
|
@@ -31,12 +34,15 @@ st.set_page_config(
|
|
31 |
# Constants
|
32 |
DEFOREST_MODEL_INPUT_SIZE = 256
|
33 |
AUDIO_MODEL_PATH = "models/best_model.pth"
|
|
|
34 |
|
35 |
# Initialize session state for navigation
|
36 |
if 'current_service' not in st.session_state:
|
37 |
st.session_state.current_service = 'deforestation'
|
38 |
if 'audio_input_method' not in st.session_state:
|
39 |
st.session_state.audio_input_method = 'upload'
|
|
|
|
|
40 |
|
41 |
# Sidebar for navigation
|
42 |
with st.sidebar:
|
@@ -45,9 +51,15 @@ with st.sidebar:
|
|
45 |
|
46 |
selected_service = st.radio(
|
47 |
"Select Service:",
|
48 |
-
["Deforestation Detection", "Forest Audio Surveillance"]
|
49 |
)
|
50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
|
52 |
st.markdown("---")
|
53 |
|
@@ -60,7 +72,7 @@ with st.sidebar:
|
|
60 |
Upload satellite or aerial images to detect areas of deforestation.
|
61 |
"""
|
62 |
)
|
63 |
-
|
64 |
st.info(
|
65 |
"""
|
66 |
**Forest Audio Surveillance**
|
@@ -92,6 +104,44 @@ with st.sidebar:
|
|
92 |
st.markdown("🔨 **Tool Sounds:** " + ", ".join([s.capitalize() for s in tool_sounds]))
|
93 |
st.markdown("🚗 **Vehicle Sounds:** " + ", ".join([s.capitalize() for s in vehicle_sounds]))
|
94 |
st.markdown("💥 **Other Sounds:** " + ", ".join([s.capitalize() for s in other_sounds]))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
|
96 |
# Load deforestation model
|
97 |
@st.cache_resource
|
@@ -104,6 +154,10 @@ def load_cached_deforestation_model():
|
|
104 |
def load_cached_audio_model():
|
105 |
return load_audio_model(AUDIO_MODEL_PATH)
|
106 |
|
|
|
|
|
|
|
|
|
107 |
# Process image for deforestation detection
|
108 |
def process_image(model, image):
|
109 |
"""Process a single image and return results"""
|
@@ -379,13 +433,255 @@ def show_audio_classification():
|
|
379 |
else:
|
380 |
st.write("Waiting for recording...")
|
381 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
382 |
# Main function
|
383 |
def main():
|
384 |
# Check which service is selected and render appropriate UI
|
385 |
if st.session_state.current_service == 'deforestation':
|
386 |
show_deforestation_detection()
|
387 |
-
|
388 |
show_audio_classification()
|
|
|
|
|
389 |
|
390 |
# Footer
|
391 |
st.markdown("---")
|
|
|
17 |
from utils.audio_processing import preprocess_audio
|
18 |
from utils.audio_model import load_audio_model, predict_audio, class_names
|
19 |
|
20 |
+
# Import YOLO detection modules
|
21 |
+
from utils.onnx_inference import YOLOv11
|
22 |
+
|
23 |
# Ensure torch classes path is initialized to avoid warnings
|
24 |
torch.classes.__path__ = []
|
25 |
|
|
|
34 |
# Constants
|
35 |
DEFOREST_MODEL_INPUT_SIZE = 256
|
36 |
AUDIO_MODEL_PATH = "models/best_model.pth"
|
37 |
+
YOLO_MODEL_PATH = "models/best_model.onnx"
|
38 |
|
39 |
# Initialize session state for navigation
|
40 |
if 'current_service' not in st.session_state:
|
41 |
st.session_state.current_service = 'deforestation'
|
42 |
if 'audio_input_method' not in st.session_state:
|
43 |
st.session_state.audio_input_method = 'upload'
|
44 |
+
if 'detection_input_method' not in st.session_state:
|
45 |
+
st.session_state.detection_input_method = 'image'
|
46 |
|
47 |
# Sidebar for navigation
|
48 |
with st.sidebar:
|
|
|
51 |
|
52 |
selected_service = st.radio(
|
53 |
"Select Service:",
|
54 |
+
["Deforestation Detection", "Forest Audio Surveillance", "Object Detection"]
|
55 |
)
|
56 |
+
|
57 |
+
if selected_service == "Deforestation Detection":
|
58 |
+
st.session_state.current_service = 'deforestation'
|
59 |
+
elif selected_service == "Forest Audio Surveillance":
|
60 |
+
st.session_state.current_service = 'audio'
|
61 |
+
else:
|
62 |
+
st.session_state.current_service = 'detection'
|
63 |
|
64 |
st.markdown("---")
|
65 |
|
|
|
72 |
Upload satellite or aerial images to detect areas of deforestation.
|
73 |
"""
|
74 |
)
|
75 |
+
elif st.session_state.current_service == 'audio':
|
76 |
st.info(
|
77 |
"""
|
78 |
**Forest Audio Surveillance**
|
|
|
104 |
st.markdown("🔨 **Tool Sounds:** " + ", ".join([s.capitalize() for s in tool_sounds]))
|
105 |
st.markdown("🚗 **Vehicle Sounds:** " + ", ".join([s.capitalize() for s in vehicle_sounds]))
|
106 |
st.markdown("💥 **Other Sounds:** " + ", ".join([s.capitalize() for s in other_sounds]))
|
107 |
+
else: # Object Detection
|
108 |
+
st.info(
|
109 |
+
"""
|
110 |
+
**Object Detection**
|
111 |
+
|
112 |
+
Detect trespassers, vehicles, fires, and other objects in forest surveillance footage.
|
113 |
+
"""
|
114 |
+
)
|
115 |
+
|
116 |
+
# Detection service specific controls
|
117 |
+
st.subheader("Detection Configuration")
|
118 |
+
detection_input_method = st.radio(
|
119 |
+
"Select Input Method:",
|
120 |
+
("Image", "Video", "Camera"),
|
121 |
+
index=0 if st.session_state.detection_input_method == 'image' else
|
122 |
+
(1 if st.session_state.detection_input_method == 'video' else 2)
|
123 |
+
)
|
124 |
+
|
125 |
+
if detection_input_method == "Image":
|
126 |
+
st.session_state.detection_input_method = 'image'
|
127 |
+
elif detection_input_method == "Video":
|
128 |
+
st.session_state.detection_input_method = 'video'
|
129 |
+
else:
|
130 |
+
st.session_state.detection_input_method = 'camera'
|
131 |
+
|
132 |
+
# Detection threshold controls
|
133 |
+
st.subheader("Detection Settings")
|
134 |
+
confidence = st.slider("Confidence Threshold", 0.0, 1.0, 0.5)
|
135 |
+
iou_thres = st.slider("IoU Threshold", 0.0, 1.0, 0.5)
|
136 |
+
|
137 |
+
# Detection class information
|
138 |
+
st.markdown("**Detection Classes:**")
|
139 |
+
st.markdown("🚴 **Bike/Bicycle**")
|
140 |
+
st.markdown("🚚 **Bus/Truck**")
|
141 |
+
st.markdown("🚗 **Car**")
|
142 |
+
st.markdown("🔥 **Fire**")
|
143 |
+
st.markdown("👤 **Human**")
|
144 |
+
st.markdown("💨 **Smoke**")
|
145 |
|
146 |
# Load deforestation model
|
147 |
@st.cache_resource
|
|
|
154 |
def load_cached_audio_model():
|
155 |
return load_audio_model(AUDIO_MODEL_PATH)
|
156 |
|
157 |
+
@st.cache_resource
|
158 |
+
def load_cached_yolo_model():
|
159 |
+
return YOLOv11(YOLO_MODEL_PATH)
|
160 |
+
|
161 |
# Process image for deforestation detection
|
162 |
def process_image(model, image):
|
163 |
"""Process a single image and return results"""
|
|
|
433 |
else:
|
434 |
st.write("Waiting for recording...")
|
435 |
|
436 |
+
# Object Detection UI
|
437 |
+
def show_object_detection():
|
438 |
+
# App title and description
|
439 |
+
st.title("🔍 Forest Object Detection")
|
440 |
+
st.markdown(
|
441 |
+
"""
|
442 |
+
Detect trespassers, vehicles, fires, and other objects in forest surveillance footage.
|
443 |
+
Choose an input method to begin detection.
|
444 |
+
"""
|
445 |
+
)
|
446 |
+
|
447 |
+
# Model info
|
448 |
+
st.info("⚙️ Object detection model optimized with ONNX runtime for faster inference")
|
449 |
+
|
450 |
+
# Load model
|
451 |
+
try:
|
452 |
+
model = load_cached_yolo_model()
|
453 |
+
# Update model confidence and IoU thresholds from sidebar
|
454 |
+
confidence = st.session_state.get('confidence', 0.5)
|
455 |
+
iou_thres = st.session_state.get('iou_thres', 0.5)
|
456 |
+
model.conf_thres = confidence
|
457 |
+
model.iou_thres = iou_thres
|
458 |
+
except Exception as e:
|
459 |
+
st.error(f"Error loading model: {e}")
|
460 |
+
st.info(
|
461 |
+
"Make sure you have the YOLO ONNX model file available at models/best_model.onnx"
|
462 |
+
)
|
463 |
+
return
|
464 |
+
|
465 |
+
# Input method based selection
|
466 |
+
if st.session_state.detection_input_method == 'image':
|
467 |
+
# Image upload
|
468 |
+
img_file = st.file_uploader("Upload Image", type=["jpg", "jpeg", "png"])
|
469 |
+
if img_file is not None:
|
470 |
+
# Load image
|
471 |
+
file_bytes = np.asarray(bytearray(img_file.read()), dtype=np.uint8)
|
472 |
+
image = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR)
|
473 |
+
if image is not None:
|
474 |
+
# Display original image
|
475 |
+
st.subheader("Original Image")
|
476 |
+
st.image(
|
477 |
+
cv2.cvtColor(image, cv2.COLOR_BGR2RGB),
|
478 |
+
caption="Uploaded Image",
|
479 |
+
use_container_width=True,
|
480 |
+
)
|
481 |
+
|
482 |
+
# Process with detection model
|
483 |
+
with st.spinner("Processing image..."):
|
484 |
+
try:
|
485 |
+
detections = model.detect(image)
|
486 |
+
result_image = model.draw_detections(image.copy(), detections)
|
487 |
+
|
488 |
+
# Display results
|
489 |
+
st.subheader("Detection Results")
|
490 |
+
st.image(
|
491 |
+
cv2.cvtColor(result_image, cv2.COLOR_BGR2RGB),
|
492 |
+
caption="Detected Objects",
|
493 |
+
use_container_width=True,
|
494 |
+
)
|
495 |
+
|
496 |
+
# Display detection statistics
|
497 |
+
st.subheader("Detection Statistics")
|
498 |
+
|
499 |
+
# Count detections by class
|
500 |
+
class_counts = {}
|
501 |
+
for det in detections:
|
502 |
+
class_name = det['class']
|
503 |
+
if class_name in class_counts:
|
504 |
+
class_counts[class_name] += 1
|
505 |
+
else:
|
506 |
+
class_counts[class_name] = 1
|
507 |
+
|
508 |
+
# Display counts with emojis
|
509 |
+
cols = st.columns(3)
|
510 |
+
col_idx = 0
|
511 |
+
|
512 |
+
for class_name, count in class_counts.items():
|
513 |
+
emoji = "👤" if class_name == "human" else (
|
514 |
+
"🔥" if class_name == "fire" else (
|
515 |
+
"💨" if class_name == "smoke" else (
|
516 |
+
"🚗" if class_name == "car" else (
|
517 |
+
"🚴" if class_name == "bike-bicycle" else "🚚"))))
|
518 |
+
|
519 |
+
with cols[col_idx % 3]:
|
520 |
+
st.metric(f"{emoji} {class_name.capitalize()}", count)
|
521 |
+
col_idx += 1
|
522 |
+
|
523 |
+
# Check for priority threats
|
524 |
+
if "fire" in class_counts or "smoke" in class_counts:
|
525 |
+
st.error("🚨 **ALERT: Fire Detected!** Potential forest fire detected. Immediate action required.")
|
526 |
+
|
527 |
+
if "human" in class_counts or "car" in class_counts or "bike-bicycle" in class_counts or "bus-truck" in class_counts:
|
528 |
+
st.warning("⚠️ **Trespassers Detected!** Unauthorized entry detected in monitored area.")
|
529 |
+
|
530 |
+
except Exception as e:
|
531 |
+
st.error(f"Error during detection: {e}")
|
532 |
+
st.exception(e)
|
533 |
+
|
534 |
+
elif st.session_state.detection_input_method == 'video':
|
535 |
+
# Video upload
|
536 |
+
video_file = st.file_uploader("Upload Video", type=["mp4", "avi", "mov"])
|
537 |
+
if video_file is not None:
|
538 |
+
# Save uploaded video to temp file
|
539 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') as tfile:
|
540 |
+
tfile.write(video_file.read())
|
541 |
+
temp_video_path = tfile.name
|
542 |
+
|
543 |
+
# Display video upload success
|
544 |
+
st.success("Video uploaded successfully!")
|
545 |
+
|
546 |
+
# Process video button
|
547 |
+
if st.button("Process Video"):
|
548 |
+
with st.spinner("Processing video... This may take a while."):
|
549 |
+
try:
|
550 |
+
# Open video file
|
551 |
+
cap = cv2.VideoCapture(temp_video_path)
|
552 |
+
|
553 |
+
# Create video writer for output
|
554 |
+
output_path = "output_video.mp4"
|
555 |
+
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
556 |
+
fps = cap.get(cv2.CAP_PROP_FPS)
|
557 |
+
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
558 |
+
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
559 |
+
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
|
560 |
+
|
561 |
+
# Create placeholder for video frames
|
562 |
+
video_placeholder = st.empty()
|
563 |
+
status_text = st.empty()
|
564 |
+
|
565 |
+
# Process frames
|
566 |
+
frame_count = 0
|
567 |
+
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
568 |
+
|
569 |
+
while cap.isOpened():
|
570 |
+
ret, frame = cap.read()
|
571 |
+
if not ret:
|
572 |
+
break
|
573 |
+
|
574 |
+
# Process every 5th frame for speed
|
575 |
+
if frame_count % 5 == 0:
|
576 |
+
detections = model.detect(frame)
|
577 |
+
result_frame = model.draw_detections(frame.copy(), detections)
|
578 |
+
|
579 |
+
# Update preview
|
580 |
+
if frame_count % 15 == 0: # Update display less frequently
|
581 |
+
video_placeholder.image(
|
582 |
+
cv2.cvtColor(result_frame, cv2.COLOR_BGR2RGB),
|
583 |
+
caption="Processing Video",
|
584 |
+
use_container_width=True
|
585 |
+
)
|
586 |
+
progress = min(100, int((frame_count / total_frames) * 100))
|
587 |
+
status_text.text(f"Processing: {progress}% complete")
|
588 |
+
else:
|
589 |
+
result_frame = frame # Skip detection on some frames
|
590 |
+
|
591 |
+
# Write frame to output video
|
592 |
+
out.write(result_frame)
|
593 |
+
frame_count += 1
|
594 |
+
|
595 |
+
# Release resources
|
596 |
+
cap.release()
|
597 |
+
out.release()
|
598 |
+
|
599 |
+
# Display completion message
|
600 |
+
st.success("Video processing complete!")
|
601 |
+
|
602 |
+
# Provide download button for processed video
|
603 |
+
with open(output_path, "rb") as file:
|
604 |
+
st.download_button(
|
605 |
+
label="Download Processed Video",
|
606 |
+
data=file,
|
607 |
+
file_name="forest_surveillance_results.mp4",
|
608 |
+
mime="video/mp4"
|
609 |
+
)
|
610 |
+
|
611 |
+
except Exception as e:
|
612 |
+
st.error(f"Error processing video: {e}")
|
613 |
+
st.exception(e)
|
614 |
+
finally:
|
615 |
+
# Clean up temp file
|
616 |
+
try:
|
617 |
+
os.unlink(temp_video_path)
|
618 |
+
except:
|
619 |
+
pass
|
620 |
+
|
621 |
+
else: # Camera mode
|
622 |
+
# Live camera feed
|
623 |
+
st.subheader("Live Camera Detection")
|
624 |
+
st.info("Use your webcam to detect objects in real-time")
|
625 |
+
|
626 |
+
cam = st.camera_input("Camera Feed")
|
627 |
+
|
628 |
+
if cam:
|
629 |
+
# Process camera input
|
630 |
+
with st.spinner("Processing image..."):
|
631 |
+
try:
|
632 |
+
# Convert image
|
633 |
+
image = cv2.imdecode(np.frombuffer(cam.getvalue(), np.uint8), cv2.IMREAD_COLOR)
|
634 |
+
|
635 |
+
# Run detection
|
636 |
+
detections = model.detect(image)
|
637 |
+
result_image = model.draw_detections(image.copy(), detections)
|
638 |
+
|
639 |
+
# Display results
|
640 |
+
st.image(
|
641 |
+
cv2.cvtColor(result_image, cv2.COLOR_BGR2RGB),
|
642 |
+
caption="Detection Results",
|
643 |
+
use_container_width=True
|
644 |
+
)
|
645 |
+
|
646 |
+
# Show detection summary
|
647 |
+
if detections:
|
648 |
+
# Count detections by class
|
649 |
+
class_counts = {}
|
650 |
+
for det in detections:
|
651 |
+
class_name = det['class']
|
652 |
+
if class_name in class_counts:
|
653 |
+
class_counts[class_name] += 1
|
654 |
+
else:
|
655 |
+
class_counts[class_name] = 1
|
656 |
+
|
657 |
+
# Display as metrics
|
658 |
+
st.subheader("Detection Summary")
|
659 |
+
cols = st.columns(3)
|
660 |
+
for i, (class_name, count) in enumerate(class_counts.items()):
|
661 |
+
with cols[i % 3]:
|
662 |
+
st.metric(class_name.capitalize(), count)
|
663 |
+
|
664 |
+
# Check for priority threats
|
665 |
+
if "fire" in class_counts or "smoke" in class_counts:
|
666 |
+
st.error("🚨 **ALERT: Fire Detected!** Potential forest fire detected.")
|
667 |
+
|
668 |
+
if "human" in class_counts:
|
669 |
+
st.warning("⚠️ **Trespasser Detected!** Human presence detected.")
|
670 |
+
else:
|
671 |
+
st.info("No objects detected in frame")
|
672 |
+
|
673 |
+
except Exception as e:
|
674 |
+
st.error(f"Error processing camera feed: {e}")
|
675 |
+
|
676 |
# Main function
|
677 |
def main():
|
678 |
# Check which service is selected and render appropriate UI
|
679 |
if st.session_state.current_service == 'deforestation':
|
680 |
show_deforestation_detection()
|
681 |
+
elif st.session_state.current_service == 'audio':
|
682 |
show_audio_classification()
|
683 |
+
else: # 'detection'
|
684 |
+
show_object_detection()
|
685 |
|
686 |
# Footer
|
687 |
st.markdown("---")
|
requirements.txt
CHANGED
@@ -13,4 +13,5 @@ onnxruntime-gpu
|
|
13 |
onnx
|
14 |
librosa
|
15 |
soundfile
|
16 |
-
pydub
|
|
|
|
13 |
onnx
|
14 |
librosa
|
15 |
soundfile
|
16 |
+
pydub
|
17 |
+
supervision
|
utils/helpers.py
CHANGED
@@ -71,3 +71,52 @@ def create_overlay(original_image, mask, threshold=0.5, alpha=0.5):
|
|
71 |
overlay = cv2.addWeighted(original_image, 1 - alpha, colored_mask, alpha, 0)
|
72 |
|
73 |
return overlay
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
overlay = cv2.addWeighted(original_image, 1 - alpha, colored_mask, alpha, 0)
|
72 |
|
73 |
return overlay
|
74 |
+
|
75 |
+
|
76 |
+
CLASS_NAMES = ['bike-bicycle', 'bus-truck', 'car', 'fire', 'human', 'smoke']
|
77 |
+
COLORS = np.random.uniform(0, 255, size=(len(CLASS_NAMES), 3))
|
78 |
+
|
79 |
+
def preprocess(image, img_size=640):
|
80 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
81 |
+
image = cv2.resize(image, (img_size, img_size))
|
82 |
+
image = image.transpose((2, 0, 1)) # HWC to CHW
|
83 |
+
image = np.ascontiguousarray(image, dtype=np.float32) / 255.0
|
84 |
+
return image[np.newaxis, ...]
|
85 |
+
|
86 |
+
def postprocess(outputs, conf_thresh=0.5, iou_thresh=0.5):
|
87 |
+
outputs = outputs[0].transpose()
|
88 |
+
boxes, scores, class_ids = [], [], []
|
89 |
+
|
90 |
+
for row in outputs:
|
91 |
+
cls_scores = row[4:4+len(CLASS_NAMES)]
|
92 |
+
class_id = np.argmax(cls_scores)
|
93 |
+
max_score = cls_scores[class_id]
|
94 |
+
|
95 |
+
if max_score >= conf_thresh:
|
96 |
+
cx, cy, w, h = row[:4]
|
97 |
+
x = (cx - w/2).item() # Convert to Python float
|
98 |
+
y = (cy - h/2).item()
|
99 |
+
width = w.item()
|
100 |
+
height = h.item()
|
101 |
+
boxes.append([x, y, width, height])
|
102 |
+
scores.append(float(max_score))
|
103 |
+
class_ids.append(int(class_id))
|
104 |
+
|
105 |
+
if len(boxes) > 0:
|
106 |
+
# Convert to list of lists with native Python floats
|
107 |
+
boxes = [[float(x) for x in box] for box in boxes]
|
108 |
+
scores = [float(score) for score in scores]
|
109 |
+
|
110 |
+
indices = cv2.dnn.NMSBoxes(
|
111 |
+
bboxes=boxes,
|
112 |
+
scores=scores,
|
113 |
+
score_threshold=conf_thresh,
|
114 |
+
nms_threshold=iou_thresh
|
115 |
+
)
|
116 |
+
|
117 |
+
if len(indices) > 0:
|
118 |
+
boxes = [boxes[i] for i in indices.flatten()]
|
119 |
+
scores = [scores[i] for i in indices.flatten()]
|
120 |
+
class_ids = [class_ids[i] for i in indices.flatten()]
|
121 |
+
|
122 |
+
return boxes, scores, class_ids
|
utils/onnx_inference.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
import onnxruntime as ort
|
4 |
+
from .helpers import CLASS_NAMES, COLORS, preprocess, postprocess
|
5 |
+
|
6 |
+
class YOLOv11:
|
7 |
+
def __init__(self, onnx_path, conf_thres=0.5, iou_thres=0.5):
|
8 |
+
self.session = ort.InferenceSession(onnx_path)
|
9 |
+
self.conf_thres = conf_thres
|
10 |
+
self.iou_thres = iou_thres
|
11 |
+
self.input_name = self.session.get_inputs()[0].name
|
12 |
+
self.output_name = self.session.get_outputs()[0].name
|
13 |
+
|
14 |
+
# Verify input type
|
15 |
+
input_type = self.session.get_inputs()[0].type
|
16 |
+
assert "float" in input_type, f"Model expects {input_type}"
|
17 |
+
|
18 |
+
def detect(self, image):
|
19 |
+
orig_h, orig_w = image.shape[:2]
|
20 |
+
blob = preprocess(image)
|
21 |
+
outputs = self.session.run([self.output_name], {self.input_name: blob})
|
22 |
+
boxes, scores, class_ids = postprocess(outputs, self.conf_thres, self.iou_thres)
|
23 |
+
|
24 |
+
results = []
|
25 |
+
for box, score, class_id in zip(boxes, scores, class_ids):
|
26 |
+
x, y, w, h = box
|
27 |
+
x1 = int(x * orig_w / 640)
|
28 |
+
y1 = int(y * orig_h / 640)
|
29 |
+
x2 = int((x + w) * orig_w / 640)
|
30 |
+
y2 = int((y + h) * orig_h / 640)
|
31 |
+
|
32 |
+
results.append({
|
33 |
+
'class': CLASS_NAMES[class_id],
|
34 |
+
'confidence': score,
|
35 |
+
'box': [x1, y1, x2, y2]
|
36 |
+
})
|
37 |
+
return results
|
38 |
+
|
39 |
+
def draw_detections(self, image, detections):
|
40 |
+
for det in detections:
|
41 |
+
x1, y1, x2, y2 = det['box']
|
42 |
+
color = COLORS[CLASS_NAMES.index(det['class'])]
|
43 |
+
cv2.rectangle(image, (x1, y1), (x2, y2), color, 2)
|
44 |
+
label = f"{det['class']}: {det['confidence']:.2f}"
|
45 |
+
cv2.putText(image, label, (x1, y1 - 10),
|
46 |
+
cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2)
|
47 |
+
return image
|