Update tasks/audio.py
Browse files- tasks/audio.py +30 -9
tasks/audio.py
CHANGED
|
@@ -239,18 +239,39 @@ for audio_data in test_dataset["audio"]:
|
|
| 239 |
|
| 240 |
print("Predictions:", predictions)
|
| 241 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 242 |
|
| 243 |
-
|
| 244 |
#--------------------------------------------------------------------------------------------
|
| 245 |
# YOUR MODEL INFERENCE STOPS HERE
|
| 246 |
-
#--------------------------------------------------------------------------------------------
|
| 247 |
-
|
| 248 |
# Stop tracking emissions
|
| 249 |
emissions_data = tracker.stop_task()
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
# Prepare results dictionary
|
| 255 |
results = {
|
| 256 |
"username": username,
|
|
@@ -268,5 +289,5 @@ print("Predictions:", predictions)
|
|
| 268 |
"test_seed": request.test_seed
|
| 269 |
}
|
| 270 |
}
|
| 271 |
-
|
| 272 |
-
return results
|
|
|
|
| 239 |
|
| 240 |
print("Predictions:", predictions)
|
| 241 |
|
| 242 |
+
def map_predictions_to_labels(predictions):
|
| 243 |
+
"""
|
| 244 |
+
Maps string predictions to numeric labels:
|
| 245 |
+
- "chainsaw" -> 0
|
| 246 |
+
- any other class -> 1
|
| 247 |
+
Args:
|
| 248 |
+
predictions (list of str): List of class name predictions.
|
| 249 |
+
Returns:
|
| 250 |
+
list of int: Mapped numeric labels.
|
| 251 |
+
"""
|
| 252 |
+
return [0 if pred == "chainsaw" else 1 for pred in predictions]
|
| 253 |
+
|
| 254 |
+
from sklearn.metrics import accuracy_score
|
| 255 |
+
|
| 256 |
+
# Map string predictions to numeric labels
|
| 257 |
+
numeric_predictions = map_predictions_to_labels(predictions)
|
| 258 |
+
|
| 259 |
+
# Extract true labels (already numeric)
|
| 260 |
+
true_labels = test_dataset["label"]
|
| 261 |
+
|
| 262 |
+
# Calculate accuracy
|
| 263 |
+
accuracy = accuracy_score(true_labels, numeric_predictions)
|
| 264 |
+
print("Accuracy:", accuracy)
|
| 265 |
|
|
|
|
| 266 |
#--------------------------------------------------------------------------------------------
|
| 267 |
# YOUR MODEL INFERENCE STOPS HERE
|
| 268 |
+
#--------------------------------------------------------------------------------------------
|
| 269 |
+
|
| 270 |
# Stop tracking emissions
|
| 271 |
emissions_data = tracker.stop_task()
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
|
|
|
|
| 275 |
# Prepare results dictionary
|
| 276 |
results = {
|
| 277 |
"username": username,
|
|
|
|
| 289 |
"test_seed": request.test_seed
|
| 290 |
}
|
| 291 |
}
|
| 292 |
+
|
| 293 |
+
return results
|