Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -26,12 +26,12 @@ DEFAULT_PARAMS = {
|
|
| 26 |
"test_seed": 42, # must be non-negative
|
| 27 |
},
|
| 28 |
"image":{
|
| 29 |
-
"dataset_name": "
|
| 30 |
"test_size": 0.2, # must be between 0 and 1
|
| 31 |
"test_seed": 42, # must be non-negative
|
| 32 |
},
|
| 33 |
"audio":{
|
| 34 |
-
"dataset_name": "
|
| 35 |
"test_size": 0.2, # must be between 0 and 1
|
| 36 |
"test_seed": 42, # must be non-negative
|
| 37 |
}
|
|
@@ -61,19 +61,33 @@ def evaluate_model(task: str, space_url: str):
|
|
| 61 |
|
| 62 |
results = response.json()
|
| 63 |
|
| 64 |
-
# Check for required keys
|
| 65 |
-
|
| 66 |
-
"username", "space_url", "submission_timestamp", "model_description",
|
| 67 |
-
"
|
| 68 |
"api_route", "dataset_config"
|
| 69 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
|
| 71 |
missing_keys = required_keys - set(results.keys())
|
| 72 |
if missing_keys:
|
| 73 |
return None, None, None, gr.Warning(f"API response missing required keys: {', '.join(missing_keys)}")
|
| 74 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
return (
|
| 76 |
-
|
| 77 |
results["emissions_gco2eq"],
|
| 78 |
results["energy_consumed_wh"],
|
| 79 |
results
|
|
|
|
| 26 |
"test_seed": 42, # must be non-negative
|
| 27 |
},
|
| 28 |
"image":{
|
| 29 |
+
"dataset_name": "pyronear/pyro-sdis",
|
| 30 |
"test_size": 0.2, # must be between 0 and 1
|
| 31 |
"test_seed": 42, # must be non-negative
|
| 32 |
},
|
| 33 |
"audio":{
|
| 34 |
+
"dataset_name": "rfcx/frugalai",
|
| 35 |
"test_size": 0.2, # must be between 0 and 1
|
| 36 |
"test_seed": 42, # must be non-negative
|
| 37 |
}
|
|
|
|
| 61 |
|
| 62 |
results = response.json()
|
| 63 |
|
| 64 |
+
# Check for required keys based on task
|
| 65 |
+
base_required_keys = {
|
| 66 |
+
"username", "space_url", "submission_timestamp", "model_description",
|
| 67 |
+
"energy_consumed_wh", "emissions_gco2eq", "emissions_data",
|
| 68 |
"api_route", "dataset_config"
|
| 69 |
}
|
| 70 |
+
|
| 71 |
+
# Add task-specific accuracy keys
|
| 72 |
+
if task == "image":
|
| 73 |
+
accuracy_keys = {"classification_accuracy", "mean_iou"}
|
| 74 |
+
else: # text and audio
|
| 75 |
+
accuracy_keys = {"accuracy"}
|
| 76 |
+
|
| 77 |
+
required_keys = base_required_keys | accuracy_keys
|
| 78 |
|
| 79 |
missing_keys = required_keys - set(results.keys())
|
| 80 |
if missing_keys:
|
| 81 |
return None, None, None, gr.Warning(f"API response missing required keys: {', '.join(missing_keys)}")
|
| 82 |
|
| 83 |
+
# Return appropriate accuracy metric based on task
|
| 84 |
+
if task == "image":
|
| 85 |
+
accuracy = results["classification_accuracy"] # For display in UI
|
| 86 |
+
else:
|
| 87 |
+
accuracy = results["accuracy"]
|
| 88 |
+
|
| 89 |
return (
|
| 90 |
+
accuracy,
|
| 91 |
results["emissions_gco2eq"],
|
| 92 |
results["energy_consumed_wh"],
|
| 93 |
results
|