andrewzamp commited on
Commit
4d59f1b
·
1 Parent(s): 8500d79

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -2
app.py CHANGED
@@ -77,8 +77,12 @@ def make_prediction(image, taxonomic_decision, taxonomic_level):
77
  predicted_class_index = np.argmax(aggregated_predictions)
78
  predicted_class_name = aggregated_class_labels[predicted_class_index]
79
 
80
- # Construct the output message without considering confidence
81
- output_text = f"<h1 style='font-weight: bold;'>{predicted_class_name}</h1>"
 
 
 
 
82
 
83
  # Add the top 5 predictions
84
  output_text += "<h4 style='font-weight: bold; font-size: 1.2em;'>Top-5 predictions:</h4>"
@@ -148,6 +152,60 @@ def make_prediction(image, taxonomic_decision, taxonomic_level):
148
 
149
  return output_text
150
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  # Define the Gradio interface
152
  interface = gr.Interface(
153
  fn=make_prediction, # Function to be called for predictions
 
77
  predicted_class_index = np.argmax(aggregated_predictions)
78
  predicted_class_name = aggregated_class_labels[predicted_class_index]
79
 
80
+ # Check if common name should be displayed (only at species level)
81
+ if taxonomic_levels[current_level_index] == "species":
82
+ predicted_common_name = taxo_df[taxo_df[taxonomic_levels[current_level_index]] == predicted_class_name]['common_name'].values[0]
83
+ output_text = f"<h1 style='font-weight: bold;'><span style='font-style: italic;'>{predicted_class_name}</span> ({predicted_common_name})</h1>"
84
+ else:
85
+ output_text = f"<h1 style='font-weight: bold;'>{predicted_class_name}</h1>"
86
 
87
  # Add the top 5 predictions
88
  output_text += "<h4 style='font-weight: bold; font-size: 1.2em;'>Top-5 predictions:</h4>"
 
152
 
153
  return output_text
154
 
155
+ # Confidence checking for the automatic model decision
156
+ # Loop through taxonomic levels if the user lets the model decide
157
+ while current_level_index < len(taxonomic_levels):
158
+ # Aggregate predictions for the next level
159
+ aggregated_predictions, aggregated_class_labels = aggregate_predictions(prediction, taxonomic_levels[current_level_index], class_names)
160
+
161
+ # Check if the confidence of the top prediction meets the threshold
162
+ top_prediction_index = np.argmax(aggregated_predictions)
163
+ top_prediction_confidence = aggregated_predictions[0][top_prediction_index]
164
+
165
+ if top_prediction_confidence >= 0.80:
166
+ break # Confidence threshold met, exit loop
167
+
168
+ current_level_index += 1 # Move to the next taxonomic level
169
+
170
+ # Check if a valid prediction was made
171
+ if current_level_index == len(taxonomic_levels):
172
+ return "<h1 style='font-weight: bold;'>Unknown animal</h1>" # No valid predictions met the confidence criteria
173
+
174
+ # Get the predicted class name for the top prediction
175
+ predicted_class_index = np.argmax(aggregated_predictions)
176
+ predicted_class_name = aggregated_class_labels[predicted_class_index]
177
+
178
+ # Check if common name should be displayed (only at species level)
179
+ if taxonomic_levels[current_level_index] == "species":
180
+ predicted_common_name = taxo_df[taxo_df[taxonomic_levels[current_level_index]] == predicted_class_name]['common_name'].values[0]
181
+ output_text = f"<h1 style='font-weight: bold;'><span style='font-style: italic;'>{predicted_class_name}</span> ({predicted_common_name})</h1>"
182
+ else:
183
+ output_text = f"<h1 style='font-weight: bold;'>{predicted_class_name}</h1>"
184
+
185
+ # Add the top 5 predictions
186
+ output_text += "<h4 style='font-weight: bold; font-size: 1.2em;'>Top-5 predictions:</h4>"
187
+
188
+ top_indices = np.argsort(aggregated_predictions[0])[-5:][::-1] # Get top 5 predictions
189
+
190
+ for i in top_indices:
191
+ class_name = aggregated_class_labels[i]
192
+
193
+ if taxonomic_levels[current_level_index] == "species":
194
+ # Display common names only at species level and make it italic
195
+ common_name = taxo_df[taxo_df[taxonomic_levels[current_level_index]] == class_name]['common_name'].values[0]
196
+ confidence_percentage = aggregated_predictions[0][i] * 100
197
+ output_text += f"<div style='display: flex; justify-content: space-between;'>" \
198
+ f"<span style='font-style: italic;'>{class_name}</span>&nbsp;(<span>{common_name}</span>)" \
199
+ f"<span style='margin-left: auto;'>{confidence_percentage:.2f}%</span></div>"
200
+ else:
201
+ # No common names at higher taxonomic levels
202
+ confidence_percentage = aggregated_predictions[0][i] * 100
203
+ output_text += f"<div style='display: flex; justify-content: space-between;'>" \
204
+ f"<span>{class_name}</span>" \
205
+ f"<span style='margin-left: auto;'>{confidence_percentage:.2f}%</span></div>"
206
+
207
+ return output_text
208
+
209
  # Define the Gradio interface
210
  interface = gr.Interface(
211
  fn=make_prediction, # Function to be called for predictions