leynessa commited on
Commit
dbb62b2
·
verified ·
1 Parent(s): 9d53d92

Update streamlit_app.py

Browse files
Files changed (1) hide show
  1. streamlit_app.py +363 -43
streamlit_app.py CHANGED
@@ -1,4 +1,4 @@
1
- # Full corrected bilingual Streamlit app for Butterfly Identifier
2
  import streamlit as st
3
  from PIL import Image
4
  import torch
@@ -49,136 +49,448 @@ inference_transform = A.Compose([
49
  ToTensorV2()
50
  ])
51
 
52
- # Load the model
53
  @st.cache_resource
54
  def load_model():
55
- MODEL_PATH = "butterfly_classifier.pth"
56
- if not os.path.exists(MODEL_PATH):
57
- st.error("Model file not found!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  return None
 
 
 
59
  try:
 
60
  checkpoint = torch.load(MODEL_PATH, map_location='cpu')
61
- model_state_dict = checkpoint['model_state_dict'] if 'model_state_dict' in checkpoint else checkpoint
62
- num_classes = len(class_names)
63
-
64
- # Attempt to auto-detect model from batch norm layer dimensions
65
- bn2_shape = None
66
- for key in model_state_dict:
67
- if key.endswith("bn2.weight"):
68
- bn2_shape = model_state_dict[key].shape[0]
69
- break
70
-
71
- feature_map = {
72
- 1280: 'efficientnet_b0',
73
- 1408: 'efficientnet_b1',
74
- 1536: 'efficientnet_b2',
75
- 1792: 'efficientnet_b3',
76
- 1920: 'efficientnet_b4',
77
- 2048: 'efficientnet_b5',
78
- 2304: 'efficientnet_b6',
79
- 2560: 'efficientnet_b7'
80
- }
81
-
82
- if bn2_shape is None:
83
- st.warning("Could not detect classifier or bn2 layer in checkpoint. Defaulting to efficientnet_b3")
84
- model_name = 'efficientnet_b3'
85
  else:
86
- model_name = feature_map.get(bn2_shape, 'efficientnet_b3')
87
- st.info(f"Detected model architecture: {model_name}")
88
-
89
- model = timm.create_model(model_name, pretrained=False, num_classes=num_classes, drop_rate=0.4, drop_path_rate=0.3)
90
- model.load_state_dict(model_state_dict, strict=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  model.eval()
 
 
 
 
 
 
 
 
92
  return model
 
93
  except Exception as e:
94
- st.error(f"Error loading model: {str(e)}")
 
95
  return None
96
 
 
97
  model = load_model()
98
 
99
  def predict_butterfly(image, threshold=0.5):
 
100
  try:
101
  if model is None:
102
  raise ValueError("Model is not loaded.")
103
  if image is None:
104
  return None, None
 
 
105
  if isinstance(image, np.ndarray):
106
  image = Image.fromarray(image)
107
  if image.mode != 'RGB':
108
  image = image.convert('RGB')
 
 
109
  transformed = inference_transform(image=np.array(image))
110
  input_tensor = transformed['image'].unsqueeze(0)
 
 
111
  with torch.no_grad():
112
  output = model(input_tensor)
113
  probabilities = torch.nn.functional.softmax(output[0], dim=0)
114
  confidence, pred = torch.max(probabilities, 0)
 
115
  if confidence.item() < threshold:
116
  return None, confidence.item()
 
117
  predicted_class = class_names[pred.item()]
118
  return predicted_class, confidence.item()
 
119
  except Exception as e:
120
  st.error(f"Prediction error: {str(e)}")
121
  return None, None
122
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  # UI Code
124
- st.title("🦋 Liblikamaja ID / Butterfly Identifier")
125
  st.write("Tuvasta liblikaid oma kaamera abil või laadi üles pilt! / Identify butterflies using your camera or by uploading an image!")
126
 
127
- tab1, tab2 = st.tabs(["📷 Live Camera / Kaamera", "📁 Upload Image / Laadi üles"])
 
 
 
 
 
128
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
 
130
  with tab1:
131
  st.header("Kaamera jäädvustamine / Camera Capture")
132
  st.write("Tee pilt liblikast tuvastamiseks / Take a photo of a butterfly for identification.")
 
133
  camera_photo = st.camera_input("Pildista liblikat / Capture a butterfly")
 
134
  if camera_photo is not None:
135
  try:
136
  image = Image.open(camera_photo).convert("RGB")
137
  col1, col2 = st.columns(2)
 
138
  with col1:
139
  st.image(image, caption="Jäädvustatud pilt / Captured Image", use_column_width=True)
 
140
  with col2:
141
  with st.spinner("Pildi analüüsimine... / Analyzing image..."):
142
- predicted_class, confidence = predict_butterfly(image)
143
- if predicted_class and confidence >= 0.50:
 
 
 
 
144
  st.success(f"**Liblikas / Butterfly: {predicted_class}**")
 
 
145
  if predicted_class in butterfly_info:
146
  st.markdown("**Liigi kirjeldus / About this species:**")
147
  st.write(butterfly_info[predicted_class]["description"])
 
 
148
  else:
149
- st.warning("⚠️ Ma ei tea, mis liblikas see on / I don't know what butterfly this is")
 
150
  st.markdown("**Näpunäited paremate tulemuste saavutamiseks / Tips for better results:**")
151
  st.markdown("- Kasutage paremat valgustust / Use better lighting")
152
  st.markdown("- Veenduge, et liblikas oleks selgelt nähtav / Ensure the butterfly is clearly visible")
153
  st.markdown("- Vältige uduseid või tumedaid pilte / Avoid blurry or dark images")
 
 
154
  except Exception as e:
155
  st.error(f"Error processing image: {str(e)}")
156
 
157
  with tab2:
158
  st.header("Laadi üles pilt / Upload Image")
159
  st.write("Laadige üles liblika selge foto tuvastamiseks / Upload a clear photo of a butterfly for identification.")
 
160
  uploaded_file = st.file_uploader("Vali pilt... / Choose an image...", type=["jpg", "jpeg", "png"])
 
161
  if uploaded_file is not None:
162
  try:
163
  image_bytes = uploaded_file.read()
164
  image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
165
  col1, col2 = st.columns(2)
 
166
  with col1:
167
  st.image(image, caption="Üleslaetud pilt / Uploaded Image", use_column_width=True)
 
168
  with col2:
169
  with st.spinner("Pildi analüüsimine... / Analyzing image..."):
170
- predicted_class, confidence = predict_butterfly(image)
171
- if predicted_class and confidence >= 0.50:
 
 
 
 
172
  st.success(f"**Liblikas / Butterfly: {predicted_class}**")
 
 
173
  if predicted_class in butterfly_info:
174
  st.markdown("**Liigi kirjeldus / About this species:**")
175
  st.write(butterfly_info[predicted_class]["description"])
 
 
176
  else:
177
- st.warning("⚠️ Ma ei tea, mis liblikas see on / I don't know what butterfly this is")
 
178
  st.markdown("**Näpunäited paremate tulemuste saavutamiseks / Tips for better results:**")
179
  st.markdown("- Kasutage paremat valgustust / Use better lighting")
180
  st.markdown("- Veenduge, et liblikas oleks selgelt nähtav / Ensure the butterfly is clearly visible")
181
  st.markdown("- Vältige uduseid või tumedaid pilte / Avoid blurry or dark images")
 
 
182
  except Exception as e:
183
  st.error(f"Error processing image: {str(e)}")
184
 
@@ -188,4 +500,12 @@ st.markdown("### Kuidas kasutada / How to use:")
188
  st.markdown("1. **Kaamera jäädvustamine / Camera Capture**: Tehke foto oma seadme kaameraga / Take a photo using your device camera")
189
  st.markdown("2. **Laadi pilt üles / Upload Image**: Vali oma seadmest liblika foto / Choose a butterfly photo from your device")
190
  st.markdown("3. **Parimad tulemused / Best Results**: Kasutage selgeid ja hästi valgustatud fotosid, kus liblikas on selgelt nähtav / Use clear, well-lit photos with the butterfly clearly visible")
 
191
 
 
 
 
 
 
 
 
 
1
+ # Enhanced Butterfly Identifier Streamlit App with Better Model Loading
2
  import streamlit as st
3
  from PIL import Image
4
  import torch
 
49
  ToTensorV2()
50
  ])
51
 
52
+ # Enhanced model loading function
53
  @st.cache_resource
54
  def load_model():
55
+ """Enhanced model loading with architecture detection and fallback options"""
56
+
57
+ # Try different model file names
58
+ model_files = [
59
+ "butterfly_classifier.pth",
60
+ "best_butterfly_model_v3.pth",
61
+ "best_butterfly_model.pth"
62
+ ]
63
+
64
+ MODEL_PATH = None
65
+ for model_file in model_files:
66
+ if os.path.exists(model_file):
67
+ MODEL_PATH = model_file
68
+ break
69
+
70
+ if MODEL_PATH is None:
71
+ st.error("No model file found! Please ensure one of these files exists: " + ", ".join(model_files))
72
  return None
73
+
74
+ st.info(f"Loading model from: {MODEL_PATH}")
75
+
76
  try:
77
+ # Load checkpoint
78
  checkpoint = torch.load(MODEL_PATH, map_location='cpu')
79
+
80
+ # Extract model state dict
81
+ if 'model_state_dict' in checkpoint:
82
+ model_state_dict = checkpoint['model_state_dict']
83
+ if 'class_names' in checkpoint:
84
+ st.info(f"Model trained on {len(checkpoint['class_names'])} classes")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  else:
86
+ model_state_dict = checkpoint
87
+
88
+ num_classes = len(class_names)
89
+
90
+ # Architecture detection based on model state dict
91
+ def detect_model_architecture(state_dict):
92
+ """Detect model architecture from state dict"""
93
+
94
+ # Check for EfficientNet variants by looking at key layer dimensions
95
+ architecture_indicators = {
96
+ 'conv_head.weight': 'efficientnet',
97
+ 'head.weight': 'efficientnet_v2',
98
+ 'classifier.weight': 'other'
99
+ }
100
+
101
+ # Look for specific layer patterns
102
+ for key in state_dict.keys():
103
+ if 'conv_head.weight' in key:
104
+ shape = state_dict[key].shape
105
+ if len(shape) >= 2:
106
+ feature_dim = shape[1]
107
+ # EfficientNet feature dimensions
108
+ efficientnet_map = {
109
+ 1280: 'efficientnet_b0',
110
+ 1408: 'efficientnet_b1',
111
+ 1536: 'efficientnet_b2',
112
+ 1792: 'efficientnet_b3',
113
+ 1920: 'efficientnet_b4',
114
+ 2048: 'efficientnet_b5',
115
+ 2304: 'efficientnet_b6',
116
+ 2560: 'efficientnet_b7'
117
+ }
118
+ return efficientnet_map.get(feature_dim, 'efficientnet_b3')
119
+
120
+ if 'head.weight' in key:
121
+ shape = state_dict[key].shape
122
+ if len(shape) >= 2:
123
+ feature_dim = shape[1]
124
+ # EfficientNetV2 feature dimensions
125
+ efficientnetv2_map = {
126
+ 1280: 'tf_efficientnetv2_s',
127
+ 1408: 'tf_efficientnetv2_m',
128
+ 1792: 'tf_efficientnetv2_l'
129
+ }
130
+ return efficientnetv2_map.get(feature_dim, 'tf_efficientnetv2_s')
131
+
132
+ # Fallback: check bn2 layer for EfficientNet variants
133
+ for key in state_dict.keys():
134
+ if key.endswith("bn2.weight"):
135
+ bn2_shape = state_dict[key].shape[0]
136
+ feature_map = {
137
+ 1280: 'efficientnet_b0',
138
+ 1408: 'efficientnet_b1',
139
+ 1536: 'efficientnet_b2',
140
+ 1792: 'efficientnet_b3',
141
+ 1920: 'efficientnet_b4',
142
+ 2048: 'efficientnet_b5',
143
+ 2304: 'efficientnet_b6',
144
+ 2560: 'efficientnet_b7'
145
+ }
146
+ return feature_map.get(bn2_shape, 'efficientnet_b3')
147
+
148
+ return 'efficientnet_b3' # Default fallback
149
+
150
+ # Detect architecture
151
+ detected_arch = detect_model_architecture(model_state_dict)
152
+ st.info(f"Detected model architecture: {detected_arch}")
153
+
154
+ # List of architectures to try in order
155
+ architectures_to_try = [
156
+ detected_arch,
157
+ 'efficientnet_b3',
158
+ 'efficientnet_b2',
159
+ 'efficientnet_b0',
160
+ 'efficientnet_b1',
161
+ 'efficientnet_b4',
162
+ 'tf_efficientnetv2_s',
163
+ 'tf_efficientnetv2_m'
164
+ ]
165
+
166
+ # Remove duplicates while preserving order
167
+ seen = set()
168
+ architectures_to_try = [x for x in architectures_to_try if not (x in seen or seen.add(x))]
169
+
170
+ model = None
171
+ successful_arch = None
172
+
173
+ # Try each architecture
174
+ for arch in architectures_to_try:
175
+ try:
176
+ st.info(f"Trying architecture: {arch}")
177
+
178
+ # Create model with the detected/guessed architecture
179
+ model = timm.create_model(
180
+ arch,
181
+ pretrained=False,
182
+ num_classes=num_classes,
183
+ drop_rate=0.4,
184
+ drop_path_rate=0.3
185
+ )
186
+
187
+ # Check if the model has a custom head/classifier in the checkpoint
188
+ if any('head.' in key for key in model_state_dict.keys()):
189
+ # Model has custom head - try to load it
190
+ try:
191
+ model.load_state_dict(model_state_dict, strict=False)
192
+ st.success(f"✅ Successfully loaded model with architecture: {arch}")
193
+ successful_arch = arch
194
+ break
195
+ except Exception as e:
196
+ st.warning(f"Failed to load custom head for {arch}: {str(e)}")
197
+ continue
198
+
199
+ elif any('classifier.' in key for key in model_state_dict.keys()):
200
+ # Model has custom classifier - try to load it
201
+ try:
202
+ model.load_state_dict(model_state_dict, strict=False)
203
+ st.success(f"✅ Successfully loaded model with architecture: {arch}")
204
+ successful_arch = arch
205
+ break
206
+ except Exception as e:
207
+ st.warning(f"Failed to load custom classifier for {arch}: {str(e)}")
208
+ continue
209
+
210
+ else:
211
+ # Try to create custom head/classifier and load backbone
212
+ try:
213
+ # Load backbone weights (ignore head/classifier mismatches)
214
+ backbone_dict = {k: v for k, v in model_state_dict.items()
215
+ if not (k.startswith('head.') or k.startswith('classifier.'))}
216
+
217
+ model.load_state_dict(backbone_dict, strict=False)
218
+
219
+ # Create new head/classifier
220
+ if hasattr(model, 'classifier'):
221
+ in_features = model.classifier.in_features
222
+ model.classifier = torch.nn.Linear(in_features, num_classes)
223
+ elif hasattr(model, 'head'):
224
+ in_features = model.head.in_features
225
+ model.head = torch.nn.Linear(in_features, num_classes)
226
+
227
+ st.warning(f"⚠️ Loaded {arch} with new head/classifier (backbone weights only)")
228
+ successful_arch = arch
229
+ break
230
+
231
+ except Exception as e:
232
+ st.warning(f"Failed to load backbone for {arch}: {str(e)}")
233
+ continue
234
+
235
+ except Exception as e:
236
+ st.warning(f"Failed to create model {arch}: {str(e)}")
237
+ continue
238
+
239
+ if model is None:
240
+ st.error("❌ Failed to load model with any architecture!")
241
+ return None
242
+
243
+ # Set model to evaluation mode
244
  model.eval()
245
+
246
+ # Display model info
247
+ total_params = sum(p.numel() for p in model.parameters())
248
+ st.success(f"✅ Model loaded successfully!")
249
+ st.info(f"📊 Model: {successful_arch}")
250
+ st.info(f"🔢 Parameters: {total_params:,}")
251
+ st.info(f"🎯 Classes: {num_classes}")
252
+
253
  return model
254
+
255
  except Exception as e:
256
+ st.error(f"Error loading model: {str(e)}")
257
+ st.error("Please check your model file and ensure it's compatible")
258
  return None
259
 
260
+ # Load model
261
  model = load_model()
262
 
263
  def predict_butterfly(image, threshold=0.5):
264
+ """Predict butterfly species from image"""
265
  try:
266
  if model is None:
267
  raise ValueError("Model is not loaded.")
268
  if image is None:
269
  return None, None
270
+
271
+ # Convert to PIL Image if needed
272
  if isinstance(image, np.ndarray):
273
  image = Image.fromarray(image)
274
  if image.mode != 'RGB':
275
  image = image.convert('RGB')
276
+
277
+ # Apply transforms
278
  transformed = inference_transform(image=np.array(image))
279
  input_tensor = transformed['image'].unsqueeze(0)
280
+
281
+ # Make prediction
282
  with torch.no_grad():
283
  output = model(input_tensor)
284
  probabilities = torch.nn.functional.softmax(output[0], dim=0)
285
  confidence, pred = torch.max(probabilities, 0)
286
+
287
  if confidence.item() < threshold:
288
  return None, confidence.item()
289
+
290
  predicted_class = class_names[pred.item()]
291
  return predicted_class, confidence.item()
292
+
293
  except Exception as e:
294
  st.error(f"Prediction error: {str(e)}")
295
  return None, None
296
 
297
+ def predict_with_tta(image, threshold=0.5, num_tta=5):
298
+ """Predict with Test Time Augmentation for better accuracy"""
299
+ try:
300
+ if model is None:
301
+ raise ValueError("Model is not loaded.")
302
+ if image is None:
303
+ return None, None
304
+
305
+ # Convert to PIL Image if needed
306
+ if isinstance(image, np.ndarray):
307
+ image = Image.fromarray(image)
308
+ if image.mode != 'RGB':
309
+ image = image.convert('RGB')
310
+
311
+ # Convert to numpy for albumentations
312
+ image_np = np.array(image)
313
+
314
+ # TTA transforms
315
+ tta_transforms = [
316
+ A.Compose([
317
+ A.Resize(224, 224),
318
+ A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
319
+ ToTensorV2()
320
+ ]),
321
+ A.Compose([
322
+ A.Resize(256, 256),
323
+ A.CenterCrop(224, 224),
324
+ A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
325
+ ToTensorV2()
326
+ ]),
327
+ A.Compose([
328
+ A.Resize(224, 224),
329
+ A.HorizontalFlip(p=1.0),
330
+ A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
331
+ ToTensorV2()
332
+ ]),
333
+ A.Compose([
334
+ A.Resize(240, 240),
335
+ A.Rotate(limit=10, p=1.0),
336
+ A.CenterCrop(224, 224),
337
+ A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
338
+ ToTensorV2()
339
+ ]),
340
+ A.Compose([
341
+ A.Resize(224, 224),
342
+ A.ColorJitter(brightness=0.1, contrast=0.1, p=1.0),
343
+ A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
344
+ ToTensorV2()
345
+ ])
346
+ ]
347
+
348
+ predictions = []
349
+
350
+ for i, transform in enumerate(tta_transforms[:num_tta]):
351
+ transformed = transform(image=image_np)
352
+ input_tensor = transformed['image'].unsqueeze(0)
353
+
354
+ with torch.no_grad():
355
+ output = model(input_tensor)
356
+ probabilities = torch.nn.functional.softmax(output, dim=1)
357
+ predictions.append(probabilities)
358
+
359
+ # Average predictions
360
+ avg_predictions = torch.mean(torch.stack(predictions), dim=0)
361
+ confidence, pred = torch.max(avg_predictions, 1)
362
+
363
+ if confidence.item() < threshold:
364
+ return None, confidence.item()
365
+
366
+ predicted_class = class_names[pred.item()]
367
+ return predicted_class, confidence.item()
368
+
369
+ except Exception as e:
370
+ st.error(f"TTA Prediction error: {str(e)}")
371
+ return None, None
372
+
373
  # UI Code
374
+ st.title("🦋 Liblikamaja ID / Butterfly Identifier")
375
  st.write("Tuvasta liblikaid oma kaamera abil või laadi üles pilt! / Identify butterflies using your camera or by uploading an image!")
376
 
377
+ # Add model status indicator
378
+ if model is not None:
379
+ st.success("✅ Model loaded and ready!")
380
+ else:
381
+ st.error("❌ Model not loaded. Please check your model file.")
382
+ st.stop()
383
 
384
+ # Add advanced options
385
+ with st.expander("🔧 Advanced Options / Täpsemad seaded"):
386
+ confidence_threshold = st.slider(
387
+ "Confidence Threshold / Kindluse lävi",
388
+ min_value=0.1,
389
+ max_value=1.0,
390
+ value=0.5,
391
+ step=0.05,
392
+ help="Higher values = more conservative predictions"
393
+ )
394
+
395
+ use_tta = st.checkbox(
396
+ "Use Test Time Augmentation (TTA) / Kasuta TTA",
397
+ value=False,
398
+ help="Slower but potentially more accurate predictions"
399
+ )
400
+
401
+ if use_tta:
402
+ tta_rounds = st.slider(
403
+ "TTA Rounds / TTA ringid",
404
+ min_value=3,
405
+ max_value=8,
406
+ value=5,
407
+ help="More rounds = slower but potentially more accurate"
408
+ )
409
+
410
+ tab1, tab2 = st.tabs(["📷 Live Camera / Kaamera", "📁 Upload Image / Laadi üles"])
411
 
412
  with tab1:
413
  st.header("Kaamera jäädvustamine / Camera Capture")
414
  st.write("Tee pilt liblikast tuvastamiseks / Take a photo of a butterfly for identification.")
415
+
416
  camera_photo = st.camera_input("Pildista liblikat / Capture a butterfly")
417
+
418
  if camera_photo is not None:
419
  try:
420
  image = Image.open(camera_photo).convert("RGB")
421
  col1, col2 = st.columns(2)
422
+
423
  with col1:
424
  st.image(image, caption="Jäädvustatud pilt / Captured Image", use_column_width=True)
425
+
426
  with col2:
427
  with st.spinner("Pildi analüüsimine... / Analyzing image..."):
428
+ if use_tta:
429
+ predicted_class, confidence = predict_with_tta(image, confidence_threshold, tta_rounds)
430
+ else:
431
+ predicted_class, confidence = predict_butterfly(image, confidence_threshold)
432
+
433
+ if predicted_class and confidence >= confidence_threshold:
434
  st.success(f"**Liblikas / Butterfly: {predicted_class}**")
435
+ st.info(f"Confidence: {confidence:.2%}")
436
+
437
  if predicted_class in butterfly_info:
438
  st.markdown("**Liigi kirjeldus / About this species:**")
439
  st.write(butterfly_info[predicted_class]["description"])
440
+ else:
441
+ st.info("No additional information available for this species.")
442
  else:
443
+ confidence_text = f" (Confidence: {confidence:.2%})" if confidence else ""
444
+ st.warning(f"⚠️ Ma ei tea, mis liblikas see on / I don't know what butterfly this is{confidence_text}")
445
  st.markdown("**Näpunäited paremate tulemuste saavutamiseks / Tips for better results:**")
446
  st.markdown("- Kasutage paremat valgustust / Use better lighting")
447
  st.markdown("- Veenduge, et liblikas oleks selgelt nähtav / Ensure the butterfly is clearly visible")
448
  st.markdown("- Vältige uduseid või tumedaid pilte / Avoid blurry or dark images")
449
+ st.markdown("- Proovige madalamat kindluse läviväärtust / Try a lower confidence threshold")
450
+
451
  except Exception as e:
452
  st.error(f"Error processing image: {str(e)}")
453
 
454
  with tab2:
455
  st.header("Laadi üles pilt / Upload Image")
456
  st.write("Laadige üles liblika selge foto tuvastamiseks / Upload a clear photo of a butterfly for identification.")
457
+
458
  uploaded_file = st.file_uploader("Vali pilt... / Choose an image...", type=["jpg", "jpeg", "png"])
459
+
460
  if uploaded_file is not None:
461
  try:
462
  image_bytes = uploaded_file.read()
463
  image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
464
  col1, col2 = st.columns(2)
465
+
466
  with col1:
467
  st.image(image, caption="Üleslaetud pilt / Uploaded Image", use_column_width=True)
468
+
469
  with col2:
470
  with st.spinner("Pildi analüüsimine... / Analyzing image..."):
471
+ if use_tta:
472
+ predicted_class, confidence = predict_with_tta(image, confidence_threshold, tta_rounds)
473
+ else:
474
+ predicted_class, confidence = predict_butterfly(image, confidence_threshold)
475
+
476
+ if predicted_class and confidence >= confidence_threshold:
477
  st.success(f"**Liblikas / Butterfly: {predicted_class}**")
478
+ st.info(f"Confidence: {confidence:.2%}")
479
+
480
  if predicted_class in butterfly_info:
481
  st.markdown("**Liigi kirjeldus / About this species:**")
482
  st.write(butterfly_info[predicted_class]["description"])
483
+ else:
484
+ st.info("No additional information available for this species.")
485
  else:
486
+ confidence_text = f" (Confidence: {confidence:.2%})" if confidence else ""
487
+ st.warning(f"⚠️ Ma ei tea, mis liblikas see on / I don't know what butterfly this is{confidence_text}")
488
  st.markdown("**Näpunäited paremate tulemuste saavutamiseks / Tips for better results:**")
489
  st.markdown("- Kasutage paremat valgustust / Use better lighting")
490
  st.markdown("- Veenduge, et liblikas oleks selgelt nähtav / Ensure the butterfly is clearly visible")
491
  st.markdown("- Vältige uduseid või tumedaid pilte / Avoid blurry or dark images")
492
+ st.markdown("- Proovige madalamat kindluse läviväärtust / Try a lower confidence threshold")
493
+
494
  except Exception as e:
495
  st.error(f"Error processing image: {str(e)}")
496
 
 
500
  st.markdown("1. **Kaamera jäädvustamine / Camera Capture**: Tehke foto oma seadme kaameraga / Take a photo using your device camera")
501
  st.markdown("2. **Laadi pilt üles / Upload Image**: Vali oma seadmest liblika foto / Choose a butterfly photo from your device")
502
  st.markdown("3. **Parimad tulemused / Best Results**: Kasutage selgeid ja hästi valgustatud fotosid, kus liblikas on selgelt nähtav / Use clear, well-lit photos with the butterfly clearly visible")
503
+ st.markdown("4. **Täpsemad seaded / Advanced Options**: Kohandage kindluse lävi ja kasutage TTA paremate tulemuste saamiseks / Adjust confidence threshold and use TTA for better results")
504
 
505
+ # Debug info
506
+ if st.checkbox("Show Debug Info"):
507
+ st.write("**Class Names:**", class_names)
508
+ st.write("**Number of Classes:**", len(class_names))
509
+ st.write("**Model Status:**", "Loaded" if model else "Not Loaded")
510
+ if butterfly_info:
511
+ st.write("**Species Info Available:**", len(butterfly_info))