Vector73 commited on
Commit
01e938d
·
1 Parent(s): 8c38d83

Add audio model.

Browse files
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ **/__pycache__
app.py CHANGED
@@ -2,25 +2,109 @@ import streamlit as st
2
  import cv2
3
  import numpy as np
4
  import os
 
 
 
 
5
  from PIL import Image
6
  import torch
7
- from predict import load_onnx_model
 
 
8
  from utils.helpers import calculate_deforestation_metrics, create_overlay
9
 
 
 
 
 
 
10
  torch.classes.__path__ = []
11
 
12
  # Set page config
13
- st.set_page_config(page_title="Deforestation Detection", page_icon="🌳", layout="wide")
 
 
 
 
 
 
 
 
 
14
 
15
- # Set constants
16
- MODEL_INPUT_SIZE = 256 # The size our model expects
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
- # Load ONNX model
19
  @st.cache_resource
20
- def load_cached_onnx_model():
21
  model_path = "models/deforestation_model.onnx"
22
- return load_onnx_model(model_path, input_size=MODEL_INPUT_SIZE)
 
 
 
 
 
23
 
 
24
  def process_image(model, image):
25
  """Process a single image and return results"""
26
  # Save original image dimensions for display
@@ -43,24 +127,121 @@ def process_image(model, image):
43
 
44
  return binary_mask, overlay, metrics
45
 
46
- def main():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  # App title and description
48
  st.title("🌳 Deforestation Detection")
49
  st.markdown(
50
  """
51
- This app detects areas of deforestation in satellite or aerial images of forests.
52
- Upload an image to get started!
53
- """
54
  )
55
 
56
  # Model info
57
  st.info(
58
- f"⚙️ Model optimized for {MODEL_INPUT_SIZE}x{MODEL_INPUT_SIZE} pixel images using ONNX runtime"
59
  )
60
 
61
  # Load model
62
  try:
63
- model = load_cached_onnx_model()
64
  except Exception as e:
65
  st.error(f"Error loading model: {e}")
66
  st.info(
@@ -139,5 +320,81 @@ def main():
139
  except Exception as e:
140
  st.error(f"Error during processing: {e}")
141
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  if __name__ == "__main__":
143
  main()
 
2
  import cv2
3
  import numpy as np
4
  import os
5
+ import tempfile
6
+ import librosa
7
+ import librosa.display
8
+ import matplotlib.pyplot as plt
9
  from PIL import Image
10
  import torch
11
+
12
+ # Import deforestation modules
13
+ from prediction_engine import load_onnx_model
14
  from utils.helpers import calculate_deforestation_metrics, create_overlay
15
 
16
+ # Import audio classification modules
17
+ from utils.audio_processing import preprocess_audio
18
+ from utils.audio_model import load_audio_model, predict_audio, class_names
19
+
20
+ # Ensure torch classes path is initialized to avoid warnings
21
  torch.classes.__path__ = []
22
 
23
  # Set page config
24
+ st.set_page_config(
25
+ page_title="Nature Nexus - Forest Surveillance",
26
+ page_icon="🌳",
27
+ layout="wide",
28
+ initial_sidebar_state="expanded"
29
+ )
30
+
31
+ # Constants
32
+ DEFOREST_MODEL_INPUT_SIZE = 256
33
+ AUDIO_MODEL_PATH = "models/best_model.pth"
34
 
35
+ # Initialize session state for navigation
36
+ if 'current_service' not in st.session_state:
37
+ st.session_state.current_service = 'deforestation'
38
+ if 'audio_input_method' not in st.session_state:
39
+ st.session_state.audio_input_method = 'upload'
40
+
41
+ # Sidebar for navigation
42
+ with st.sidebar:
43
+ st.title("Nature Nexus")
44
+ st.subheader("Forest Surveillance System")
45
+
46
+ selected_service = st.radio(
47
+ "Select Service:",
48
+ ["Deforestation Detection", "Forest Audio Surveillance"]
49
+ )
50
+ st.session_state.current_service = 'deforestation' if selected_service == "Deforestation Detection" else 'audio'
51
+
52
+ st.markdown("---")
53
+
54
+ # Service-specific sidebar content
55
+ if st.session_state.current_service == 'deforestation':
56
+ st.info(
57
+ """
58
+ **Deforestation Detection**
59
+
60
+ Upload satellite or aerial images to detect areas of deforestation.
61
+ """
62
+ )
63
+ else:
64
+ st.info(
65
+ """
66
+ **Forest Audio Surveillance**
67
+
68
+ Detect unusual human-related sounds in forested regions.
69
+ """
70
+ )
71
+
72
+ # Audio service specific controls
73
+ st.subheader("Audio Configuration")
74
+ audio_input_method = st.radio(
75
+ "Select Input Method:",
76
+ ("Upload Audio", "Record Audio"),
77
+ index=0 if st.session_state.audio_input_method == 'upload' else 1
78
+ )
79
+ st.session_state.audio_input_method = 'upload' if audio_input_method == "Upload Audio" else 'record'
80
+
81
+ # Audio class information
82
+ st.markdown("**Detection Classes:**")
83
+
84
+ # Group classes by category
85
+ human_sounds = ['footsteps', 'coughing', 'laughing', 'breathing',
86
+ 'drinking_sipping', 'snoring', 'sneezing']
87
+ tool_sounds = ['chainsaw', 'hand_saw']
88
+ vehicle_sounds = ['car_horn', 'engine', 'siren']
89
+ other_sounds = ['crackling_fire', 'fireworks']
90
+
91
+ st.markdown("👤 **Human Sounds:** " + ", ".join([s.capitalize() for s in human_sounds]))
92
+ st.markdown("🔨 **Tool Sounds:** " + ", ".join([s.capitalize() for s in tool_sounds]))
93
+ st.markdown("🚗 **Vehicle Sounds:** " + ", ".join([s.capitalize() for s in vehicle_sounds]))
94
+ st.markdown("💥 **Other Sounds:** " + ", ".join([s.capitalize() for s in other_sounds]))
95
 
96
+ # Load deforestation model
97
  @st.cache_resource
98
+ def load_cached_deforestation_model():
99
  model_path = "models/deforestation_model.onnx"
100
+ return load_onnx_model(model_path, input_size=DEFOREST_MODEL_INPUT_SIZE)
101
+
102
+ # Load audio model
103
+ @st.cache_resource
104
+ def load_cached_audio_model():
105
+ return load_audio_model(AUDIO_MODEL_PATH)
106
 
107
+ # Process image for deforestation detection
108
  def process_image(model, image):
109
  """Process a single image and return results"""
110
  # Save original image dimensions for display
 
127
 
128
  return binary_mask, overlay, metrics
129
 
130
+ # Visualize audio for audio classification
131
+ def visualize_audio(audio_path):
132
+ y, sr = librosa.load(audio_path, sr=16000)
133
+ duration = len(y) / sr
134
+
135
+ fig, ax = plt.subplots(2, 1, figsize=(10, 6))
136
+
137
+ # Waveform plot
138
+ librosa.display.waveshow(y, sr=sr, ax=ax[0])
139
+ ax[0].set_title('Audio Waveform')
140
+ ax[0].set_xlabel('Time (s)')
141
+ ax[0].set_ylabel('Amplitude')
142
+
143
+ # Spectrogram plot
144
+ S = librosa.feature.melspectrogram(y=y, sr=sr)
145
+ S_db = librosa.power_to_db(S, ref=np.max)
146
+ img = librosa.display.specshow(S_db, sr=sr, x_axis='time', y_axis='mel', ax=ax[1])
147
+ fig.colorbar(img, ax=ax[1], format='%+2.0f dB')
148
+ ax[1].set_title('Mel Spectrogram')
149
+
150
+ plt.tight_layout()
151
+ st.pyplot(fig)
152
+
153
+ return y, sr, duration
154
+
155
+ # Process audio for classification
156
+ def process_audio(audio_file):
157
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
158
+ tmp_file.write(audio_file.read() if hasattr(audio_file, 'read') else audio_file)
159
+ audio_path = tmp_file.name
160
+
161
+ try:
162
+ # Load audio model
163
+ audio_model = load_cached_audio_model()
164
+
165
+ # Visualize audio
166
+ with st.spinner('Analyzing audio...'):
167
+ y, sr, duration = visualize_audio(audio_path)
168
+ st.caption(f"Audio duration: {duration:.2f} seconds")
169
+
170
+ # Make prediction
171
+ with st.spinner('Making prediction...'):
172
+ class_name, confidence = predict_audio(audio_path, audio_model)
173
+
174
+ # Display results
175
+ st.subheader("Detection Results")
176
+
177
+ col1, col2 = st.columns(2)
178
+ with col1:
179
+ st.metric("Detected Sound", class_name.replace('_', ' ').title())
180
+ with col2:
181
+ st.metric("Confidence", f"{confidence*100:.2f}%")
182
+
183
+ # Show alerts based on class
184
+ human_sounds = ['footsteps', 'coughing', 'laughing', 'breathing',
185
+ 'drinking_sipping', 'snoring', 'sneezing']
186
+ tool_sounds = ['chainsaw', 'hand_saw']
187
+
188
+ if class_name in human_sounds:
189
+ st.warning("""
190
+ ⚠️ **Human Activity Detected!**
191
+ Potential human presence in the monitored area.
192
+ """)
193
+ elif class_name in tool_sounds:
194
+ st.error("""
195
+ 🚨 **ALERT: Human Tool Detected!**
196
+ Potential illegal logging or activity detected. Consider immediate verification.
197
+ """)
198
+ elif class_name in ['car_horn', 'engine', 'siren']:
199
+ st.warning("""
200
+ ⚠️ **Vehicle Detected!**
201
+ Vehicle sounds detected in the monitored area.
202
+ """)
203
+ elif class_name == 'fireworks':
204
+ st.error("""
205
+ 🚨 **ALERT: Fireworks Detected!**
206
+ Potential fire hazard and disturbance to wildlife. Immediate verification required.
207
+ """)
208
+ elif class_name == 'crackling_fire':
209
+ st.error("""
210
+ 🚨 **ALERT: Fire Detected!**
211
+ Potential wildfire detected. Immediate verification required.
212
+ """)
213
+ else:
214
+ st.success("✅ Environmental sound detected - no immediate threat")
215
+
216
+ except Exception as e:
217
+ st.error(f"Error processing audio: {str(e)}")
218
+ st.exception(e)
219
+ finally:
220
+ # Clean up temp file
221
+ try:
222
+ os.unlink(audio_path)
223
+ except:
224
+ pass
225
+
226
+ # Deforestation detection UI
227
+ def show_deforestation_detection():
228
  # App title and description
229
  st.title("🌳 Deforestation Detection")
230
  st.markdown(
231
  """
232
+ This service detects areas of deforestation in satellite or aerial images of forests.
233
+ Upload an image to get started!
234
+ """
235
  )
236
 
237
  # Model info
238
  st.info(
239
+ f"⚙️ Model optimized for {DEFOREST_MODEL_INPUT_SIZE}x{DEFOREST_MODEL_INPUT_SIZE} pixel images using ONNX runtime"
240
  )
241
 
242
  # Load model
243
  try:
244
+ model = load_cached_deforestation_model()
245
  except Exception as e:
246
  st.error(f"Error loading model: {e}")
247
  st.info(
 
320
  except Exception as e:
321
  st.error(f"Error during processing: {e}")
322
 
323
+ # Audio classification UI
324
+ def show_audio_classification():
325
+ # App title and description
326
+ st.title("🎧 Forest Audio Surveillance")
327
+ st.markdown("""
328
+ Detect unusual human-related sounds in forested regions to prevent illegal activities.
329
+ Supported sounds: {}
330
+ """.format(", ".join(class_names)))
331
+
332
+ if st.session_state.audio_input_method == 'upload':
333
+ st.header("Upload Audio File")
334
+
335
+ sample_col, upload_col = st.columns(2)
336
+ with sample_col:
337
+ st.info("Upload a WAV, MP3 or OGG file with forest sounds")
338
+ st.markdown("""
339
+ **Tips for best results:**
340
+ - Use audio with minimal background noise
341
+ - Ensure the sound of interest is clear
342
+ - 2-3 second clips work best
343
+ """)
344
+
345
+ with upload_col:
346
+ audio_file = st.file_uploader(
347
+ "Choose an audio file",
348
+ type=["wav", "mp3", "ogg"],
349
+ help="Supported formats: WAV, MP3, OGG"
350
+ )
351
+
352
+ if audio_file:
353
+ st.success("File uploaded successfully!")
354
+ with st.expander("Audio Preview", expanded=True):
355
+ st.audio(audio_file)
356
+ process_audio(audio_file)
357
+
358
+ else: # Record mode
359
+ st.header("Record Live Audio")
360
+
361
+ st.info("""
362
+ Click the microphone button below to record a sound for analysis.
363
+ **Note:** Please ensure your browser has permission to access your microphone.
364
+ When prompted, click "Allow" to enable recording.
365
+ """)
366
+
367
+ recorded_audio = st.audio_input(
368
+ label="Record a sound",
369
+ key="audio_recorder",
370
+ help="Click to record forest sounds for analysis",
371
+ label_visibility="visible"
372
+ )
373
+
374
+ if recorded_audio:
375
+ st.success("Audio recorded successfully!")
376
+ with st.expander("Recorded Audio", expanded=True):
377
+ st.audio(recorded_audio)
378
+ process_audio(recorded_audio)
379
+ else:
380
+ st.write("Waiting for recording...")
381
+
382
+ # Main function
383
+ def main():
384
+ # Check which service is selected and render appropriate UI
385
+ if st.session_state.current_service == 'deforestation':
386
+ show_deforestation_detection()
387
+ else:
388
+ show_audio_classification()
389
+
390
+ # Footer
391
+ st.markdown("---")
392
+ st.markdown("""
393
+ <div style="text-align: center; padding: 10px;">
394
+ <p>Nature Nexus - Forest Surveillance System | 🌳 Protect Natural Ecosystems</p>
395
+ <p><small>Built with Streamlit and PyTorch</small></p>
396
+ </div>
397
+ """, unsafe_allow_html=True)
398
+
399
  if __name__ == "__main__":
400
  main()
models/best_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fff4ca890869016c359ce0991e22c0df72bdaee45b4512f5252967fe44361095
3
+ size 5148310
prediction_engine.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import cv2
4
+ import onnxruntime as ort
5
+ from utils.preprocess import preprocess_image
6
+
7
+
8
+ class PredictionEngine:
9
+ def __init__(self, model_path=None, use_onnx=True, input_size=256):
10
+ """
11
+ Initialize the prediction engine
12
+
13
+ Args:
14
+ model_path: Path to the model file (PyTorch or ONNX)
15
+ use_onnx: Whether to use ONNX runtime for inference
16
+ input_size: Input size for the model (default is 256)
17
+ """
18
+ self.use_onnx = use_onnx
19
+ self.input_size = input_size
20
+
21
+ if model_path:
22
+ if use_onnx:
23
+ self.model = self._load_onnx_model(model_path)
24
+ else:
25
+ self.model = self._load_pytorch_model(model_path)
26
+ else:
27
+ self.model = None
28
+
29
+ def _load_onnx_model(self, model_path):
30
+ """
31
+ Load an ONNX model
32
+
33
+ Args:
34
+ model_path: Path to the ONNX model
35
+
36
+ Returns:
37
+ ONNX Runtime InferenceSession
38
+ """
39
+ # Try with CUDA first, fall back to CPU if needed
40
+ try:
41
+ session = ort.InferenceSession(
42
+ model_path, providers=["CUDAExecutionProvider", "CPUExecutionProvider"]
43
+ )
44
+ print("ONNX model loaded with CUDA support")
45
+ return session
46
+ except Exception as e:
47
+ print(f"Could not load ONNX model with CUDA, falling back to CPU: {e}")
48
+ session = ort.InferenceSession(
49
+ model_path, providers=["CPUExecutionProvider"]
50
+ )
51
+ print("ONNX model loaded with CPU support")
52
+ return session
53
+
54
+ def _load_pytorch_model(self, model_path):
55
+ """
56
+ Load a PyTorch model
57
+
58
+ Args:
59
+ model_path: Path to the PyTorch model
60
+
61
+ Returns:
62
+ PyTorch model
63
+ """
64
+ from utils.model import load_model
65
+ return load_model(model_path)
66
+
67
+ def preprocess(self, image):
68
+ """
69
+ Preprocess an image for prediction
70
+
71
+ Args:
72
+ image: Input image (numpy array)
73
+
74
+ Returns:
75
+ Processed image suitable for the model
76
+ """
77
+ # Keep the original image for reference
78
+ self.original_shape = image.shape[:2]
79
+
80
+ # Preprocess image
81
+ if self.use_onnx:
82
+ # For ONNX, we need to ensure the input is exactly the expected size
83
+ tensor = preprocess_image(image, img_size=self.input_size)
84
+ return tensor.numpy()
85
+ else:
86
+ # For PyTorch
87
+ return preprocess_image(image, img_size=self.input_size)
88
+
89
+ def predict(self, image):
90
+ """
91
+ Make a prediction on an image
92
+
93
+ Args:
94
+ image: Input image (numpy array)
95
+
96
+ Returns:
97
+ Predicted mask
98
+ """
99
+ if self.model is None:
100
+ raise ValueError("Model not loaded. Initialize with a valid model path.")
101
+
102
+ # Preprocess the image
103
+ processed_input = self.preprocess(image)
104
+
105
+ # Run inference
106
+ if self.use_onnx:
107
+ # Get input and output names
108
+ input_name = self.model.get_inputs()[0].name
109
+ output_name = self.model.get_outputs()[0].name
110
+
111
+ # Run ONNX inference
112
+ outputs = self.model.run([output_name], {input_name: processed_input})
113
+
114
+ # Apply sigmoid to output
115
+ mask = 1 / (1 + np.exp(-outputs[0].squeeze()))
116
+ else:
117
+ # PyTorch inference
118
+ with torch.no_grad():
119
+ # Move to device
120
+ device = next(self.model.parameters()).device
121
+ processed_input = processed_input.to(device)
122
+
123
+ # Forward pass
124
+ output = self.model(processed_input)
125
+ output = torch.sigmoid(output)
126
+
127
+ # Convert to numpy
128
+ mask = output.cpu().numpy().squeeze()
129
+
130
+ return mask
131
+
132
+
133
+ def load_pytorch_model(model_path):
134
+ """
135
+ Load the PyTorch model for prediction
136
+
137
+ Args:
138
+ model_path: Path to the PyTorch model
139
+
140
+ Returns:
141
+ PredictionEngine instance
142
+ """
143
+ return PredictionEngine(model_path, use_onnx=False)
144
+
145
+
146
+ def load_onnx_model(model_path, input_size=256):
147
+ """
148
+ Load the ONNX model for prediction
149
+
150
+ Args:
151
+ model_path: Path to the ONNX model
152
+ input_size: Input size for the model
153
+
154
+ Returns:
155
+ PredictionEngine instance
156
+ """
157
+ return PredictionEngine(model_path, use_onnx=True, input_size=input_size)
requirements.txt CHANGED
@@ -10,4 +10,7 @@ scikit-learn
10
  matplotlib
11
  onnxruntime
12
  onnxruntime-gpu
13
- onnx
 
 
 
 
10
  matplotlib
11
  onnxruntime
12
  onnxruntime-gpu
13
+ onnx
14
+ librosa
15
+ soundfile
16
+ pydub
utils/audio_model.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from utils.audio_processing import preprocess_audio
4
+
5
+ class_names = [
6
+ 'fireworks', 'chainsaw', 'footsteps', 'car_horn', 'crackling_fire',
7
+ 'drinking_sipping', 'laughing', 'engine', 'breathing', 'hand_saw',
8
+ 'coughing', 'snoring', 'sneezing', 'siren'
9
+ ]
10
+
11
+ class AudioClassifier(torch.nn.Module):
12
+ def __init__(self, num_classes=14):
13
+ super().__init__()
14
+ self.features = torch.nn.Sequential(
15
+ torch.nn.Conv2d(1, 64, 3, padding=1),
16
+ torch.nn.BatchNorm2d(64),
17
+ torch.nn.ReLU(),
18
+ torch.nn.Conv2d(64, 64, 3, padding=1),
19
+ torch.nn.BatchNorm2d(64),
20
+ torch.nn.ReLU(),
21
+ torch.nn.MaxPool2d(2),
22
+ torch.nn.Dropout(0.2),
23
+
24
+ torch.nn.Conv2d(64, 128, 3, padding=1),
25
+ torch.nn.BatchNorm2d(128),
26
+ torch.nn.ReLU(),
27
+ torch.nn.Conv2d(128, 128, 3, padding=1),
28
+ torch.nn.BatchNorm2d(128),
29
+ torch.nn.ReLU(),
30
+ torch.nn.MaxPool2d(2),
31
+ torch.nn.Dropout(0.2),
32
+
33
+ torch.nn.Conv2d(128, 256, 3, padding=1),
34
+ torch.nn.BatchNorm2d(256),
35
+ torch.nn.ReLU(),
36
+ torch.nn.Conv2d(256, 256, 3, padding=1),
37
+ torch.nn.BatchNorm2d(256),
38
+ torch.nn.ReLU(),
39
+ torch.nn.MaxPool2d(2),
40
+ torch.nn.Dropout(0.2)
41
+ )
42
+ self.classifier = torch.nn.Sequential(
43
+ torch.nn.AdaptiveAvgPool2d(1),
44
+ torch.nn.Flatten(),
45
+ torch.nn.Linear(256, 256),
46
+ torch.nn.ReLU(),
47
+ torch.nn.Linear(256, 256),
48
+ torch.nn.ReLU(),
49
+ torch.nn.Linear(256, num_classes)
50
+ )
51
+
52
+ def forward(self, x):
53
+ x = self.features(x)
54
+ return self.classifier(x)
55
+
56
+ def load_audio_model(model_path='models/audio_model.pth'):
57
+ model = AudioClassifier(len(class_names))
58
+ model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
59
+ model.eval()
60
+ return model
61
+
62
+ def predict_audio(audio_path, model):
63
+ # Preprocess audio
64
+ spec = preprocess_audio(audio_path)
65
+
66
+ # Convert to tensor
67
+ input_tensor = torch.FloatTensor(spec).unsqueeze(0) # Add batch dimension
68
+
69
+ # Predict
70
+ with torch.no_grad():
71
+ outputs = model(input_tensor)
72
+ probabilities = torch.nn.functional.softmax(outputs, dim=1)
73
+
74
+ # Get results
75
+ pred_prob, pred_index = torch.max(probabilities, 1)
76
+ return class_names[pred_index.item()], pred_prob.item()
utils/audio_processing.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import librosa
2
+ import numpy as np
3
+
4
+ class AudioConfig:
5
+ sr = 16000
6
+ duration = 3
7
+ hop_length = 340 * duration
8
+ fmin = 20
9
+ fmax = sr // 2
10
+ n_mels = 128
11
+ n_fft = 128 * 20
12
+ samples = sr * duration
13
+
14
+ def preprocess_audio(audio_path, config=None):
15
+ if config is None:
16
+ config = AudioConfig()
17
+
18
+ # Load audio
19
+ y, sr = librosa.load(audio_path, sr=config.sr)
20
+
21
+ # Trim or pad
22
+ if len(y) > config.samples:
23
+ y = y[:config.samples]
24
+ else:
25
+ padding = config.samples - len(y)
26
+ offset = padding // 2
27
+ y = np.pad(y, (offset, padding - offset), 'constant')
28
+
29
+ # Create mel spectrogram
30
+ spectrogram = librosa.feature.melspectrogram(
31
+ y=y,
32
+ sr=config.sr,
33
+ n_mels=config.n_mels,
34
+ hop_length=config.hop_length,
35
+ n_fft=config.n_fft,
36
+ fmin=config.fmin,
37
+ fmax=config.fmax
38
+ )
39
+ spectrogram = librosa.power_to_db(spectrogram)
40
+
41
+ # Return with correct shape for PyTorch (channels, height, width)
42
+ return spectrogram[np.newaxis, ...]