Spaces:
Sleeping
Sleeping
Update streamlit_app.py
Browse files- streamlit_app.py +133 -53
streamlit_app.py
CHANGED
@@ -47,7 +47,7 @@ def detect_model_architecture(model_state_dict):
|
|
47 |
if 'classifier.weight' in model_state_dict:
|
48 |
classifier_input_features = model_state_dict['classifier.weight'].shape[1]
|
49 |
|
50 |
-
# EfficientNet feature mapping (
|
51 |
efficientnet_features = {
|
52 |
1280: 'efficientnet_b0',
|
53 |
1408: 'efficientnet_b1',
|
@@ -61,7 +61,7 @@ def detect_model_architecture(model_state_dict):
|
|
61 |
|
62 |
model_name = efficientnet_features.get(classifier_input_features, 'efficientnet_b3')
|
63 |
print(f"Detected model architecture: {model_name} (classifier features: {classifier_input_features})")
|
64 |
-
return model_name
|
65 |
|
66 |
# Check bn2 layer (final batch norm before classifier)
|
67 |
if 'bn2.weight' in model_state_dict:
|
@@ -82,11 +82,11 @@ def detect_model_architecture(model_state_dict):
|
|
82 |
|
83 |
model_name = bn2_mapping.get(bn2_features, 'efficientnet_b3')
|
84 |
print(f"Detected model architecture from bn2: {model_name}")
|
85 |
-
return model_name
|
86 |
|
87 |
# Default to B3 based on your error logs
|
88 |
print("Could not detect model architecture, defaulting to efficientnet_b3")
|
89 |
-
return 'efficientnet_b3'
|
90 |
|
91 |
@st.cache_resource
|
92 |
def load_model():
|
@@ -105,6 +105,10 @@ def load_model():
|
|
105 |
|
106 |
The model file appears to be a Git LFS pointer file (size: {file_size} bytes).
|
107 |
This means the actual model wasn't downloaded properly.
|
|
|
|
|
|
|
|
|
108 |
""")
|
109 |
return None
|
110 |
|
@@ -128,7 +132,7 @@ def load_model():
|
|
128 |
print(f" {key}: {model_state_dict[key].shape}")
|
129 |
|
130 |
# Auto-detect the correct model architecture
|
131 |
-
model_name = detect_model_architecture(model_state_dict)
|
132 |
|
133 |
# Get number of classes
|
134 |
num_classes = len(class_names)
|
@@ -137,32 +141,56 @@ def load_model():
|
|
137 |
|
138 |
print(f"Loading {model_name} with {num_classes} classes")
|
139 |
|
140 |
-
# Create model with correct architecture
|
141 |
-
#
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
drop_rate
|
147 |
-
drop_path_rate
|
148 |
-
|
149 |
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
157 |
missing_keys, unexpected_keys = model.load_state_dict(model_state_dict, strict=False)
|
158 |
|
159 |
if missing_keys:
|
160 |
print(f"Missing keys: {missing_keys}")
|
161 |
-
st.warning(f"Some model weights were not loaded: {len(missing_keys)} missing keys")
|
162 |
|
163 |
if unexpected_keys:
|
164 |
print(f"Unexpected keys: {unexpected_keys}")
|
165 |
-
st.warning(f"Some checkpoint keys were not used: {len(unexpected_keys)} unexpected keys")
|
166 |
|
167 |
# Verify the model loaded correctly
|
168 |
model.eval()
|
@@ -173,6 +201,12 @@ def load_model():
|
|
173 |
try:
|
174 |
dummy_output = model(dummy_input)
|
175 |
print(f"Model test successful. Output shape: {dummy_output.shape}")
|
|
|
|
|
|
|
|
|
|
|
|
|
176 |
except Exception as e:
|
177 |
print(f"Model test failed: {e}")
|
178 |
st.error(f"Model validation failed: {e}")
|
@@ -190,14 +224,19 @@ def load_model():
|
|
190 |
checkpoint = torch.load(MODEL_PATH, map_location='cpu')
|
191 |
if 'model_state_dict' in checkpoint:
|
192 |
model_keys = list(checkpoint['model_state_dict'].keys())
|
193 |
-
print(f"Available keys in checkpoint: {model_keys[:10]}...")
|
194 |
|
195 |
# Show the problematic layer shapes
|
196 |
state_dict = checkpoint['model_state_dict']
|
197 |
if 'classifier.weight' in state_dict and 'bn2.weight' in state_dict:
|
198 |
-
|
199 |
-
|
|
|
|
|
200 |
|
|
|
|
|
|
|
201 |
except Exception as debug_e:
|
202 |
print(f"Debug info failed: {debug_e}")
|
203 |
|
@@ -207,8 +246,12 @@ def load_model():
|
|
207 |
model = load_model()
|
208 |
|
209 |
if model is None:
|
210 |
-
st.
|
211 |
-
st.info("
|
|
|
|
|
|
|
|
|
212 |
st.stop()
|
213 |
|
214 |
# Transform for preprocessing (same as training)
|
@@ -241,11 +284,20 @@ def predict_butterfly(image):
|
|
241 |
probabilities = torch.nn.functional.softmax(output[0], dim=0)
|
242 |
confidence, pred = torch.max(probabilities, 0)
|
243 |
|
|
|
|
|
|
|
244 |
if pred.item() < len(class_names):
|
245 |
predicted_class = class_names[pred.item()]
|
246 |
else:
|
247 |
predicted_class = f"Class_{pred.item()}"
|
248 |
|
|
|
|
|
|
|
|
|
|
|
|
|
249 |
return predicted_class, confidence.item()
|
250 |
|
251 |
except Exception as e:
|
@@ -256,6 +308,10 @@ def predict_butterfly(image):
|
|
256 |
st.title("🦋 Butterfly Identifier / Liblikamaja ID")
|
257 |
st.write("Identify butterflies using your camera or by uploading an image!")
|
258 |
|
|
|
|
|
|
|
|
|
259 |
# Create tabs for different input methods
|
260 |
tab1, tab2 = st.tabs(["📷 Live Camera", "📁 Upload Image"])
|
261 |
|
@@ -275,20 +331,27 @@ with tab1:
|
|
275 |
st.image(image, caption="Captured Image", use_column_width=True)
|
276 |
|
277 |
with col2:
|
278 |
-
|
|
|
279 |
|
280 |
-
if predicted_class and confidence
|
281 |
-
|
282 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
283 |
else:
|
284 |
-
st.
|
285 |
-
if confidence:
|
286 |
-
st.info(f"Confidence too low: {confidence:.1%}")
|
287 |
-
st.markdown("**Tips for better results:**")
|
288 |
-
st.markdown("- Use better lighting")
|
289 |
-
st.markdown("- Get closer to the butterfly")
|
290 |
-
st.markdown("- Ensure the butterfly is clearly visible")
|
291 |
-
st.markdown("- Avoid blurry or dark images")
|
292 |
|
293 |
except Exception as e:
|
294 |
st.error(f"Error processing image: {str(e)}")
|
@@ -314,20 +377,27 @@ with tab2:
|
|
314 |
st.image(image, caption="Uploaded Image", use_column_width=True)
|
315 |
|
316 |
with col2:
|
317 |
-
|
|
|
318 |
|
319 |
-
if predicted_class and confidence
|
320 |
-
|
321 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
322 |
else:
|
323 |
-
st.
|
324 |
-
if confidence:
|
325 |
-
st.info(f"Confidence too low: {confidence:.1%}")
|
326 |
-
st.markdown("**Tips for better results:**")
|
327 |
-
st.markdown("- Use better lighting")
|
328 |
-
st.markdown("- Get closer to the butterfly")
|
329 |
-
st.markdown("- Ensure the butterfly is clearly visible")
|
330 |
-
st.markdown("- Avoid blurry or dark images")
|
331 |
|
332 |
except Exception as e:
|
333 |
st.error(f"Error processing image: {str(e)}")
|
@@ -337,4 +407,14 @@ st.markdown("---")
|
|
337 |
st.markdown("### How to use:")
|
338 |
st.markdown("1. **Camera Capture**: Take a photo using your device camera")
|
339 |
st.markdown("2. **Upload Image**: Choose a butterfly photo from your device")
|
340 |
-
st.markdown("3. **Best Results**: Use clear, well-lit photos with the butterfly clearly visible")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
if 'classifier.weight' in model_state_dict:
|
48 |
classifier_input_features = model_state_dict['classifier.weight'].shape[1]
|
49 |
|
50 |
+
# EfficientNet feature mapping (updated based on actual timm implementations)
|
51 |
efficientnet_features = {
|
52 |
1280: 'efficientnet_b0',
|
53 |
1408: 'efficientnet_b1',
|
|
|
61 |
|
62 |
model_name = efficientnet_features.get(classifier_input_features, 'efficientnet_b3')
|
63 |
print(f"Detected model architecture: {model_name} (classifier features: {classifier_input_features})")
|
64 |
+
return model_name, classifier_input_features
|
65 |
|
66 |
# Check bn2 layer (final batch norm before classifier)
|
67 |
if 'bn2.weight' in model_state_dict:
|
|
|
82 |
|
83 |
model_name = bn2_mapping.get(bn2_features, 'efficientnet_b3')
|
84 |
print(f"Detected model architecture from bn2: {model_name}")
|
85 |
+
return model_name, bn2_features
|
86 |
|
87 |
# Default to B3 based on your error logs
|
88 |
print("Could not detect model architecture, defaulting to efficientnet_b3")
|
89 |
+
return 'efficientnet_b3', 1792
|
90 |
|
91 |
@st.cache_resource
|
92 |
def load_model():
|
|
|
105 |
|
106 |
The model file appears to be a Git LFS pointer file (size: {file_size} bytes).
|
107 |
This means the actual model wasn't downloaded properly.
|
108 |
+
|
109 |
+
**To fix this:**
|
110 |
+
1. Run: `git lfs pull` in your repository
|
111 |
+
2. Or download the model file directly from your storage
|
112 |
""")
|
113 |
return None
|
114 |
|
|
|
132 |
print(f" {key}: {model_state_dict[key].shape}")
|
133 |
|
134 |
# Auto-detect the correct model architecture
|
135 |
+
model_name, expected_features = detect_model_architecture(model_state_dict)
|
136 |
|
137 |
# Get number of classes
|
138 |
num_classes = len(class_names)
|
|
|
141 |
|
142 |
print(f"Loading {model_name} with {num_classes} classes")
|
143 |
|
144 |
+
# Create model with correct architecture
|
145 |
+
# Try different parameter combinations that might have been used during training
|
146 |
+
model_configs = [
|
147 |
+
# Most likely configuration based on your checkpoint
|
148 |
+
{'drop_rate': 0.4, 'drop_path_rate': 0.3},
|
149 |
+
{'drop_rate': 0.3, 'drop_path_rate': 0.2},
|
150 |
+
{'drop_rate': 0.2, 'drop_path_rate': 0.1},
|
151 |
+
{'drop_rate': 0.0, 'drop_path_rate': 0.0}, # Default
|
152 |
+
]
|
153 |
|
154 |
+
model = None
|
155 |
+
for config in model_configs:
|
156 |
+
try:
|
157 |
+
print(f"Trying model config: {config}")
|
158 |
+
model = timm.create_model(
|
159 |
+
model_name,
|
160 |
+
pretrained=False,
|
161 |
+
num_classes=num_classes,
|
162 |
+
**config
|
163 |
+
)
|
164 |
+
|
165 |
+
# Try loading with strict=True first
|
166 |
+
model.load_state_dict(model_state_dict, strict=True)
|
167 |
+
print(f"Model loaded successfully with config: {config}")
|
168 |
+
break
|
169 |
+
|
170 |
+
except RuntimeError as e:
|
171 |
+
print(f"Config {config} failed: {e}")
|
172 |
+
continue
|
173 |
+
|
174 |
+
# If strict loading failed for all configs, try with strict=False
|
175 |
+
if model is None:
|
176 |
+
print("All strict loading attempts failed, trying with strict=False")
|
177 |
+
model = timm.create_model(
|
178 |
+
model_name,
|
179 |
+
pretrained=False,
|
180 |
+
num_classes=num_classes,
|
181 |
+
drop_rate=0.4,
|
182 |
+
drop_path_rate=0.3
|
183 |
+
)
|
184 |
+
|
185 |
missing_keys, unexpected_keys = model.load_state_dict(model_state_dict, strict=False)
|
186 |
|
187 |
if missing_keys:
|
188 |
print(f"Missing keys: {missing_keys}")
|
189 |
+
st.warning(f"⚠️ Some model weights were not loaded: {len(missing_keys)} missing keys")
|
190 |
|
191 |
if unexpected_keys:
|
192 |
print(f"Unexpected keys: {unexpected_keys}")
|
193 |
+
st.warning(f"⚠️ Some checkpoint keys were not used: {len(unexpected_keys)} unexpected keys")
|
194 |
|
195 |
# Verify the model loaded correctly
|
196 |
model.eval()
|
|
|
201 |
try:
|
202 |
dummy_output = model(dummy_input)
|
203 |
print(f"Model test successful. Output shape: {dummy_output.shape}")
|
204 |
+
|
205 |
+
# Verify output shape matches expected classes
|
206 |
+
if dummy_output.shape[1] != num_classes:
|
207 |
+
st.error(f"Model output mismatch: expected {num_classes} classes, got {dummy_output.shape[1]}")
|
208 |
+
return None
|
209 |
+
|
210 |
except Exception as e:
|
211 |
print(f"Model test failed: {e}")
|
212 |
st.error(f"Model validation failed: {e}")
|
|
|
224 |
checkpoint = torch.load(MODEL_PATH, map_location='cpu')
|
225 |
if 'model_state_dict' in checkpoint:
|
226 |
model_keys = list(checkpoint['model_state_dict'].keys())
|
227 |
+
print(f"Available keys in checkpoint: {model_keys[:10]}...")
|
228 |
|
229 |
# Show the problematic layer shapes
|
230 |
state_dict = checkpoint['model_state_dict']
|
231 |
if 'classifier.weight' in state_dict and 'bn2.weight' in state_dict:
|
232 |
+
classifier_features = state_dict['classifier.weight'].shape[1]
|
233 |
+
bn2_features = state_dict['bn2.weight'].shape[0]
|
234 |
+
print(f"Classifier input features: {classifier_features}")
|
235 |
+
print(f"bn2 features: {bn2_features}")
|
236 |
|
237 |
+
if classifier_features != bn2_features:
|
238 |
+
st.error(f"Architecture mismatch: classifier expects {classifier_features} features, but bn2 has {bn2_features}")
|
239 |
+
|
240 |
except Exception as debug_e:
|
241 |
print(f"Debug info failed: {debug_e}")
|
242 |
|
|
|
246 |
model = load_model()
|
247 |
|
248 |
if model is None:
|
249 |
+
st.error("⚠️ **Model Loading Failed**")
|
250 |
+
st.info("**Possible solutions:**")
|
251 |
+
st.markdown("1. **Git LFS issue**: Run `git lfs pull` to download the actual model file")
|
252 |
+
st.markdown("2. **Architecture mismatch**: The model was trained with different parameters")
|
253 |
+
st.markdown("3. **Corrupted file**: Re-download or re-train the model")
|
254 |
+
st.markdown("4. **Check the console/logs** for detailed error information")
|
255 |
st.stop()
|
256 |
|
257 |
# Transform for preprocessing (same as training)
|
|
|
284 |
probabilities = torch.nn.functional.softmax(output[0], dim=0)
|
285 |
confidence, pred = torch.max(probabilities, 0)
|
286 |
|
287 |
+
# Get top 3 predictions for better debugging
|
288 |
+
top_probs, top_indices = torch.topk(probabilities, min(3, len(probabilities)))
|
289 |
+
|
290 |
if pred.item() < len(class_names):
|
291 |
predicted_class = class_names[pred.item()]
|
292 |
else:
|
293 |
predicted_class = f"Class_{pred.item()}"
|
294 |
|
295 |
+
# Debug info
|
296 |
+
print(f"Top predictions:")
|
297 |
+
for i, (prob, idx) in enumerate(zip(top_probs, top_indices)):
|
298 |
+
class_name = class_names[idx.item()] if idx.item() < len(class_names) else f"Class_{idx.item()}"
|
299 |
+
print(f" {i+1}. {class_name}: {prob.item():.3f}")
|
300 |
+
|
301 |
return predicted_class, confidence.item()
|
302 |
|
303 |
except Exception as e:
|
|
|
308 |
st.title("🦋 Butterfly Identifier / Liblikamaja ID")
|
309 |
st.write("Identify butterflies using your camera or by uploading an image!")
|
310 |
|
311 |
+
# Show model info
|
312 |
+
if model is not None:
|
313 |
+
st.info(f"📊 Model loaded: {len(class_names)} butterfly species recognized")
|
314 |
+
|
315 |
# Create tabs for different input methods
|
316 |
tab1, tab2 = st.tabs(["📷 Live Camera", "📁 Upload Image"])
|
317 |
|
|
|
331 |
st.image(image, caption="Captured Image", use_column_width=True)
|
332 |
|
333 |
with col2:
|
334 |
+
with st.spinner("Analyzing image..."):
|
335 |
+
predicted_class, confidence = predict_butterfly(image)
|
336 |
|
337 |
+
if predicted_class and confidence:
|
338 |
+
if confidence >= 0.80:
|
339 |
+
st.success(f"**Prediction: {predicted_class}**")
|
340 |
+
st.info(f"Confidence: {confidence:.2%}")
|
341 |
+
|
342 |
+
# Show additional info if available
|
343 |
+
if predicted_class in butterfly_info:
|
344 |
+
st.write(f"**About:** {butterfly_info[predicted_class]['description']}")
|
345 |
+
else:
|
346 |
+
st.warning("⚠️ **Low confidence prediction**")
|
347 |
+
st.info(f"Best guess: {predicted_class} ({confidence:.1%})")
|
348 |
+
st.markdown("**Tips for better results:**")
|
349 |
+
st.markdown("- Use better lighting")
|
350 |
+
st.markdown("- Get closer to the butterfly")
|
351 |
+
st.markdown("- Ensure the butterfly is clearly visible")
|
352 |
+
st.markdown("- Avoid blurry or dark images")
|
353 |
else:
|
354 |
+
st.error("Unable to analyze image. Please try again.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
355 |
|
356 |
except Exception as e:
|
357 |
st.error(f"Error processing image: {str(e)}")
|
|
|
377 |
st.image(image, caption="Uploaded Image", use_column_width=True)
|
378 |
|
379 |
with col2:
|
380 |
+
with st.spinner("Analyzing image..."):
|
381 |
+
predicted_class, confidence = predict_butterfly(image)
|
382 |
|
383 |
+
if predicted_class and confidence:
|
384 |
+
if confidence >= 0.80:
|
385 |
+
st.success(f"**Prediction: {predicted_class}**")
|
386 |
+
st.info(f"Confidence: {confidence:.2%}")
|
387 |
+
|
388 |
+
# Show additional info if available
|
389 |
+
if predicted_class in butterfly_info:
|
390 |
+
st.write(f"**About:** {butterfly_info[predicted_class]['description']}")
|
391 |
+
else:
|
392 |
+
st.warning("⚠️ **Low confidence prediction**")
|
393 |
+
st.info(f"Best guess: {predicted_class} ({confidence:.1%})")
|
394 |
+
st.markdown("**Tips for better results:**")
|
395 |
+
st.markdown("- Use better lighting")
|
396 |
+
st.markdown("- Get closer to the butterfly")
|
397 |
+
st.markdown("- Ensure the butterfly is clearly visible")
|
398 |
+
st.markdown("- Avoid blurry or dark images")
|
399 |
else:
|
400 |
+
st.error("Unable to analyze image. Please try again.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
401 |
|
402 |
except Exception as e:
|
403 |
st.error(f"Error processing image: {str(e)}")
|
|
|
407 |
st.markdown("### How to use:")
|
408 |
st.markdown("1. **Camera Capture**: Take a photo using your device camera")
|
409 |
st.markdown("2. **Upload Image**: Choose a butterfly photo from your device")
|
410 |
+
st.markdown("3. **Best Results**: Use clear, well-lit photos with the butterfly clearly visible")
|
411 |
+
|
412 |
+
# Debug info (only show if there are issues)
|
413 |
+
if st.checkbox("Show debug information"):
|
414 |
+
st.markdown("### Debug Information")
|
415 |
+
st.write(f"Number of classes: {len(class_names)}")
|
416 |
+
st.write(f"Model loaded: {model is not None}")
|
417 |
+
if model:
|
418 |
+
st.write("Model architecture successfully detected and loaded")
|
419 |
+
else:
|
420 |
+
st.write("❌ Model failed to load - check console for details")
|