andrewzamp commited on
Commit
e375d50
·
1 Parent(s): 8137f13

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -4
app.py CHANGED
@@ -30,10 +30,19 @@ def aggregate_predictions(predicted_probs, taxonomic_level, class_names):
30
  species = row['species']
31
  higher_level = row[taxonomic_level]
32
 
33
- species_index = class_names.index(species) # Index of the species in the prediction array
34
- higher_level_index = unique_labels.index(higher_level)
35
-
36
- aggregated_predictions[:, higher_level_index] += predicted_probs[:, species_index]
 
 
 
 
 
 
 
 
 
37
 
38
  return aggregated_predictions, unique_labels
39
 
 
30
  species = row['species']
31
  higher_level = row[taxonomic_level]
32
 
33
+ if species in class_names: # Check if species exists in class names
34
+ species_index = class_names.index(species) # Index of the species in the prediction array
35
+
36
+ if higher_level in unique_labels: # Check if higher level exists
37
+ higher_level_index = unique_labels.index(higher_level)
38
+
39
+ # Only update if indices are valid
40
+ if species_index < predicted_probs.shape[1] and higher_level_index < aggregated_predictions.shape[1]:
41
+ if predicted_probs[:, species_index].max() >= 0.80: # Check confidence level
42
+ aggregated_predictions[:, higher_level_index] += predicted_probs[:, species_index]
43
+ else:
44
+ # Stop aggregation at the current level if confidence is below 0.80
45
+ break
46
 
47
  return aggregated_predictions, unique_labels
48