Spaces:
Sleeping
Sleeping
Commit
·
e375d50
1
Parent(s):
8137f13
Update app.py
Browse files
app.py
CHANGED
@@ -30,10 +30,19 @@ def aggregate_predictions(predicted_probs, taxonomic_level, class_names):
|
|
30 |
species = row['species']
|
31 |
higher_level = row[taxonomic_level]
|
32 |
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
|
38 |
return aggregated_predictions, unique_labels
|
39 |
|
|
|
30 |
species = row['species']
|
31 |
higher_level = row[taxonomic_level]
|
32 |
|
33 |
+
if species in class_names: # Check if species exists in class names
|
34 |
+
species_index = class_names.index(species) # Index of the species in the prediction array
|
35 |
+
|
36 |
+
if higher_level in unique_labels: # Check if higher level exists
|
37 |
+
higher_level_index = unique_labels.index(higher_level)
|
38 |
+
|
39 |
+
# Only update if indices are valid
|
40 |
+
if species_index < predicted_probs.shape[1] and higher_level_index < aggregated_predictions.shape[1]:
|
41 |
+
if predicted_probs[:, species_index].max() >= 0.80: # Check confidence level
|
42 |
+
aggregated_predictions[:, higher_level_index] += predicted_probs[:, species_index]
|
43 |
+
else:
|
44 |
+
# Stop aggregation at the current level if confidence is below 0.80
|
45 |
+
break
|
46 |
|
47 |
return aggregated_predictions, unique_labels
|
48 |
|