leynessa commited on
Commit
a91e356
·
verified ·
1 Parent(s): 62475ba

Update streamlit_app.py

Browse files
Files changed (1) hide show
  1. streamlit_app.py +133 -53
streamlit_app.py CHANGED
@@ -47,7 +47,7 @@ def detect_model_architecture(model_state_dict):
47
  if 'classifier.weight' in model_state_dict:
48
  classifier_input_features = model_state_dict['classifier.weight'].shape[1]
49
 
50
- # EfficientNet feature mapping (corrected based on actual architectures)
51
  efficientnet_features = {
52
  1280: 'efficientnet_b0',
53
  1408: 'efficientnet_b1',
@@ -61,7 +61,7 @@ def detect_model_architecture(model_state_dict):
61
 
62
  model_name = efficientnet_features.get(classifier_input_features, 'efficientnet_b3')
63
  print(f"Detected model architecture: {model_name} (classifier features: {classifier_input_features})")
64
- return model_name
65
 
66
  # Check bn2 layer (final batch norm before classifier)
67
  if 'bn2.weight' in model_state_dict:
@@ -82,11 +82,11 @@ def detect_model_architecture(model_state_dict):
82
 
83
  model_name = bn2_mapping.get(bn2_features, 'efficientnet_b3')
84
  print(f"Detected model architecture from bn2: {model_name}")
85
- return model_name
86
 
87
  # Default to B3 based on your error logs
88
  print("Could not detect model architecture, defaulting to efficientnet_b3")
89
- return 'efficientnet_b3'
90
 
91
  @st.cache_resource
92
  def load_model():
@@ -105,6 +105,10 @@ def load_model():
105
 
106
  The model file appears to be a Git LFS pointer file (size: {file_size} bytes).
107
  This means the actual model wasn't downloaded properly.
 
 
 
 
108
  """)
109
  return None
110
 
@@ -128,7 +132,7 @@ def load_model():
128
  print(f" {key}: {model_state_dict[key].shape}")
129
 
130
  # Auto-detect the correct model architecture
131
- model_name = detect_model_architecture(model_state_dict)
132
 
133
  # Get number of classes
134
  num_classes = len(class_names)
@@ -137,32 +141,56 @@ def load_model():
137
 
138
  print(f"Loading {model_name} with {num_classes} classes")
139
 
140
- # Create model with correct architecture and matching parameters
141
- # Use the exact same parameters as in training
142
- model = timm.create_model(
143
- model_name,
144
- pretrained=False, # Don't load pretrained weights
145
- num_classes=num_classes,
146
- drop_rate=0.4, # Match training parameters
147
- drop_path_rate=0.3
148
- )
149
 
150
- # Load the trained weights
151
- try:
152
- model.load_state_dict(model_state_dict, strict=True)
153
- print("Model loaded successfully with strict=True")
154
- except RuntimeError as e:
155
- print(f"Strict loading failed: {e}")
156
- # Try with strict=False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  missing_keys, unexpected_keys = model.load_state_dict(model_state_dict, strict=False)
158
 
159
  if missing_keys:
160
  print(f"Missing keys: {missing_keys}")
161
- st.warning(f"Some model weights were not loaded: {len(missing_keys)} missing keys")
162
 
163
  if unexpected_keys:
164
  print(f"Unexpected keys: {unexpected_keys}")
165
- st.warning(f"Some checkpoint keys were not used: {len(unexpected_keys)} unexpected keys")
166
 
167
  # Verify the model loaded correctly
168
  model.eval()
@@ -173,6 +201,12 @@ def load_model():
173
  try:
174
  dummy_output = model(dummy_input)
175
  print(f"Model test successful. Output shape: {dummy_output.shape}")
 
 
 
 
 
 
176
  except Exception as e:
177
  print(f"Model test failed: {e}")
178
  st.error(f"Model validation failed: {e}")
@@ -190,14 +224,19 @@ def load_model():
190
  checkpoint = torch.load(MODEL_PATH, map_location='cpu')
191
  if 'model_state_dict' in checkpoint:
192
  model_keys = list(checkpoint['model_state_dict'].keys())
193
- print(f"Available keys in checkpoint: {model_keys[:10]}...") # Show first 10 keys
194
 
195
  # Show the problematic layer shapes
196
  state_dict = checkpoint['model_state_dict']
197
  if 'classifier.weight' in state_dict and 'bn2.weight' in state_dict:
198
- print(f"Classifier input features: {state_dict['classifier.weight'].shape[1]}")
199
- print(f"bn2 features: {state_dict['bn2.weight'].shape[0]}")
 
 
200
 
 
 
 
201
  except Exception as debug_e:
202
  print(f"Debug info failed: {debug_e}")
203
 
@@ -207,8 +246,12 @@ def load_model():
207
  model = load_model()
208
 
209
  if model is None:
210
- st.warning("⚠️ **Model Loading Failed**")
211
- st.info("Please check the logs for detailed error information.")
 
 
 
 
212
  st.stop()
213
 
214
  # Transform for preprocessing (same as training)
@@ -241,11 +284,20 @@ def predict_butterfly(image):
241
  probabilities = torch.nn.functional.softmax(output[0], dim=0)
242
  confidence, pred = torch.max(probabilities, 0)
243
 
 
 
 
244
  if pred.item() < len(class_names):
245
  predicted_class = class_names[pred.item()]
246
  else:
247
  predicted_class = f"Class_{pred.item()}"
248
 
 
 
 
 
 
 
249
  return predicted_class, confidence.item()
250
 
251
  except Exception as e:
@@ -256,6 +308,10 @@ def predict_butterfly(image):
256
  st.title("🦋 Butterfly Identifier / Liblikamaja ID")
257
  st.write("Identify butterflies using your camera or by uploading an image!")
258
 
 
 
 
 
259
  # Create tabs for different input methods
260
  tab1, tab2 = st.tabs(["📷 Live Camera", "📁 Upload Image"])
261
 
@@ -275,20 +331,27 @@ with tab1:
275
  st.image(image, caption="Captured Image", use_column_width=True)
276
 
277
  with col2:
278
- predicted_class, confidence = predict_butterfly(image)
 
279
 
280
- if predicted_class and confidence and confidence >= 0.80:
281
- st.success(f"**Prediction: {predicted_class}**")
282
- st.info(f"Confidence: {confidence:.2%}")
 
 
 
 
 
 
 
 
 
 
 
 
 
283
  else:
284
- st.warning("⚠️ **Image not clear - Unable to identify butterfly**")
285
- if confidence:
286
- st.info(f"Confidence too low: {confidence:.1%}")
287
- st.markdown("**Tips for better results:**")
288
- st.markdown("- Use better lighting")
289
- st.markdown("- Get closer to the butterfly")
290
- st.markdown("- Ensure the butterfly is clearly visible")
291
- st.markdown("- Avoid blurry or dark images")
292
 
293
  except Exception as e:
294
  st.error(f"Error processing image: {str(e)}")
@@ -314,20 +377,27 @@ with tab2:
314
  st.image(image, caption="Uploaded Image", use_column_width=True)
315
 
316
  with col2:
317
- predicted_class, confidence = predict_butterfly(image)
 
318
 
319
- if predicted_class and confidence and confidence >= 0.80:
320
- st.success(f"**Prediction: {predicted_class}**")
321
- st.info(f"Confidence: {confidence:.2%}")
 
 
 
 
 
 
 
 
 
 
 
 
 
322
  else:
323
- st.warning("⚠️ **Image not clear - Unable to identify butterfly**")
324
- if confidence:
325
- st.info(f"Confidence too low: {confidence:.1%}")
326
- st.markdown("**Tips for better results:**")
327
- st.markdown("- Use better lighting")
328
- st.markdown("- Get closer to the butterfly")
329
- st.markdown("- Ensure the butterfly is clearly visible")
330
- st.markdown("- Avoid blurry or dark images")
331
 
332
  except Exception as e:
333
  st.error(f"Error processing image: {str(e)}")
@@ -337,4 +407,14 @@ st.markdown("---")
337
  st.markdown("### How to use:")
338
  st.markdown("1. **Camera Capture**: Take a photo using your device camera")
339
  st.markdown("2. **Upload Image**: Choose a butterfly photo from your device")
340
- st.markdown("3. **Best Results**: Use clear, well-lit photos with the butterfly clearly visible")
 
 
 
 
 
 
 
 
 
 
 
47
  if 'classifier.weight' in model_state_dict:
48
  classifier_input_features = model_state_dict['classifier.weight'].shape[1]
49
 
50
+ # EfficientNet feature mapping (updated based on actual timm implementations)
51
  efficientnet_features = {
52
  1280: 'efficientnet_b0',
53
  1408: 'efficientnet_b1',
 
61
 
62
  model_name = efficientnet_features.get(classifier_input_features, 'efficientnet_b3')
63
  print(f"Detected model architecture: {model_name} (classifier features: {classifier_input_features})")
64
+ return model_name, classifier_input_features
65
 
66
  # Check bn2 layer (final batch norm before classifier)
67
  if 'bn2.weight' in model_state_dict:
 
82
 
83
  model_name = bn2_mapping.get(bn2_features, 'efficientnet_b3')
84
  print(f"Detected model architecture from bn2: {model_name}")
85
+ return model_name, bn2_features
86
 
87
  # Default to B3 based on your error logs
88
  print("Could not detect model architecture, defaulting to efficientnet_b3")
89
+ return 'efficientnet_b3', 1792
90
 
91
  @st.cache_resource
92
  def load_model():
 
105
 
106
  The model file appears to be a Git LFS pointer file (size: {file_size} bytes).
107
  This means the actual model wasn't downloaded properly.
108
+
109
+ **To fix this:**
110
+ 1. Run: `git lfs pull` in your repository
111
+ 2. Or download the model file directly from your storage
112
  """)
113
  return None
114
 
 
132
  print(f" {key}: {model_state_dict[key].shape}")
133
 
134
  # Auto-detect the correct model architecture
135
+ model_name, expected_features = detect_model_architecture(model_state_dict)
136
 
137
  # Get number of classes
138
  num_classes = len(class_names)
 
141
 
142
  print(f"Loading {model_name} with {num_classes} classes")
143
 
144
+ # Create model with correct architecture
145
+ # Try different parameter combinations that might have been used during training
146
+ model_configs = [
147
+ # Most likely configuration based on your checkpoint
148
+ {'drop_rate': 0.4, 'drop_path_rate': 0.3},
149
+ {'drop_rate': 0.3, 'drop_path_rate': 0.2},
150
+ {'drop_rate': 0.2, 'drop_path_rate': 0.1},
151
+ {'drop_rate': 0.0, 'drop_path_rate': 0.0}, # Default
152
+ ]
153
 
154
+ model = None
155
+ for config in model_configs:
156
+ try:
157
+ print(f"Trying model config: {config}")
158
+ model = timm.create_model(
159
+ model_name,
160
+ pretrained=False,
161
+ num_classes=num_classes,
162
+ **config
163
+ )
164
+
165
+ # Try loading with strict=True first
166
+ model.load_state_dict(model_state_dict, strict=True)
167
+ print(f"Model loaded successfully with config: {config}")
168
+ break
169
+
170
+ except RuntimeError as e:
171
+ print(f"Config {config} failed: {e}")
172
+ continue
173
+
174
+ # If strict loading failed for all configs, try with strict=False
175
+ if model is None:
176
+ print("All strict loading attempts failed, trying with strict=False")
177
+ model = timm.create_model(
178
+ model_name,
179
+ pretrained=False,
180
+ num_classes=num_classes,
181
+ drop_rate=0.4,
182
+ drop_path_rate=0.3
183
+ )
184
+
185
  missing_keys, unexpected_keys = model.load_state_dict(model_state_dict, strict=False)
186
 
187
  if missing_keys:
188
  print(f"Missing keys: {missing_keys}")
189
+ st.warning(f"⚠️ Some model weights were not loaded: {len(missing_keys)} missing keys")
190
 
191
  if unexpected_keys:
192
  print(f"Unexpected keys: {unexpected_keys}")
193
+ st.warning(f"⚠️ Some checkpoint keys were not used: {len(unexpected_keys)} unexpected keys")
194
 
195
  # Verify the model loaded correctly
196
  model.eval()
 
201
  try:
202
  dummy_output = model(dummy_input)
203
  print(f"Model test successful. Output shape: {dummy_output.shape}")
204
+
205
+ # Verify output shape matches expected classes
206
+ if dummy_output.shape[1] != num_classes:
207
+ st.error(f"Model output mismatch: expected {num_classes} classes, got {dummy_output.shape[1]}")
208
+ return None
209
+
210
  except Exception as e:
211
  print(f"Model test failed: {e}")
212
  st.error(f"Model validation failed: {e}")
 
224
  checkpoint = torch.load(MODEL_PATH, map_location='cpu')
225
  if 'model_state_dict' in checkpoint:
226
  model_keys = list(checkpoint['model_state_dict'].keys())
227
+ print(f"Available keys in checkpoint: {model_keys[:10]}...")
228
 
229
  # Show the problematic layer shapes
230
  state_dict = checkpoint['model_state_dict']
231
  if 'classifier.weight' in state_dict and 'bn2.weight' in state_dict:
232
+ classifier_features = state_dict['classifier.weight'].shape[1]
233
+ bn2_features = state_dict['bn2.weight'].shape[0]
234
+ print(f"Classifier input features: {classifier_features}")
235
+ print(f"bn2 features: {bn2_features}")
236
 
237
+ if classifier_features != bn2_features:
238
+ st.error(f"Architecture mismatch: classifier expects {classifier_features} features, but bn2 has {bn2_features}")
239
+
240
  except Exception as debug_e:
241
  print(f"Debug info failed: {debug_e}")
242
 
 
246
  model = load_model()
247
 
248
  if model is None:
249
+ st.error("⚠️ **Model Loading Failed**")
250
+ st.info("**Possible solutions:**")
251
+ st.markdown("1. **Git LFS issue**: Run `git lfs pull` to download the actual model file")
252
+ st.markdown("2. **Architecture mismatch**: The model was trained with different parameters")
253
+ st.markdown("3. **Corrupted file**: Re-download or re-train the model")
254
+ st.markdown("4. **Check the console/logs** for detailed error information")
255
  st.stop()
256
 
257
  # Transform for preprocessing (same as training)
 
284
  probabilities = torch.nn.functional.softmax(output[0], dim=0)
285
  confidence, pred = torch.max(probabilities, 0)
286
 
287
+ # Get top 3 predictions for better debugging
288
+ top_probs, top_indices = torch.topk(probabilities, min(3, len(probabilities)))
289
+
290
  if pred.item() < len(class_names):
291
  predicted_class = class_names[pred.item()]
292
  else:
293
  predicted_class = f"Class_{pred.item()}"
294
 
295
+ # Debug info
296
+ print(f"Top predictions:")
297
+ for i, (prob, idx) in enumerate(zip(top_probs, top_indices)):
298
+ class_name = class_names[idx.item()] if idx.item() < len(class_names) else f"Class_{idx.item()}"
299
+ print(f" {i+1}. {class_name}: {prob.item():.3f}")
300
+
301
  return predicted_class, confidence.item()
302
 
303
  except Exception as e:
 
308
  st.title("🦋 Butterfly Identifier / Liblikamaja ID")
309
  st.write("Identify butterflies using your camera or by uploading an image!")
310
 
311
+ # Show model info
312
+ if model is not None:
313
+ st.info(f"📊 Model loaded: {len(class_names)} butterfly species recognized")
314
+
315
  # Create tabs for different input methods
316
  tab1, tab2 = st.tabs(["📷 Live Camera", "📁 Upload Image"])
317
 
 
331
  st.image(image, caption="Captured Image", use_column_width=True)
332
 
333
  with col2:
334
+ with st.spinner("Analyzing image..."):
335
+ predicted_class, confidence = predict_butterfly(image)
336
 
337
+ if predicted_class and confidence:
338
+ if confidence >= 0.80:
339
+ st.success(f"**Prediction: {predicted_class}**")
340
+ st.info(f"Confidence: {confidence:.2%}")
341
+
342
+ # Show additional info if available
343
+ if predicted_class in butterfly_info:
344
+ st.write(f"**About:** {butterfly_info[predicted_class]['description']}")
345
+ else:
346
+ st.warning("⚠️ **Low confidence prediction**")
347
+ st.info(f"Best guess: {predicted_class} ({confidence:.1%})")
348
+ st.markdown("**Tips for better results:**")
349
+ st.markdown("- Use better lighting")
350
+ st.markdown("- Get closer to the butterfly")
351
+ st.markdown("- Ensure the butterfly is clearly visible")
352
+ st.markdown("- Avoid blurry or dark images")
353
  else:
354
+ st.error("Unable to analyze image. Please try again.")
 
 
 
 
 
 
 
355
 
356
  except Exception as e:
357
  st.error(f"Error processing image: {str(e)}")
 
377
  st.image(image, caption="Uploaded Image", use_column_width=True)
378
 
379
  with col2:
380
+ with st.spinner("Analyzing image..."):
381
+ predicted_class, confidence = predict_butterfly(image)
382
 
383
+ if predicted_class and confidence:
384
+ if confidence >= 0.80:
385
+ st.success(f"**Prediction: {predicted_class}**")
386
+ st.info(f"Confidence: {confidence:.2%}")
387
+
388
+ # Show additional info if available
389
+ if predicted_class in butterfly_info:
390
+ st.write(f"**About:** {butterfly_info[predicted_class]['description']}")
391
+ else:
392
+ st.warning("⚠️ **Low confidence prediction**")
393
+ st.info(f"Best guess: {predicted_class} ({confidence:.1%})")
394
+ st.markdown("**Tips for better results:**")
395
+ st.markdown("- Use better lighting")
396
+ st.markdown("- Get closer to the butterfly")
397
+ st.markdown("- Ensure the butterfly is clearly visible")
398
+ st.markdown("- Avoid blurry or dark images")
399
  else:
400
+ st.error("Unable to analyze image. Please try again.")
 
 
 
 
 
 
 
401
 
402
  except Exception as e:
403
  st.error(f"Error processing image: {str(e)}")
 
407
  st.markdown("### How to use:")
408
  st.markdown("1. **Camera Capture**: Take a photo using your device camera")
409
  st.markdown("2. **Upload Image**: Choose a butterfly photo from your device")
410
+ st.markdown("3. **Best Results**: Use clear, well-lit photos with the butterfly clearly visible")
411
+
412
+ # Debug info (only show if there are issues)
413
+ if st.checkbox("Show debug information"):
414
+ st.markdown("### Debug Information")
415
+ st.write(f"Number of classes: {len(class_names)}")
416
+ st.write(f"Model loaded: {model is not None}")
417
+ if model:
418
+ st.write("Model architecture successfully detected and loaded")
419
+ else:
420
+ st.write("❌ Model failed to load - check console for details")