File size: 3,405 Bytes
b24406e
 
 
 
4d6e8c2
fe4a4cb
 
 
 
3b09640
fe4a4cb
4d6e8c2
b24406e
4d6e8c2
3b09640
 
 
4d6e8c2
 
b24406e
1c33274
70f5f26
fe4a4cb
3b09640
1c33274
70f5f26
4d6e8c2
 
fe4a4cb
70f5f26
b24406e
4d6e8c2
fe4a4cb
4d6e8c2
fe4a4cb
 
 
 
 
 
 
3b09640
 
 
fe4a4cb
 
 
 
 
 
 
 
 
b24406e
fe4a4cb
b24406e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fe4a4cb
 
 
 
 
 
 
 
b24406e
fe4a4cb
 
 
 
4d6e8c2
 
fe4a4cb
70f5f26
fe4a4cb
 
 
 
 
 
4d6e8c2
 
70f5f26
4d6e8c2
fe4a4cb
 
b24406e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import librosa
import joblib
import numpy as np

from fastapi import APIRouter
from datetime import datetime
from datasets import load_dataset
from sklearn.metrics import accuracy_score
import random
import os

from .utils.evaluation import AudioEvaluationRequest
from .utils.emissions import tracker, clean_emissions_data, get_space_info

from dotenv import load_dotenv
load_dotenv()

router = APIRouter()

DESCRIPTION = "Decision tree"
ROUTE = "/audio"



@router.post(ROUTE, tags=["Audio Task"],
             description=DESCRIPTION)
async def evaluate_audio(request: AudioEvaluationRequest):
    """
    Evaluate audio classification for rainforest sound detection.
    
    Current Model: Basic decision tree
    """
    # Get space info
    username, space_url = get_space_info()

    # Define the label mapping
    LABEL_MAPPING = {
        "chainsaw": 0,
        "environment": 1
    }
    # Load and prepare the dataset
    # Because the dataset is gated, we need to use the HF_TOKEN environment variable to authenticate
    dataset = load_dataset(request.dataset_name,token=os.getenv("HF_TOKEN"))
    
    # Split dataset
    train_test = dataset["train"].train_test_split(test_size=request.test_size, seed=request.test_seed)
    test_dataset = train_test["test"]
    
    # Start tracking emissions
    tracker.start()
    tracker.start_task("inference")
    
    #--------------------------------------------------------------------------------------------
    # MY MODEL
    #--------------------------------------------------------------------------------------------   
    def extract_features(example, sampling_rate):
        audio_array = example['audio']['array']
        # mfcc = librosa.feature.mfcc(y=audio_array, sr=sampling_rate, n_mfcc=5)
        mfcc = librosa.feature.spectral_contrast(y=audio_array)
        return np.mean(mfcc, axis=1)

    def predict_new_audio(model, dataset, sampling_rate):
        features_list = [extract_features(example, sampling_rate) for example in dataset]
        features_array = np.vstack(features_list)
        predictions = model.predict(features_array)
        return predictions

    model_filename = "model_audio.pkl"
    clf = joblib.load(model_filename)

    predictions = predict_new_audio(clf, test_dataset, 12000)
    #--------------------------------------------------------------------------------------------
    # YOUR MODEL INFERENCE STOPS HERE
    #--------------------------------------------------------------------------------------------   
    
    # Stop tracking emissions
    emissions_data = tracker.stop_task()
    
    # Calculate accuracy
    true_labels = test_dataset["label"]
    accuracy = accuracy_score(true_labels, predictions)
    
    # Prepare results dictionary
    results = {
        "username": username,
        "space_url": space_url,
        "submission_timestamp": datetime.now().isoformat(),
        "model_description": DESCRIPTION,
        "accuracy": float(accuracy),
        "energy_consumed_wh": emissions_data.energy_consumed * 1000,
        "emissions_gco2eq": emissions_data.emissions * 1000,
        "emissions_data": clean_emissions_data(emissions_data),
        "api_route": ROUTE,
        "dataset_config": {
            "dataset_name": request.dataset_name,
            "test_size": request.test_size,
            "test_seed": request.test_seed
        }
    }
    
    return results