leynessa commited on
Commit
e11e26a
·
verified ·
1 Parent(s): 66ead19

Update streamlit_app.py

Browse files
Files changed (1) hide show
  1. streamlit_app.py +55 -292
streamlit_app.py CHANGED
@@ -1,18 +1,20 @@
 
1
  import streamlit as st
2
  from PIL import Image
3
  import torch
4
- from torchvision import models, transforms
5
  import json
6
  import os
7
  import io
8
  import numpy as np
9
  import timm
 
 
10
  import warnings
11
  warnings.filterwarnings('ignore')
12
 
13
  # Configure Streamlit
14
  st.set_page_config(
15
- page_title="Butterfly Identifier/ Liblikamaja ID",
16
  page_icon="🦋",
17
  layout="wide"
18
  )
@@ -40,360 +42,121 @@ def load_butterfly_info():
40
 
41
  butterfly_info = load_butterfly_info()
42
 
43
- def detect_model_architecture(model_state_dict):
44
- """Detect which EfficientNet variant was used based on layer shapes"""
45
-
46
- # Check the classifier input features to determine model type
47
- if 'classifier.weight' in model_state_dict:
48
- classifier_input_features = model_state_dict['classifier.weight'].shape[1]
49
-
50
- # EfficientNet feature mapping (updated based on actual timm implementations)
51
- efficientnet_features = {
52
- 1280: 'efficientnet_b0',
53
- 1408: 'efficientnet_b1',
54
- 1536: 'efficientnet_b2',
55
- 1792: 'efficientnet_b3', # Your model has 1792 features
56
- 1920: 'efficientnet_b4',
57
- 2048: 'efficientnet_b5',
58
- 2304: 'efficientnet_b6',
59
- 2560: 'efficientnet_b7'
60
- }
61
-
62
- model_name = efficientnet_features.get(classifier_input_features, 'efficientnet_b3')
63
- print(f"Detected model architecture: {model_name} (classifier features: {classifier_input_features})")
64
- return model_name, classifier_input_features
65
-
66
- # Check bn2 layer (final batch norm before classifier)
67
- if 'bn2.weight' in model_state_dict:
68
- bn2_features = model_state_dict['bn2.weight'].shape[0]
69
- print(f"bn2 features: {bn2_features}")
70
-
71
- # This should match the classifier input features
72
- bn2_mapping = {
73
- 1280: 'efficientnet_b0',
74
- 1408: 'efficientnet_b1',
75
- 1536: 'efficientnet_b2',
76
- 1792: 'efficientnet_b3', # Your model shows 1792
77
- 1920: 'efficientnet_b4',
78
- 2048: 'efficientnet_b5',
79
- 2304: 'efficientnet_b6',
80
- 2560: 'efficientnet_b7'
81
- }
82
-
83
- model_name = bn2_mapping.get(bn2_features, 'efficientnet_b3')
84
- print(f"Detected model architecture from bn2: {model_name}")
85
- return model_name, bn2_features
86
-
87
- # Default to B3 based on your error logs
88
- print("Could not detect model architecture, defaulting to efficientnet_b3")
89
- return 'efficientnet_b3', 1792
90
 
 
91
  @st.cache_resource
92
  def load_model():
93
  MODEL_PATH = "butterfly_classifier.pth"
94
-
95
- # Check if model file exists
96
  if not os.path.exists(MODEL_PATH):
97
- st.error(f"Model file '{MODEL_PATH}' not found!")
98
- return None
99
-
100
- # Check file size
101
- file_size = os.path.getsize(MODEL_PATH)
102
- if file_size < 1000: # Less than 1KB suggests LFS pointer file
103
- st.error(f"""
104
- 🚨 **Git LFS Issue Detected**
105
-
106
- The model file appears to be a Git LFS pointer file (size: {file_size} bytes).
107
- This means the actual model wasn't downloaded properly.
108
-
109
- **To fix this:**
110
- 1. Run: `git lfs pull` in your repository
111
- 2. Or download the model file directly from your storage
112
- """)
113
  return None
114
-
115
  try:
116
- # Load checkpoint
117
- print(f"Loading model from {MODEL_PATH} (size: {file_size} bytes)")
118
  checkpoint = torch.load(MODEL_PATH, map_location='cpu')
119
-
120
- # Extract model components
121
- if 'model_state_dict' in checkpoint:
122
- model_state_dict = checkpoint['model_state_dict']
123
- saved_class_names = checkpoint.get('class_names', class_names)
124
- else:
125
- model_state_dict = checkpoint
126
- saved_class_names = class_names
127
-
128
- # Debug: Print some key layer shapes
129
- print("Key layer shapes in checkpoint:")
130
- for key in ['conv_stem.weight', 'bn1.weight', 'bn2.weight', 'classifier.weight']:
131
- if key in model_state_dict:
132
- print(f" {key}: {model_state_dict[key].shape}")
133
-
134
- # Auto-detect the correct model architecture
135
- model_name, expected_features = detect_model_architecture(model_state_dict)
136
-
137
- # Get number of classes
138
  num_classes = len(class_names)
139
- if 'classifier.weight' in model_state_dict:
140
- num_classes = model_state_dict['classifier.weight'].shape[0]
141
-
142
- print(f"Loading {model_name} with {num_classes} classes")
143
-
144
- # Create model with correct architecture
145
- # Try different parameter combinations that might have been used during training
146
- model_configs = [
147
- # Most likely configuration based on your checkpoint
148
- {'drop_rate': 0.4, 'drop_path_rate': 0.3},
149
- {'drop_rate': 0.3, 'drop_path_rate': 0.2},
150
- {'drop_rate': 0.2, 'drop_path_rate': 0.1},
151
- {'drop_rate': 0.0, 'drop_path_rate': 0.0}, # Default
152
- ]
153
-
154
- model = None
155
- for config in model_configs:
156
- try:
157
- print(f"Trying model config: {config}")
158
- model = timm.create_model(
159
- model_name,
160
- pretrained=False,
161
- num_classes=num_classes,
162
- **config
163
- )
164
-
165
- # Try loading with strict=True first
166
- model.load_state_dict(model_state_dict, strict=True)
167
- print(f"Model loaded successfully with config: {config}")
168
- break
169
-
170
- except RuntimeError as e:
171
- print(f"Config {config} failed: {e}")
172
- continue
173
-
174
- # If strict loading failed for all configs, try with strict=False
175
- if model is None:
176
- print("All strict loading attempts failed, trying with strict=False")
177
- model = timm.create_model(
178
- model_name,
179
- pretrained=False,
180
- num_classes=num_classes,
181
- drop_rate=0.4,
182
- drop_path_rate=0.3
183
- )
184
-
185
- missing_keys, unexpected_keys = model.load_state_dict(model_state_dict, strict=False)
186
-
187
- if missing_keys:
188
- print(f"Missing keys: {missing_keys}")
189
- st.warning(f"⚠️ Some model weights were not loaded: {len(missing_keys)} missing keys")
190
-
191
- if unexpected_keys:
192
- print(f"Unexpected keys: {unexpected_keys}")
193
- st.warning(f"⚠️ Some checkpoint keys were not used: {len(unexpected_keys)} unexpected keys")
194
-
195
- # Verify the model loaded correctly
196
  model.eval()
197
-
198
- # Test with a dummy input to make sure everything works
199
- dummy_input = torch.randn(1, 3, 224, 224)
200
- with torch.no_grad():
201
- try:
202
- dummy_output = model(dummy_input)
203
- print(f"Model test successful. Output shape: {dummy_output.shape}")
204
-
205
- # Verify output shape matches expected classes
206
- if dummy_output.shape[1] != num_classes:
207
- st.error(f"Model output mismatch: expected {num_classes} classes, got {dummy_output.shape[1]}")
208
- return None
209
-
210
- except Exception as e:
211
- print(f"Model test failed: {e}")
212
- st.error(f"Model validation failed: {e}")
213
- return None
214
-
215
- #st.success(f"✅ Successfully loaded {model_name} with {num_classes} classes")
216
  return model
217
-
218
  except Exception as e:
219
  st.error(f"Error loading model: {str(e)}")
220
- st.error(f"Model file size: {file_size} bytes")
221
-
222
- # Additional debugging info
223
- try:
224
- checkpoint = torch.load(MODEL_PATH, map_location='cpu')
225
- if 'model_state_dict' in checkpoint:
226
- model_keys = list(checkpoint['model_state_dict'].keys())
227
- print(f"Available keys in checkpoint: {model_keys[:10]}...")
228
-
229
- # Show the problematic layer shapes
230
- state_dict = checkpoint['model_state_dict']
231
- if 'classifier.weight' in state_dict and 'bn2.weight' in state_dict:
232
- classifier_features = state_dict['classifier.weight'].shape[1]
233
- bn2_features = state_dict['bn2.weight'].shape[0]
234
- print(f"Classifier input features: {classifier_features}")
235
- print(f"bn2 features: {bn2_features}")
236
-
237
- if classifier_features != bn2_features:
238
- st.error(f"Architecture mismatch: classifier expects {classifier_features} features, but bn2 has {bn2_features}")
239
-
240
- except Exception as debug_e:
241
- print(f"Debug info failed: {debug_e}")
242
-
243
  return None
244
 
245
- # Load the model
246
  model = load_model()
247
 
248
- if model is None:
249
- st.error("⚠️ **Model Loading Failed**")
250
- st.info("**Possible solutions:**")
251
- st.markdown("1. **Git LFS issue**: Run `git lfs pull` to download the actual model file")
252
- st.markdown("2. **Architecture mismatch**: The model was trained with different parameters")
253
- st.markdown("3. **Corrupted file**: Re-download or re-train the model")
254
- st.markdown("4. **Check the console/logs** for detailed error information")
255
- st.stop()
256
-
257
- # Transform for preprocessing (same as training)
258
- transform = transforms.Compose([
259
- transforms.Resize((224, 224)),
260
- transforms.ToTensor(),
261
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
262
- ])
263
-
264
  def predict_butterfly(image, threshold=0.5):
265
- """Predict butterfly species from image"""
266
  try:
267
  if image is None:
268
  return None, None
269
-
270
  if isinstance(image, np.ndarray):
271
  image = Image.fromarray(image)
272
  if image.mode != 'RGB':
273
  image = image.convert('RGB')
274
-
275
- input_tensor = transform(image).unsqueeze(0)
276
-
277
  with torch.no_grad():
278
  output = model(input_tensor)
279
  probabilities = torch.nn.functional.softmax(output[0], dim=0)
280
-
281
  confidence, pred = torch.max(probabilities, 0)
282
  if confidence.item() < threshold:
283
  return None, confidence.item()
284
-
285
  predicted_class = class_names[pred.item()]
286
  return predicted_class, confidence.item()
287
-
288
  except Exception as e:
289
  st.error(f"Prediction error: {str(e)}")
290
  return None, None
291
 
292
-
293
  # UI Code
294
- st.title("🦋 Liblikamaja ID/ Butterfly Identifier")
295
- st.write("Tuvasta liblikaid oma kaamera abil või laadi üles pilt!/ Identify butterflies using your camera or by uploading an image!")
296
-
297
- # Show model info
298
- #if model is not None:
299
- #st.info(f"📊 Model loaded: {len(class_names)} butterfly species recognized")
300
 
301
- # Create tabs for different input methods
302
- tab1, tab2 = st.tabs(["📷 Live Camera", "📁 Upload Image"])
303
 
304
  with tab1:
305
- st.header("Kaamera jäädvustamine/ Camera Capture")
306
- st.write("Take a photo of a butterfly for identification!")
307
-
308
- camera_photo = st.camera_input("Pildista liblikat/ Take a picture of a butterfly")
309
-
310
  if camera_photo is not None:
311
  try:
312
  image = Image.open(camera_photo).convert("RGB")
313
-
314
  col1, col2 = st.columns(2)
315
-
316
  with col1:
317
- st.image(image, caption="Captured Image/Jäädvustatud pilt", use_column_width=True)
318
-
319
  with col2:
320
- with st.spinner("Kujutise analüüsimine..."):
321
  predicted_class, confidence = predict_butterfly(image)
322
-
323
  if predicted_class and confidence >= 0.50:
324
- st.success(f"**Liblika: {predicted_class}**")
325
- #st.info(f"Confidence: {confidence:.2%}")
 
 
326
  else:
327
- st.warning("⚠️ **I don't know what butterfly this is.**")
328
- st.warning("⚠️ **Low confidence prediction/ Madala usaldusväärsusega ennustus**")
329
- #st.info(f"Best guess: {predicted_class} ({confidence:.1%})")
330
- st.markdown("**Tips for better results/ Näpunäited paremate tulemuste saavutamiseks:**")
331
- st.markdown("- Use better lighting/ Kasutage paremat valgustust")
332
- st.markdown("- Ensure the butterfly is clearly visible/ Veenduge, et liblikas oleks selgelt nähtav")
333
- st.markdown("- Avoid blurry or dark images/ Vältige uduseid või tumedaid pilte")
334
-
335
-
336
-
337
  except Exception as e:
338
  st.error(f"Error processing image: {str(e)}")
339
 
340
  with tab2:
341
- st.header("Upload Image/ Laadi pilt üles")
342
- st.write("Upload a clear photo of a butterfly for identification/ Laadige üles liblika selge foto tuvastamiseks")
343
-
344
- uploaded_file = st.file_uploader(
345
- "Choose an image/ Valige pilt...",
346
- type=["jpg", "jpeg", "png"],
347
- help="Upload a clear photo of a butterfly/ Laadi üles selge foto liblikast"
348
- )
349
-
350
  if uploaded_file is not None:
351
  try:
352
  image_bytes = uploaded_file.read()
353
  image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
354
-
355
  col1, col2 = st.columns(2)
356
-
357
  with col1:
358
- st.image(image, caption="Uploaded Image", use_column_width=True)
359
-
360
  with col2:
361
- with st.spinner("Analyzing image..."):
362
  predicted_class, confidence = predict_butterfly(image)
363
-
364
- if predicted_class and confidence:
365
- if confidence >= 0.50:
366
- st.success(f"**liblikas: {predicted_class}**")
367
- #st.info(f"Confidence: {confidence:.2%}")
368
-
369
- # Show additional info if available
370
- #if predicted_class in butterfly_info:
371
- # st.write(f"**About:** {butterfly_info[predicted_class]['description']}")
372
- else:
373
- st.warning(" **Low confidence prediction/ Madala usaldusväärsusega ennustus**")
374
- #st.info(f"Best guess: {predicted_class} ({confidence:.1%})")
375
- st.markdown("**Tips for better results/ Näpunäited paremate tulemuste saavutamiseks:**")
376
- st.markdown("- Use better lighting/ Kasutage paremat valgustust")
377
- st.markdown("- Ensure the butterfly is clearly visible/ Veenduge, et liblikas oleks selgelt nähtav")
378
- st.markdown("- Avoid blurry or dark images/ Vältige uduseid või tumedaid pilte")
379
-
380
-
381
  except Exception as e:
382
  st.error(f"Error processing image: {str(e)}")
383
 
384
- # Add footer
385
  st.markdown("---")
386
- st.markdown("### How to use/ Kuidas kasutada:")
387
- st.markdown("1. **Camera Capture**: Take a photo using your device camera/ **Kaamera jäädvustamine**: Tehke foto oma seadme kaameraga")
388
- st.markdown("2. **Upload Image**: Choose a butterfly photo from your device/ **Laadi pilt üles**: Vali oma seadmest liblika foto")
389
- st.markdown("3. **Best Results**: Use clear, well-lit photos with the butterfly clearly visible/ **Parimad tulemused**: Kasutage selgeid ja hästi valgustatud fotosid, kus liblikas on selgelt nähtav.")
390
 
391
- # Debug info (only show if there are issues)
392
- #if st.checkbox("Show debug information"):
393
- # st.markdown("### Debug Information")
394
- # st.write(f"Number of classes: {len(class_names)}")
395
- # st.write(f"Model loaded: {model is not None}")
396
- # if model:
397
- # st.write("Model architecture successfully detected and loaded")
398
- # else:
399
- # st.write("❌ Model failed to load - check console for details")
 
1
+ # Full corrected bilingual Streamlit app for Butterfly Identifier
2
  import streamlit as st
3
  from PIL import Image
4
  import torch
 
5
  import json
6
  import os
7
  import io
8
  import numpy as np
9
  import timm
10
+ import albumentations as A
11
+ from albumentations.pytorch import ToTensorV2
12
  import warnings
13
  warnings.filterwarnings('ignore')
14
 
15
  # Configure Streamlit
16
  st.set_page_config(
17
+ page_title="Butterfly Identifier / Liblikamaja ID",
18
  page_icon="🦋",
19
  layout="wide"
20
  )
 
42
 
43
  butterfly_info = load_butterfly_info()
44
 
45
+ # Define transform matching training pipeline
46
+ inference_transform = A.Compose([
47
+ A.Resize(224, 224),
48
+ A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
49
+ ToTensorV2()
50
+ ])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
+ # Load the model
53
  @st.cache_resource
54
  def load_model():
55
  MODEL_PATH = "butterfly_classifier.pth"
 
 
56
  if not os.path.exists(MODEL_PATH):
57
+ st.error("Model file not found!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  return None
 
59
  try:
 
 
60
  checkpoint = torch.load(MODEL_PATH, map_location='cpu')
61
+ model_state_dict = checkpoint['model_state_dict'] if 'model_state_dict' in checkpoint else checkpoint
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  num_classes = len(class_names)
63
+ model = timm.create_model('efficientnet_b3', pretrained=False, num_classes=num_classes, drop_rate=0.4, drop_path_rate=0.3)
64
+ model.load_state_dict(model_state_dict, strict=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  return model
 
67
  except Exception as e:
68
  st.error(f"Error loading model: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  return None
70
 
 
71
  model = load_model()
72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  def predict_butterfly(image, threshold=0.5):
 
74
  try:
75
  if image is None:
76
  return None, None
 
77
  if isinstance(image, np.ndarray):
78
  image = Image.fromarray(image)
79
  if image.mode != 'RGB':
80
  image = image.convert('RGB')
81
+ transformed = inference_transform(image=np.array(image))
82
+ input_tensor = transformed['image'].unsqueeze(0)
 
83
  with torch.no_grad():
84
  output = model(input_tensor)
85
  probabilities = torch.nn.functional.softmax(output[0], dim=0)
 
86
  confidence, pred = torch.max(probabilities, 0)
87
  if confidence.item() < threshold:
88
  return None, confidence.item()
 
89
  predicted_class = class_names[pred.item()]
90
  return predicted_class, confidence.item()
 
91
  except Exception as e:
92
  st.error(f"Prediction error: {str(e)}")
93
  return None, None
94
 
 
95
  # UI Code
96
+ st.title("🦋 Liblikamaja ID / Butterfly Identifier")
97
+ st.write("Tuvasta liblikaid oma kaamera abil või laadi üles pilt! / Identify butterflies using your camera or by uploading an image!")
 
 
 
 
98
 
99
+ tab1, tab2 = st.tabs(["📷 Live Camera / Kaamera", "📁 Upload Image / Laadi üles"])
 
100
 
101
  with tab1:
102
+ st.header("Kaamera jäädvustamine / Camera Capture")
103
+ st.write("Tee pilt liblikast tuvastamiseks / Take a photo of a butterfly for identification.")
104
+ camera_photo = st.camera_input("Pildista liblikat / Capture a butterfly")
 
 
105
  if camera_photo is not None:
106
  try:
107
  image = Image.open(camera_photo).convert("RGB")
 
108
  col1, col2 = st.columns(2)
 
109
  with col1:
110
+ st.image(image, caption="Jäädvustatud pilt / Captured Image", use_column_width=True)
 
111
  with col2:
112
+ with st.spinner("Pildi analüüsimine... / Analyzing image..."):
113
  predicted_class, confidence = predict_butterfly(image)
 
114
  if predicted_class and confidence >= 0.50:
115
+ st.success(f"**Liblikas / Butterfly: {predicted_class}**")
116
+ if predicted_class in butterfly_info:
117
+ st.markdown("**Liigi kirjeldus / About this species:**")
118
+ st.write(butterfly_info[predicted_class]["description"])
119
  else:
120
+ st.warning("⚠️ Ma ei tea, mis liblikas see on / I don't know what butterfly this is")
121
+ st.markdown("**Näpunäited paremate tulemuste saavutamiseks / Tips for better results:**")
122
+ st.markdown("- Kasutage paremat valgustust / Use better lighting")
123
+ st.markdown("- Veenduge, et liblikas oleks selgelt nähtav / Ensure the butterfly is clearly visible")
124
+ st.markdown("- Vältige uduseid või tumedaid pilte / Avoid blurry or dark images")
 
 
 
 
 
125
  except Exception as e:
126
  st.error(f"Error processing image: {str(e)}")
127
 
128
  with tab2:
129
+ st.header("Laadi üles pilt / Upload Image")
130
+ st.write("Laadige üles liblika selge foto tuvastamiseks / Upload a clear photo of a butterfly for identification.")
131
+ uploaded_file = st.file_uploader("Vali pilt... / Choose an image...", type=["jpg", "jpeg", "png"])
 
 
 
 
 
 
132
  if uploaded_file is not None:
133
  try:
134
  image_bytes = uploaded_file.read()
135
  image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
 
136
  col1, col2 = st.columns(2)
 
137
  with col1:
138
+ st.image(image, caption="Üleslaetud pilt / Uploaded Image", use_column_width=True)
 
139
  with col2:
140
+ with st.spinner("Pildi analüüsimine... / Analyzing image..."):
141
  predicted_class, confidence = predict_butterfly(image)
142
+ if predicted_class and confidence >= 0.50:
143
+ st.success(f"**Liblikas / Butterfly: {predicted_class}**")
144
+ if predicted_class in butterfly_info:
145
+ st.markdown("**Liigi kirjeldus / About this species:**")
146
+ st.write(butterfly_info[predicted_class]["description"])
147
+ else:
148
+ st.warning("⚠️ Ma ei tea, mis liblikas see on / I don't know what butterfly this is")
149
+ st.markdown("**Näpunäited paremate tulemuste saavutamiseks / Tips for better results:**")
150
+ st.markdown("- Kasutage paremat valgustust / Use better lighting")
151
+ st.markdown("- Veenduge, et liblikas oleks selgelt nähtav / Ensure the butterfly is clearly visible")
152
+ st.markdown("- Vältige uduseid või tumedaid pilte / Avoid blurry or dark images")
 
 
 
 
 
 
 
153
  except Exception as e:
154
  st.error(f"Error processing image: {str(e)}")
155
 
156
+ # Footer
157
  st.markdown("---")
158
+ st.markdown("### Kuidas kasutada / How to use:")
159
+ st.markdown("1. **Kaamera jäädvustamine / Camera Capture**: Tehke foto oma seadme kaameraga / Take a photo using your device camera")
160
+ st.markdown("2. **Laadi pilt üles / Upload Image**: Vali oma seadmest liblika foto / Choose a butterfly photo from your device")
161
+ st.markdown("3. **Parimad tulemused / Best Results**: Kasutage selgeid ja hästi valgustatud fotosid, kus liblikas on selgelt nähtav / Use clear, well-lit photos with the butterfly clearly visible")
162