File size: 17,278 Bytes
dbb62b2
02001f5
 
 
 
605629e
bfc9fea
 
8c61898
e11e26a
 
1809961
 
02001f5
af316a5
 
e11e26a
af316a5
 
 
 
02001f5
1809961
 
 
 
 
 
 
 
 
 
02001f5
 
1809961
 
 
 
 
8c61898
1809961
 
 
 
e11e26a
 
 
 
 
 
02001f5
dbb62b2
02001f5
 
2edff24
dbb62b2
 
 
 
2edff24
dbb62b2
 
 
 
 
 
 
 
 
 
2edff24
1809961
dbb62b2
 
 
9625ec8
dbb62b2
1809961
dbb62b2
 
 
 
9d53d92
dbb62b2
2edff24
dbb62b2
 
2edff24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dbb62b2
2edff24
 
dbb62b2
2edff24
 
dbb62b2
2edff24
 
dbb62b2
2edff24
 
 
 
 
 
dbb62b2
 
 
 
 
 
 
 
 
 
 
 
 
 
2edff24
dbb62b2
 
 
 
2edff24
 
dbb62b2
 
2edff24
 
 
 
 
 
 
 
dbb62b2
 
2edff24
dbb62b2
 
2edff24
 
dbb62b2
 
 
 
 
 
 
 
 
 
 
1809961
dbb62b2
 
 
 
 
 
 
 
1809961
dbb62b2
1809961
dbb62b2
1809961
02001f5
dbb62b2
7fb82fb
02001f5
118e762
dbb62b2
8445a65
7fb82fb
 
8445a65
 
dbb62b2
 
8445a65
 
 
 
dbb62b2
 
e11e26a
 
dbb62b2
 
8445a65
 
 
 
dbb62b2
8c61898
 
dbb62b2
8445a65
 
dbb62b2
8445a65
 
 
118e762
dbb62b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62475ba
dbb62b2
e11e26a
a91e356
dbb62b2
 
 
 
 
 
bfc9fea
dbb62b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9d53d92
bfc9fea
e11e26a
 
dbb62b2
e11e26a
dbb62b2
38a7793
 
 
 
dbb62b2
38a7793
e11e26a
dbb62b2
38a7793
e11e26a
dbb62b2
 
 
 
 
 
e11e26a
dbb62b2
 
e11e26a
 
 
dbb62b2
 
118e762
dbb62b2
 
e11e26a
 
 
 
dbb62b2
 
38a7793
 
bfc9fea
 
e11e26a
 
dbb62b2
e11e26a
dbb62b2
605629e
bfc9fea
 
 
 
dbb62b2
bfc9fea
e11e26a
dbb62b2
bfc9fea
e11e26a
dbb62b2
 
 
 
 
 
e11e26a
dbb62b2
 
e11e26a
 
 
dbb62b2
 
e11e26a
dbb62b2
 
e11e26a
 
 
 
dbb62b2
 
bfc9fea
 
 
e11e26a
bfc9fea
e11e26a
 
 
 
dbb62b2
a91e356
dbb62b2
 
 
 
 
 
66294ea
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
# Enhanced Butterfly Identifier Streamlit App with Better Model Loading
import streamlit as st
from PIL import Image
import torch
import json
import os
import io
import numpy as np
import timm
import albumentations as A
from albumentations.pytorch import ToTensorV2
import warnings
warnings.filterwarnings('ignore')

# Configure Streamlit
st.set_page_config(
    page_title="Butterfly Identifier / Liblikamaja ID",
    page_icon="🦋",
    layout="wide"
)

# Load class names
@st.cache_data
def load_class_names():
    try:
        with open("class_names.txt", "r") as f:
            return [line.strip() for line in f.readlines()]
    except FileNotFoundError:
        st.error("class_names.txt file not found!")
        return []

class_names = load_class_names()

# Load butterfly info
@st.cache_data
def load_butterfly_info():
    try:
        with open("butterfly_info.json", "r") as f:
            return json.load(f)
    except:
        return {}

butterfly_info = load_butterfly_info()

# Define transform matching training pipeline
inference_transform = A.Compose([
    A.Resize(224, 224),
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ToTensorV2()
])

# Enhanced model loading function
@st.cache_resource
def load_model():
    """Enhanced model loading with better architecture detection"""
    
    # Try different model file names
    model_files = [
        "butterfly_classifier.pth",
        "best_butterfly_model_v3.pth", 
        "best_butterfly_model.pth"
    ]
    
    MODEL_PATH = None
    for model_file in model_files:
        if os.path.exists(model_file):
            MODEL_PATH = model_file
            break
    
    if MODEL_PATH is None:
        st.error("No model file found!")
        return None
    
    st.info(f"Loading model from: {MODEL_PATH}")
    
    try:
        # Load checkpoint
        checkpoint = torch.load(MODEL_PATH, map_location='cpu')
        
        # Extract model state dict
        if 'model_state_dict' in checkpoint:
            model_state_dict = checkpoint['model_state_dict']
        else:
            model_state_dict = checkpoint
            
        num_classes = len(class_names)
        
        # Better architecture detection based on conv_stem channels
        def detect_architecture_by_channels(state_dict):
            """Detect architecture by examining conv_stem channels"""
            for key, tensor in state_dict.items():
                if key.endswith('conv_stem.weight'):
                    channels = tensor.shape[0]  # Output channels
                    # Map channels to likely architectures
                    channel_map = {
                        24: ['tf_efficientnetv2_s', 'efficientnet_b0'],
                        32: ['tf_efficientnetv2_s', 'efficientnet_b1'], 
                        40: ['efficientnet_b3', 'efficientnet_b2'],
                        48: ['efficientnet_b4', 'tf_efficientnetv2_m'],
                        56: ['efficientnet_b5'],
                        64: ['efficientnet_b6', 'tf_efficientnetv2_l'],
                        72: ['efficientnet_b7']
                    }
                    return channel_map.get(channels, ['tf_efficientnetv2_s'])
            return ['tf_efficientnetv2_s']
        
        # Get likely architectures based on channels
        likely_architectures = detect_architecture_by_channels(model_state_dict)
        
        # Expanded list of architectures to try
        architectures_to_try = likely_architectures + [
            'tf_efficientnetv2_s',
            'efficientnet_b0', 
            'efficientnet_b1',
            'efficientnet_b2',
            'efficientnet_b3',
            'tf_efficientnetv2_m',
            'efficientnet_b4'
        ]
        
        # Remove duplicates while preserving order
        seen = set()
        architectures_to_try = [x for x in architectures_to_try if not (x in seen or seen.add(x))]
        
        model = None
        successful_arch = None
        
        # Try each architecture
        for arch in architectures_to_try:
            try:
                st.info(f"Trying architecture: {arch}")
                
                # Create model
                model = timm.create_model(
                    arch,
                    pretrained=False,
                    num_classes=num_classes,
                    drop_rate=0.0,  # Set to 0 for inference
                    drop_path_rate=0.0  # Set to 0 for inference
                )
                
                # Try to load the state dict
                try:
                    model.load_state_dict(model_state_dict, strict=True)
                    st.success(f"✅ Successfully loaded model with architecture: {arch}")
                    successful_arch = arch
                    break
                except Exception as e:
                    # Try with strict=False
                    try:
                        model.load_state_dict(model_state_dict, strict=False)
                        st.warning(f"⚠️ Loaded {arch} with some mismatched weights")
                        successful_arch = arch
                        break
                    except Exception as e2:
                        st.warning(f"Failed to load {arch}: {str(e2)}")
                        continue
                        
            except Exception as e:
                st.warning(f"Failed to create model {arch}: {str(e)}")
                continue
        
        if model is None:
            st.error("❌ Failed to load model with any architecture!")
            return None
        
        # Set model to evaluation mode
        model.eval()
        
        # Display model info
        total_params = sum(p.numel() for p in model.parameters())
        st.success(f"✅ Model loaded successfully!")
        st.info(f"📊 Model: {successful_arch}")
        st.info(f"🔢 Parameters: {total_params:,}")
        st.info(f"🎯 Classes: {num_classes}")
        
        return model
        
    except Exception as e:
        st.error(f"❌ Error loading model: {str(e)}")
        return None

# Load model
model = load_model()

def predict_butterfly(image, threshold=0.5):
    """Predict butterfly species from image"""
    try:
        if model is None:
            raise ValueError("Model is not loaded.")
        if image is None:
            return None, None
            
        # Convert to PIL Image if needed
        if isinstance(image, np.ndarray):
            image = Image.fromarray(image)
        if image.mode != 'RGB':
            image = image.convert('RGB')
        
        # Apply transforms
        transformed = inference_transform(image=np.array(image))
        input_tensor = transformed['image'].unsqueeze(0)
        
        # Make prediction
        with torch.no_grad():
            output = model(input_tensor)
            probabilities = torch.nn.functional.softmax(output[0], dim=0)
            confidence, pred = torch.max(probabilities, 0)
            
            if confidence.item() < threshold:
                return None, confidence.item()
                
            predicted_class = class_names[pred.item()]
            return predicted_class, confidence.item()
            
    except Exception as e:
        st.error(f"Prediction error: {str(e)}")
        return None, None

def predict_with_tta(image, threshold=0.5, num_tta=5):
    """Predict with Test Time Augmentation for better accuracy"""
    try:
        if model is None:
            raise ValueError("Model is not loaded.")
        if image is None:
            return None, None
            
        # Convert to PIL Image if needed
        if isinstance(image, np.ndarray):
            image = Image.fromarray(image)
        if image.mode != 'RGB':
            image = image.convert('RGB')
        
        # Convert to numpy for albumentations
        image_np = np.array(image)
        
        # TTA transforms
        tta_transforms = [
            A.Compose([
                A.Resize(224, 224),
                A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
                ToTensorV2()
            ]),
            A.Compose([
                A.Resize(256, 256),
                A.CenterCrop(224, 224),
                A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
                ToTensorV2()
            ]),
            A.Compose([
                A.Resize(224, 224),
                A.HorizontalFlip(p=1.0),
                A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
                ToTensorV2()
            ]),
            A.Compose([
                A.Resize(240, 240),
                A.Rotate(limit=10, p=1.0),
                A.CenterCrop(224, 224),
                A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
                ToTensorV2()
            ]),
            A.Compose([
                A.Resize(224, 224),
                A.ColorJitter(brightness=0.1, contrast=0.1, p=1.0),
                A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
                ToTensorV2()
            ])
        ]
        
        predictions = []
        
        for i, transform in enumerate(tta_transforms[:num_tta]):
            transformed = transform(image=image_np)
            input_tensor = transformed['image'].unsqueeze(0)
            
            with torch.no_grad():
                output = model(input_tensor)
                probabilities = torch.nn.functional.softmax(output, dim=1)
                predictions.append(probabilities)
        
        # Average predictions
        avg_predictions = torch.mean(torch.stack(predictions), dim=0)
        confidence, pred = torch.max(avg_predictions, 1)
        
        if confidence.item() < threshold:
            return None, confidence.item()
            
        predicted_class = class_names[pred.item()]
        return predicted_class, confidence.item()
        
    except Exception as e:
        st.error(f"TTA Prediction error: {str(e)}")
        return None, None

# UI Code
st.title("🦋 Liblikamaja ID / Butterfly Identifier")
st.write("Tuvasta liblikaid oma kaamera abil või laadi üles pilt! / Identify butterflies using your camera or by uploading an image!")

# Add model status indicator
if model is not None:
    st.success("✅ Model loaded and ready!")
else:
    st.error("❌ Model not loaded. Please check your model file.")
    st.stop()

# Add advanced options
with st.expander("🔧 Advanced Options / Täpsemad seaded"):
    confidence_threshold = st.slider(
        "Confidence Threshold / Kindluse lävi", 
        min_value=0.1, 
        max_value=1.0, 
        value=0.5, 
        step=0.05,
        help="Higher values = more conservative predictions"
    )
    
    use_tta = st.checkbox(
        "Use Test Time Augmentation (TTA) / Kasuta TTA", 
        value=False,
        help="Slower but potentially more accurate predictions"
    )
    
    if use_tta:
        tta_rounds = st.slider(
            "TTA Rounds / TTA ringid", 
            min_value=3, 
            max_value=8, 
            value=5,
            help="More rounds = slower but potentially more accurate"
        )

tab1, tab2 = st.tabs(["📷 Live Camera / Kaamera", "📁 Upload Image / Laadi üles"])

with tab1:
    st.header("Kaamera jäädvustamine / Camera Capture")
    st.write("Tee pilt liblikast tuvastamiseks / Take a photo of a butterfly for identification.")
    
    camera_photo = st.camera_input("Pildista liblikat / Capture a butterfly")
    
    if camera_photo is not None:
        try:
            image = Image.open(camera_photo).convert("RGB")
            col1, col2 = st.columns(2)
            
            with col1:
                st.image(image, caption="Jäädvustatud pilt / Captured Image", use_column_width=True)
            
            with col2:
                with st.spinner("Pildi analüüsimine... / Analyzing image..."):
                    if use_tta:
                        predicted_class, confidence = predict_with_tta(image, confidence_threshold, tta_rounds)
                    else:
                        predicted_class, confidence = predict_butterfly(image, confidence_threshold)
                
                if predicted_class and confidence >= confidence_threshold:
                    st.success(f"**Liblikas / Butterfly: {predicted_class}**")
                    st.info(f"Confidence: {confidence:.2%}")
                    
                    if predicted_class in butterfly_info:
                        st.markdown("**Liigi kirjeldus / About this species:**")
                        st.write(butterfly_info[predicted_class]["description"])
                    else:
                        st.info("No additional information available for this species.")
                else:
                    confidence_text = f" (Confidence: {confidence:.2%})" if confidence else ""
                    st.warning(f"⚠️ Ma ei tea, mis liblikas see on / I don't know what butterfly this is{confidence_text}")
                    st.markdown("**Näpunäited paremate tulemuste saavutamiseks / Tips for better results:**")
                    st.markdown("- Kasutage paremat valgustust / Use better lighting")
                    st.markdown("- Veenduge, et liblikas oleks selgelt nähtav / Ensure the butterfly is clearly visible")
                    st.markdown("- Vältige uduseid või tumedaid pilte / Avoid blurry or dark images")
                    st.markdown("- Proovige madalamat kindluse läviväärtust / Try a lower confidence threshold")
                    
        except Exception as e:
            st.error(f"Error processing image: {str(e)}")

with tab2:
    st.header("Laadi üles pilt / Upload Image")
    st.write("Laadige üles liblika selge foto tuvastamiseks / Upload a clear photo of a butterfly for identification.")
    
    uploaded_file = st.file_uploader("Vali pilt... / Choose an image...", type=["jpg", "jpeg", "png"])
    
    if uploaded_file is not None:
        try:
            image_bytes = uploaded_file.read()
            image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
            col1, col2 = st.columns(2)
            
            with col1:
                st.image(image, caption="Üleslaetud pilt / Uploaded Image", use_column_width=True)
            
            with col2:
                with st.spinner("Pildi analüüsimine... / Analyzing image..."):
                    if use_tta:
                        predicted_class, confidence = predict_with_tta(image, confidence_threshold, tta_rounds)
                    else:
                        predicted_class, confidence = predict_butterfly(image, confidence_threshold)
                
                if predicted_class and confidence >= confidence_threshold:
                    st.success(f"**Liblikas / Butterfly: {predicted_class}**")
                    st.info(f"Confidence: {confidence:.2%}")
                    
                    if predicted_class in butterfly_info:
                        st.markdown("**Liigi kirjeldus / About this species:**")
                        st.write(butterfly_info[predicted_class]["description"])
                    else:
                        st.info("No additional information available for this species.")
                else:
                    confidence_text = f" (Confidence: {confidence:.2%})" if confidence else ""
                    st.warning(f"⚠️ Ma ei tea, mis liblikas see on / I don't know what butterfly this is{confidence_text}")
                    st.markdown("**Näpunäited paremate tulemuste saavutamiseks / Tips for better results:**")
                    st.markdown("- Kasutage paremat valgustust / Use better lighting")
                    st.markdown("- Veenduge, et liblikas oleks selgelt nähtav / Ensure the butterfly is clearly visible")
                    st.markdown("- Vältige uduseid või tumedaid pilte / Avoid blurry or dark images")
                    st.markdown("- Proovige madalamat kindluse läviväärtust / Try a lower confidence threshold")
                    
        except Exception as e:
            st.error(f"Error processing image: {str(e)}")

# Footer
st.markdown("---")
st.markdown("### Kuidas kasutada / How to use:")
st.markdown("1. **Kaamera jäädvustamine / Camera Capture**: Tehke foto oma seadme kaameraga / Take a photo using your device camera")
st.markdown("2. **Laadi pilt üles / Upload Image**: Vali oma seadmest liblika foto / Choose a butterfly photo from your device")
st.markdown("3. **Parimad tulemused / Best Results**: Kasutage selgeid ja hästi valgustatud fotosid, kus liblikas on selgelt nähtav / Use clear, well-lit photos with the butterfly clearly visible")
st.markdown("4. **Täpsemad seaded / Advanced Options**: Kohandage kindluse lävi ja kasutage TTA paremate tulemuste saamiseks / Adjust confidence threshold and use TTA for better results")

# Debug info
if st.checkbox("Show Debug Info"):
    st.write("**Class Names:**", class_names)
    st.write("**Number of Classes:**", len(class_names))
    st.write("**Model Status:**", "Loaded" if model else "Not Loaded")
    if butterfly_info:
        st.write("**Species Info Available:**", len(butterfly_info))