Spaces:
Sleeping
Sleeping
Vector73
commited on
Commit
·
01e938d
1
Parent(s):
8c38d83
Add audio model.
Browse files- .gitignore +1 -0
- app.py +270 -13
- models/best_model.pth +3 -0
- prediction_engine.py +157 -0
- requirements.txt +4 -1
- utils/audio_model.py +76 -0
- utils/audio_processing.py +42 -0
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
**/__pycache__
|
app.py
CHANGED
@@ -2,25 +2,109 @@ import streamlit as st
|
|
2 |
import cv2
|
3 |
import numpy as np
|
4 |
import os
|
|
|
|
|
|
|
|
|
5 |
from PIL import Image
|
6 |
import torch
|
7 |
-
|
|
|
|
|
8 |
from utils.helpers import calculate_deforestation_metrics, create_overlay
|
9 |
|
|
|
|
|
|
|
|
|
|
|
10 |
torch.classes.__path__ = []
|
11 |
|
12 |
# Set page config
|
13 |
-
st.set_page_config(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
-
#
|
16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
-
# Load
|
19 |
@st.cache_resource
|
20 |
-
def
|
21 |
model_path = "models/deforestation_model.onnx"
|
22 |
-
return load_onnx_model(model_path, input_size=
|
|
|
|
|
|
|
|
|
|
|
23 |
|
|
|
24 |
def process_image(model, image):
|
25 |
"""Process a single image and return results"""
|
26 |
# Save original image dimensions for display
|
@@ -43,24 +127,121 @@ def process_image(model, image):
|
|
43 |
|
44 |
return binary_mask, overlay, metrics
|
45 |
|
46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
# App title and description
|
48 |
st.title("🌳 Deforestation Detection")
|
49 |
st.markdown(
|
50 |
"""
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
)
|
55 |
|
56 |
# Model info
|
57 |
st.info(
|
58 |
-
f"⚙️ Model optimized for {
|
59 |
)
|
60 |
|
61 |
# Load model
|
62 |
try:
|
63 |
-
model =
|
64 |
except Exception as e:
|
65 |
st.error(f"Error loading model: {e}")
|
66 |
st.info(
|
@@ -139,5 +320,81 @@ def main():
|
|
139 |
except Exception as e:
|
140 |
st.error(f"Error during processing: {e}")
|
141 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
142 |
if __name__ == "__main__":
|
143 |
main()
|
|
|
2 |
import cv2
|
3 |
import numpy as np
|
4 |
import os
|
5 |
+
import tempfile
|
6 |
+
import librosa
|
7 |
+
import librosa.display
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
from PIL import Image
|
10 |
import torch
|
11 |
+
|
12 |
+
# Import deforestation modules
|
13 |
+
from prediction_engine import load_onnx_model
|
14 |
from utils.helpers import calculate_deforestation_metrics, create_overlay
|
15 |
|
16 |
+
# Import audio classification modules
|
17 |
+
from utils.audio_processing import preprocess_audio
|
18 |
+
from utils.audio_model import load_audio_model, predict_audio, class_names
|
19 |
+
|
20 |
+
# Ensure torch classes path is initialized to avoid warnings
|
21 |
torch.classes.__path__ = []
|
22 |
|
23 |
# Set page config
|
24 |
+
st.set_page_config(
|
25 |
+
page_title="Nature Nexus - Forest Surveillance",
|
26 |
+
page_icon="🌳",
|
27 |
+
layout="wide",
|
28 |
+
initial_sidebar_state="expanded"
|
29 |
+
)
|
30 |
+
|
31 |
+
# Constants
|
32 |
+
DEFOREST_MODEL_INPUT_SIZE = 256
|
33 |
+
AUDIO_MODEL_PATH = "models/best_model.pth"
|
34 |
|
35 |
+
# Initialize session state for navigation
|
36 |
+
if 'current_service' not in st.session_state:
|
37 |
+
st.session_state.current_service = 'deforestation'
|
38 |
+
if 'audio_input_method' not in st.session_state:
|
39 |
+
st.session_state.audio_input_method = 'upload'
|
40 |
+
|
41 |
+
# Sidebar for navigation
|
42 |
+
with st.sidebar:
|
43 |
+
st.title("Nature Nexus")
|
44 |
+
st.subheader("Forest Surveillance System")
|
45 |
+
|
46 |
+
selected_service = st.radio(
|
47 |
+
"Select Service:",
|
48 |
+
["Deforestation Detection", "Forest Audio Surveillance"]
|
49 |
+
)
|
50 |
+
st.session_state.current_service = 'deforestation' if selected_service == "Deforestation Detection" else 'audio'
|
51 |
+
|
52 |
+
st.markdown("---")
|
53 |
+
|
54 |
+
# Service-specific sidebar content
|
55 |
+
if st.session_state.current_service == 'deforestation':
|
56 |
+
st.info(
|
57 |
+
"""
|
58 |
+
**Deforestation Detection**
|
59 |
+
|
60 |
+
Upload satellite or aerial images to detect areas of deforestation.
|
61 |
+
"""
|
62 |
+
)
|
63 |
+
else:
|
64 |
+
st.info(
|
65 |
+
"""
|
66 |
+
**Forest Audio Surveillance**
|
67 |
+
|
68 |
+
Detect unusual human-related sounds in forested regions.
|
69 |
+
"""
|
70 |
+
)
|
71 |
+
|
72 |
+
# Audio service specific controls
|
73 |
+
st.subheader("Audio Configuration")
|
74 |
+
audio_input_method = st.radio(
|
75 |
+
"Select Input Method:",
|
76 |
+
("Upload Audio", "Record Audio"),
|
77 |
+
index=0 if st.session_state.audio_input_method == 'upload' else 1
|
78 |
+
)
|
79 |
+
st.session_state.audio_input_method = 'upload' if audio_input_method == "Upload Audio" else 'record'
|
80 |
+
|
81 |
+
# Audio class information
|
82 |
+
st.markdown("**Detection Classes:**")
|
83 |
+
|
84 |
+
# Group classes by category
|
85 |
+
human_sounds = ['footsteps', 'coughing', 'laughing', 'breathing',
|
86 |
+
'drinking_sipping', 'snoring', 'sneezing']
|
87 |
+
tool_sounds = ['chainsaw', 'hand_saw']
|
88 |
+
vehicle_sounds = ['car_horn', 'engine', 'siren']
|
89 |
+
other_sounds = ['crackling_fire', 'fireworks']
|
90 |
+
|
91 |
+
st.markdown("👤 **Human Sounds:** " + ", ".join([s.capitalize() for s in human_sounds]))
|
92 |
+
st.markdown("🔨 **Tool Sounds:** " + ", ".join([s.capitalize() for s in tool_sounds]))
|
93 |
+
st.markdown("🚗 **Vehicle Sounds:** " + ", ".join([s.capitalize() for s in vehicle_sounds]))
|
94 |
+
st.markdown("💥 **Other Sounds:** " + ", ".join([s.capitalize() for s in other_sounds]))
|
95 |
|
96 |
+
# Load deforestation model
|
97 |
@st.cache_resource
|
98 |
+
def load_cached_deforestation_model():
|
99 |
model_path = "models/deforestation_model.onnx"
|
100 |
+
return load_onnx_model(model_path, input_size=DEFOREST_MODEL_INPUT_SIZE)
|
101 |
+
|
102 |
+
# Load audio model
|
103 |
+
@st.cache_resource
|
104 |
+
def load_cached_audio_model():
|
105 |
+
return load_audio_model(AUDIO_MODEL_PATH)
|
106 |
|
107 |
+
# Process image for deforestation detection
|
108 |
def process_image(model, image):
|
109 |
"""Process a single image and return results"""
|
110 |
# Save original image dimensions for display
|
|
|
127 |
|
128 |
return binary_mask, overlay, metrics
|
129 |
|
130 |
+
# Visualize audio for audio classification
|
131 |
+
def visualize_audio(audio_path):
|
132 |
+
y, sr = librosa.load(audio_path, sr=16000)
|
133 |
+
duration = len(y) / sr
|
134 |
+
|
135 |
+
fig, ax = plt.subplots(2, 1, figsize=(10, 6))
|
136 |
+
|
137 |
+
# Waveform plot
|
138 |
+
librosa.display.waveshow(y, sr=sr, ax=ax[0])
|
139 |
+
ax[0].set_title('Audio Waveform')
|
140 |
+
ax[0].set_xlabel('Time (s)')
|
141 |
+
ax[0].set_ylabel('Amplitude')
|
142 |
+
|
143 |
+
# Spectrogram plot
|
144 |
+
S = librosa.feature.melspectrogram(y=y, sr=sr)
|
145 |
+
S_db = librosa.power_to_db(S, ref=np.max)
|
146 |
+
img = librosa.display.specshow(S_db, sr=sr, x_axis='time', y_axis='mel', ax=ax[1])
|
147 |
+
fig.colorbar(img, ax=ax[1], format='%+2.0f dB')
|
148 |
+
ax[1].set_title('Mel Spectrogram')
|
149 |
+
|
150 |
+
plt.tight_layout()
|
151 |
+
st.pyplot(fig)
|
152 |
+
|
153 |
+
return y, sr, duration
|
154 |
+
|
155 |
+
# Process audio for classification
|
156 |
+
def process_audio(audio_file):
|
157 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
|
158 |
+
tmp_file.write(audio_file.read() if hasattr(audio_file, 'read') else audio_file)
|
159 |
+
audio_path = tmp_file.name
|
160 |
+
|
161 |
+
try:
|
162 |
+
# Load audio model
|
163 |
+
audio_model = load_cached_audio_model()
|
164 |
+
|
165 |
+
# Visualize audio
|
166 |
+
with st.spinner('Analyzing audio...'):
|
167 |
+
y, sr, duration = visualize_audio(audio_path)
|
168 |
+
st.caption(f"Audio duration: {duration:.2f} seconds")
|
169 |
+
|
170 |
+
# Make prediction
|
171 |
+
with st.spinner('Making prediction...'):
|
172 |
+
class_name, confidence = predict_audio(audio_path, audio_model)
|
173 |
+
|
174 |
+
# Display results
|
175 |
+
st.subheader("Detection Results")
|
176 |
+
|
177 |
+
col1, col2 = st.columns(2)
|
178 |
+
with col1:
|
179 |
+
st.metric("Detected Sound", class_name.replace('_', ' ').title())
|
180 |
+
with col2:
|
181 |
+
st.metric("Confidence", f"{confidence*100:.2f}%")
|
182 |
+
|
183 |
+
# Show alerts based on class
|
184 |
+
human_sounds = ['footsteps', 'coughing', 'laughing', 'breathing',
|
185 |
+
'drinking_sipping', 'snoring', 'sneezing']
|
186 |
+
tool_sounds = ['chainsaw', 'hand_saw']
|
187 |
+
|
188 |
+
if class_name in human_sounds:
|
189 |
+
st.warning("""
|
190 |
+
⚠️ **Human Activity Detected!**
|
191 |
+
Potential human presence in the monitored area.
|
192 |
+
""")
|
193 |
+
elif class_name in tool_sounds:
|
194 |
+
st.error("""
|
195 |
+
🚨 **ALERT: Human Tool Detected!**
|
196 |
+
Potential illegal logging or activity detected. Consider immediate verification.
|
197 |
+
""")
|
198 |
+
elif class_name in ['car_horn', 'engine', 'siren']:
|
199 |
+
st.warning("""
|
200 |
+
⚠️ **Vehicle Detected!**
|
201 |
+
Vehicle sounds detected in the monitored area.
|
202 |
+
""")
|
203 |
+
elif class_name == 'fireworks':
|
204 |
+
st.error("""
|
205 |
+
🚨 **ALERT: Fireworks Detected!**
|
206 |
+
Potential fire hazard and disturbance to wildlife. Immediate verification required.
|
207 |
+
""")
|
208 |
+
elif class_name == 'crackling_fire':
|
209 |
+
st.error("""
|
210 |
+
🚨 **ALERT: Fire Detected!**
|
211 |
+
Potential wildfire detected. Immediate verification required.
|
212 |
+
""")
|
213 |
+
else:
|
214 |
+
st.success("✅ Environmental sound detected - no immediate threat")
|
215 |
+
|
216 |
+
except Exception as e:
|
217 |
+
st.error(f"Error processing audio: {str(e)}")
|
218 |
+
st.exception(e)
|
219 |
+
finally:
|
220 |
+
# Clean up temp file
|
221 |
+
try:
|
222 |
+
os.unlink(audio_path)
|
223 |
+
except:
|
224 |
+
pass
|
225 |
+
|
226 |
+
# Deforestation detection UI
|
227 |
+
def show_deforestation_detection():
|
228 |
# App title and description
|
229 |
st.title("🌳 Deforestation Detection")
|
230 |
st.markdown(
|
231 |
"""
|
232 |
+
This service detects areas of deforestation in satellite or aerial images of forests.
|
233 |
+
Upload an image to get started!
|
234 |
+
"""
|
235 |
)
|
236 |
|
237 |
# Model info
|
238 |
st.info(
|
239 |
+
f"⚙️ Model optimized for {DEFOREST_MODEL_INPUT_SIZE}x{DEFOREST_MODEL_INPUT_SIZE} pixel images using ONNX runtime"
|
240 |
)
|
241 |
|
242 |
# Load model
|
243 |
try:
|
244 |
+
model = load_cached_deforestation_model()
|
245 |
except Exception as e:
|
246 |
st.error(f"Error loading model: {e}")
|
247 |
st.info(
|
|
|
320 |
except Exception as e:
|
321 |
st.error(f"Error during processing: {e}")
|
322 |
|
323 |
+
# Audio classification UI
|
324 |
+
def show_audio_classification():
|
325 |
+
# App title and description
|
326 |
+
st.title("🎧 Forest Audio Surveillance")
|
327 |
+
st.markdown("""
|
328 |
+
Detect unusual human-related sounds in forested regions to prevent illegal activities.
|
329 |
+
Supported sounds: {}
|
330 |
+
""".format(", ".join(class_names)))
|
331 |
+
|
332 |
+
if st.session_state.audio_input_method == 'upload':
|
333 |
+
st.header("Upload Audio File")
|
334 |
+
|
335 |
+
sample_col, upload_col = st.columns(2)
|
336 |
+
with sample_col:
|
337 |
+
st.info("Upload a WAV, MP3 or OGG file with forest sounds")
|
338 |
+
st.markdown("""
|
339 |
+
**Tips for best results:**
|
340 |
+
- Use audio with minimal background noise
|
341 |
+
- Ensure the sound of interest is clear
|
342 |
+
- 2-3 second clips work best
|
343 |
+
""")
|
344 |
+
|
345 |
+
with upload_col:
|
346 |
+
audio_file = st.file_uploader(
|
347 |
+
"Choose an audio file",
|
348 |
+
type=["wav", "mp3", "ogg"],
|
349 |
+
help="Supported formats: WAV, MP3, OGG"
|
350 |
+
)
|
351 |
+
|
352 |
+
if audio_file:
|
353 |
+
st.success("File uploaded successfully!")
|
354 |
+
with st.expander("Audio Preview", expanded=True):
|
355 |
+
st.audio(audio_file)
|
356 |
+
process_audio(audio_file)
|
357 |
+
|
358 |
+
else: # Record mode
|
359 |
+
st.header("Record Live Audio")
|
360 |
+
|
361 |
+
st.info("""
|
362 |
+
Click the microphone button below to record a sound for analysis.
|
363 |
+
**Note:** Please ensure your browser has permission to access your microphone.
|
364 |
+
When prompted, click "Allow" to enable recording.
|
365 |
+
""")
|
366 |
+
|
367 |
+
recorded_audio = st.audio_input(
|
368 |
+
label="Record a sound",
|
369 |
+
key="audio_recorder",
|
370 |
+
help="Click to record forest sounds for analysis",
|
371 |
+
label_visibility="visible"
|
372 |
+
)
|
373 |
+
|
374 |
+
if recorded_audio:
|
375 |
+
st.success("Audio recorded successfully!")
|
376 |
+
with st.expander("Recorded Audio", expanded=True):
|
377 |
+
st.audio(recorded_audio)
|
378 |
+
process_audio(recorded_audio)
|
379 |
+
else:
|
380 |
+
st.write("Waiting for recording...")
|
381 |
+
|
382 |
+
# Main function
|
383 |
+
def main():
|
384 |
+
# Check which service is selected and render appropriate UI
|
385 |
+
if st.session_state.current_service == 'deforestation':
|
386 |
+
show_deforestation_detection()
|
387 |
+
else:
|
388 |
+
show_audio_classification()
|
389 |
+
|
390 |
+
# Footer
|
391 |
+
st.markdown("---")
|
392 |
+
st.markdown("""
|
393 |
+
<div style="text-align: center; padding: 10px;">
|
394 |
+
<p>Nature Nexus - Forest Surveillance System | 🌳 Protect Natural Ecosystems</p>
|
395 |
+
<p><small>Built with Streamlit and PyTorch</small></p>
|
396 |
+
</div>
|
397 |
+
""", unsafe_allow_html=True)
|
398 |
+
|
399 |
if __name__ == "__main__":
|
400 |
main()
|
models/best_model.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:fff4ca890869016c359ce0991e22c0df72bdaee45b4512f5252967fe44361095
|
3 |
+
size 5148310
|
prediction_engine.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import cv2
|
4 |
+
import onnxruntime as ort
|
5 |
+
from utils.preprocess import preprocess_image
|
6 |
+
|
7 |
+
|
8 |
+
class PredictionEngine:
|
9 |
+
def __init__(self, model_path=None, use_onnx=True, input_size=256):
|
10 |
+
"""
|
11 |
+
Initialize the prediction engine
|
12 |
+
|
13 |
+
Args:
|
14 |
+
model_path: Path to the model file (PyTorch or ONNX)
|
15 |
+
use_onnx: Whether to use ONNX runtime for inference
|
16 |
+
input_size: Input size for the model (default is 256)
|
17 |
+
"""
|
18 |
+
self.use_onnx = use_onnx
|
19 |
+
self.input_size = input_size
|
20 |
+
|
21 |
+
if model_path:
|
22 |
+
if use_onnx:
|
23 |
+
self.model = self._load_onnx_model(model_path)
|
24 |
+
else:
|
25 |
+
self.model = self._load_pytorch_model(model_path)
|
26 |
+
else:
|
27 |
+
self.model = None
|
28 |
+
|
29 |
+
def _load_onnx_model(self, model_path):
|
30 |
+
"""
|
31 |
+
Load an ONNX model
|
32 |
+
|
33 |
+
Args:
|
34 |
+
model_path: Path to the ONNX model
|
35 |
+
|
36 |
+
Returns:
|
37 |
+
ONNX Runtime InferenceSession
|
38 |
+
"""
|
39 |
+
# Try with CUDA first, fall back to CPU if needed
|
40 |
+
try:
|
41 |
+
session = ort.InferenceSession(
|
42 |
+
model_path, providers=["CUDAExecutionProvider", "CPUExecutionProvider"]
|
43 |
+
)
|
44 |
+
print("ONNX model loaded with CUDA support")
|
45 |
+
return session
|
46 |
+
except Exception as e:
|
47 |
+
print(f"Could not load ONNX model with CUDA, falling back to CPU: {e}")
|
48 |
+
session = ort.InferenceSession(
|
49 |
+
model_path, providers=["CPUExecutionProvider"]
|
50 |
+
)
|
51 |
+
print("ONNX model loaded with CPU support")
|
52 |
+
return session
|
53 |
+
|
54 |
+
def _load_pytorch_model(self, model_path):
|
55 |
+
"""
|
56 |
+
Load a PyTorch model
|
57 |
+
|
58 |
+
Args:
|
59 |
+
model_path: Path to the PyTorch model
|
60 |
+
|
61 |
+
Returns:
|
62 |
+
PyTorch model
|
63 |
+
"""
|
64 |
+
from utils.model import load_model
|
65 |
+
return load_model(model_path)
|
66 |
+
|
67 |
+
def preprocess(self, image):
|
68 |
+
"""
|
69 |
+
Preprocess an image for prediction
|
70 |
+
|
71 |
+
Args:
|
72 |
+
image: Input image (numpy array)
|
73 |
+
|
74 |
+
Returns:
|
75 |
+
Processed image suitable for the model
|
76 |
+
"""
|
77 |
+
# Keep the original image for reference
|
78 |
+
self.original_shape = image.shape[:2]
|
79 |
+
|
80 |
+
# Preprocess image
|
81 |
+
if self.use_onnx:
|
82 |
+
# For ONNX, we need to ensure the input is exactly the expected size
|
83 |
+
tensor = preprocess_image(image, img_size=self.input_size)
|
84 |
+
return tensor.numpy()
|
85 |
+
else:
|
86 |
+
# For PyTorch
|
87 |
+
return preprocess_image(image, img_size=self.input_size)
|
88 |
+
|
89 |
+
def predict(self, image):
|
90 |
+
"""
|
91 |
+
Make a prediction on an image
|
92 |
+
|
93 |
+
Args:
|
94 |
+
image: Input image (numpy array)
|
95 |
+
|
96 |
+
Returns:
|
97 |
+
Predicted mask
|
98 |
+
"""
|
99 |
+
if self.model is None:
|
100 |
+
raise ValueError("Model not loaded. Initialize with a valid model path.")
|
101 |
+
|
102 |
+
# Preprocess the image
|
103 |
+
processed_input = self.preprocess(image)
|
104 |
+
|
105 |
+
# Run inference
|
106 |
+
if self.use_onnx:
|
107 |
+
# Get input and output names
|
108 |
+
input_name = self.model.get_inputs()[0].name
|
109 |
+
output_name = self.model.get_outputs()[0].name
|
110 |
+
|
111 |
+
# Run ONNX inference
|
112 |
+
outputs = self.model.run([output_name], {input_name: processed_input})
|
113 |
+
|
114 |
+
# Apply sigmoid to output
|
115 |
+
mask = 1 / (1 + np.exp(-outputs[0].squeeze()))
|
116 |
+
else:
|
117 |
+
# PyTorch inference
|
118 |
+
with torch.no_grad():
|
119 |
+
# Move to device
|
120 |
+
device = next(self.model.parameters()).device
|
121 |
+
processed_input = processed_input.to(device)
|
122 |
+
|
123 |
+
# Forward pass
|
124 |
+
output = self.model(processed_input)
|
125 |
+
output = torch.sigmoid(output)
|
126 |
+
|
127 |
+
# Convert to numpy
|
128 |
+
mask = output.cpu().numpy().squeeze()
|
129 |
+
|
130 |
+
return mask
|
131 |
+
|
132 |
+
|
133 |
+
def load_pytorch_model(model_path):
|
134 |
+
"""
|
135 |
+
Load the PyTorch model for prediction
|
136 |
+
|
137 |
+
Args:
|
138 |
+
model_path: Path to the PyTorch model
|
139 |
+
|
140 |
+
Returns:
|
141 |
+
PredictionEngine instance
|
142 |
+
"""
|
143 |
+
return PredictionEngine(model_path, use_onnx=False)
|
144 |
+
|
145 |
+
|
146 |
+
def load_onnx_model(model_path, input_size=256):
|
147 |
+
"""
|
148 |
+
Load the ONNX model for prediction
|
149 |
+
|
150 |
+
Args:
|
151 |
+
model_path: Path to the ONNX model
|
152 |
+
input_size: Input size for the model
|
153 |
+
|
154 |
+
Returns:
|
155 |
+
PredictionEngine instance
|
156 |
+
"""
|
157 |
+
return PredictionEngine(model_path, use_onnx=True, input_size=input_size)
|
requirements.txt
CHANGED
@@ -10,4 +10,7 @@ scikit-learn
|
|
10 |
matplotlib
|
11 |
onnxruntime
|
12 |
onnxruntime-gpu
|
13 |
-
onnx
|
|
|
|
|
|
|
|
10 |
matplotlib
|
11 |
onnxruntime
|
12 |
onnxruntime-gpu
|
13 |
+
onnx
|
14 |
+
librosa
|
15 |
+
soundfile
|
16 |
+
pydub
|
utils/audio_model.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from utils.audio_processing import preprocess_audio
|
4 |
+
|
5 |
+
class_names = [
|
6 |
+
'fireworks', 'chainsaw', 'footsteps', 'car_horn', 'crackling_fire',
|
7 |
+
'drinking_sipping', 'laughing', 'engine', 'breathing', 'hand_saw',
|
8 |
+
'coughing', 'snoring', 'sneezing', 'siren'
|
9 |
+
]
|
10 |
+
|
11 |
+
class AudioClassifier(torch.nn.Module):
|
12 |
+
def __init__(self, num_classes=14):
|
13 |
+
super().__init__()
|
14 |
+
self.features = torch.nn.Sequential(
|
15 |
+
torch.nn.Conv2d(1, 64, 3, padding=1),
|
16 |
+
torch.nn.BatchNorm2d(64),
|
17 |
+
torch.nn.ReLU(),
|
18 |
+
torch.nn.Conv2d(64, 64, 3, padding=1),
|
19 |
+
torch.nn.BatchNorm2d(64),
|
20 |
+
torch.nn.ReLU(),
|
21 |
+
torch.nn.MaxPool2d(2),
|
22 |
+
torch.nn.Dropout(0.2),
|
23 |
+
|
24 |
+
torch.nn.Conv2d(64, 128, 3, padding=1),
|
25 |
+
torch.nn.BatchNorm2d(128),
|
26 |
+
torch.nn.ReLU(),
|
27 |
+
torch.nn.Conv2d(128, 128, 3, padding=1),
|
28 |
+
torch.nn.BatchNorm2d(128),
|
29 |
+
torch.nn.ReLU(),
|
30 |
+
torch.nn.MaxPool2d(2),
|
31 |
+
torch.nn.Dropout(0.2),
|
32 |
+
|
33 |
+
torch.nn.Conv2d(128, 256, 3, padding=1),
|
34 |
+
torch.nn.BatchNorm2d(256),
|
35 |
+
torch.nn.ReLU(),
|
36 |
+
torch.nn.Conv2d(256, 256, 3, padding=1),
|
37 |
+
torch.nn.BatchNorm2d(256),
|
38 |
+
torch.nn.ReLU(),
|
39 |
+
torch.nn.MaxPool2d(2),
|
40 |
+
torch.nn.Dropout(0.2)
|
41 |
+
)
|
42 |
+
self.classifier = torch.nn.Sequential(
|
43 |
+
torch.nn.AdaptiveAvgPool2d(1),
|
44 |
+
torch.nn.Flatten(),
|
45 |
+
torch.nn.Linear(256, 256),
|
46 |
+
torch.nn.ReLU(),
|
47 |
+
torch.nn.Linear(256, 256),
|
48 |
+
torch.nn.ReLU(),
|
49 |
+
torch.nn.Linear(256, num_classes)
|
50 |
+
)
|
51 |
+
|
52 |
+
def forward(self, x):
|
53 |
+
x = self.features(x)
|
54 |
+
return self.classifier(x)
|
55 |
+
|
56 |
+
def load_audio_model(model_path='models/audio_model.pth'):
|
57 |
+
model = AudioClassifier(len(class_names))
|
58 |
+
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
|
59 |
+
model.eval()
|
60 |
+
return model
|
61 |
+
|
62 |
+
def predict_audio(audio_path, model):
|
63 |
+
# Preprocess audio
|
64 |
+
spec = preprocess_audio(audio_path)
|
65 |
+
|
66 |
+
# Convert to tensor
|
67 |
+
input_tensor = torch.FloatTensor(spec).unsqueeze(0) # Add batch dimension
|
68 |
+
|
69 |
+
# Predict
|
70 |
+
with torch.no_grad():
|
71 |
+
outputs = model(input_tensor)
|
72 |
+
probabilities = torch.nn.functional.softmax(outputs, dim=1)
|
73 |
+
|
74 |
+
# Get results
|
75 |
+
pred_prob, pred_index = torch.max(probabilities, 1)
|
76 |
+
return class_names[pred_index.item()], pred_prob.item()
|
utils/audio_processing.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import librosa
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
class AudioConfig:
|
5 |
+
sr = 16000
|
6 |
+
duration = 3
|
7 |
+
hop_length = 340 * duration
|
8 |
+
fmin = 20
|
9 |
+
fmax = sr // 2
|
10 |
+
n_mels = 128
|
11 |
+
n_fft = 128 * 20
|
12 |
+
samples = sr * duration
|
13 |
+
|
14 |
+
def preprocess_audio(audio_path, config=None):
|
15 |
+
if config is None:
|
16 |
+
config = AudioConfig()
|
17 |
+
|
18 |
+
# Load audio
|
19 |
+
y, sr = librosa.load(audio_path, sr=config.sr)
|
20 |
+
|
21 |
+
# Trim or pad
|
22 |
+
if len(y) > config.samples:
|
23 |
+
y = y[:config.samples]
|
24 |
+
else:
|
25 |
+
padding = config.samples - len(y)
|
26 |
+
offset = padding // 2
|
27 |
+
y = np.pad(y, (offset, padding - offset), 'constant')
|
28 |
+
|
29 |
+
# Create mel spectrogram
|
30 |
+
spectrogram = librosa.feature.melspectrogram(
|
31 |
+
y=y,
|
32 |
+
sr=config.sr,
|
33 |
+
n_mels=config.n_mels,
|
34 |
+
hop_length=config.hop_length,
|
35 |
+
n_fft=config.n_fft,
|
36 |
+
fmin=config.fmin,
|
37 |
+
fmax=config.fmax
|
38 |
+
)
|
39 |
+
spectrogram = librosa.power_to_db(spectrogram)
|
40 |
+
|
41 |
+
# Return with correct shape for PyTorch (channels, height, width)
|
42 |
+
return spectrogram[np.newaxis, ...]
|