File size: 3,254 Bytes
4d6e8c2
 
 
 
 
42b7ac6
 
4d6e8c2
42b7ac6
 
 
 
 
 
 
 
 
 
 
4d6e8c2
 
42b7ac6
 
 
 
 
1c33274
70f5f26
1c33274
70f5f26
42b7ac6
 
 
 
4d6e8c2
70f5f26
 
42b7ac6
 
 
 
4d6e8c2
42b7ac6
 
4d6e8c2
42b7ac6
 
4d6e8c2
42b7ac6
 
4d6e8c2
42b7ac6
 
 
 
 
 
 
70f5f26
42b7ac6
 
4d6e8c2
42b7ac6
 
 
 
 
70f5f26
42b7ac6
 
70f5f26
4d6e8c2
42b7ac6
 
 
 
4d6e8c2
 
42b7ac6
4d6e8c2
 
 
 
 
 
 
70f5f26
4d6e8c2
 
 
 
1c33274
4d6e8c2
 
 
 
 
 
 
42b7ac6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
from fastapi import APIRouter
from datetime import datetime
from datasets import load_dataset
from sklearn.metrics import accuracy_score

from .data.data_loaders import TextDataLoader
from .models.text_classifiers import BaselineModel
from .utils.evaluation import TextEvaluationRequest
from .utils.emissions import get_tracker, clean_emissions_data, get_space_info, EmissionsData

# define models
from .models.text_classifiers import ModelFactory
embedding_ml_model = ModelFactory.create_model({"model_type": "embeddingML"})

distilbert_model = ModelFactory.create_model({"model_type":
                                                  "distilbert-pretrained",
                                              "model_name":
                                                  "2025-01-27_17-00-47_DistilBERT_Model_fined-tuned_from_distilbert-base-uncased"
                                              })


model_to_evaluate = distilbert_model

# define router
router = APIRouter()
DESCRIPTION = model_to_evaluate.description
ROUTE = "/text"

@router.post(ROUTE, tags=["Text Task"], 
             description=DESCRIPTION)
async def evaluate_text(request: TextEvaluationRequest,
                        track_emissions: bool = True,
                        model = distilbert_model,
                        light_dataset: bool = False) -> dict:
    """
    Evaluate text classification for climate disinformation detection.
    
    Parameters:
    -----------
    request: TextEvaluationRequest
        The request object containing the dataset configuration.

    track_emissions: bool
        Whether to track emissions or not.

    model: TextClassifier
        The model to use for inference.

    light_dataset: bool
        Whether to use a light dataset or not.

    Returns:
    --------
    dict
        A dictionary containing the evaluation results.
    """
    # Get space info
    username, space_url = get_space_info()

    # Load the dataset
    test_dataset = TextDataLoader(request, light=light_dataset).get_test_dataset()
    
    # Start tracking emissions
    if track_emissions:
        tracker = get_tracker()
        tracker.start()
        tracker.start_task("inference")

    # model inference
    predictions = [model.predict(quote) for quote in test_dataset["quote"]]

    # Stop tracking emissions
    if track_emissions:
        emissions_data = tracker.stop_task()
    else:
        emissions_data = EmissionsData(0, 0)
    
    # Calculate accuracy
    true_labels = test_dataset["label"]
    accuracy = accuracy_score(true_labels, predictions)
    
    # Prepare results dictionary
    results = {
        "username": username,
        "space_url": space_url,
        "submission_timestamp": datetime.now().isoformat(),
        "model_description": DESCRIPTION,
        "accuracy": float(accuracy),
        "energy_consumed_wh": emissions_data.energy_consumed * 1000,
        "emissions_gco2eq": emissions_data.emissions * 1000,
        "emissions_data": clean_emissions_data(emissions_data),
        "api_route": ROUTE,
        "dataset_config": {
            "dataset_name": request.dataset_name,
            "test_size": request.test_size,
            "test_seed": request.test_seed
        }
    }
    
    return results