File size: 8,197 Bytes
e11e26a
02001f5
 
 
 
605629e
bfc9fea
 
8c61898
e11e26a
 
1809961
 
02001f5
af316a5
 
e11e26a
af316a5
 
 
 
02001f5
1809961
 
 
 
 
 
 
 
 
 
02001f5
 
1809961
 
 
 
 
8c61898
1809961
 
 
 
e11e26a
 
 
 
 
 
02001f5
e11e26a
02001f5
 
a233bbd
 
e11e26a
1809961
9625ec8
1809961
e11e26a
1809961
06966cd
9d53d92
 
 
 
 
 
6be4ed9
06966cd
 
 
 
 
 
 
 
 
 
9d53d92
 
 
 
 
 
 
06966cd
 
e11e26a
1809961
 
 
 
 
02001f5
7fb82fb
02001f5
118e762
8445a65
7fb82fb
 
8445a65
 
 
 
 
 
e11e26a
 
8445a65
 
 
 
8c61898
 
8445a65
 
 
 
 
118e762
62475ba
e11e26a
 
a91e356
e11e26a
bfc9fea
9d53d92
bfc9fea
e11e26a
 
 
38a7793
 
 
 
 
e11e26a
38a7793
e11e26a
a91e356
8c61898
e11e26a
 
 
 
118e762
e11e26a
 
 
 
 
38a7793
 
bfc9fea
 
e11e26a
 
 
605629e
bfc9fea
 
 
 
 
e11e26a
bfc9fea
e11e26a
a91e356
e11e26a
 
 
 
 
 
 
 
 
 
 
bfc9fea
 
 
e11e26a
bfc9fea
e11e26a
 
 
 
a91e356
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
# Full corrected bilingual Streamlit app for Butterfly Identifier
import streamlit as st
from PIL import Image
import torch
import json
import os
import io
import numpy as np
import timm
import albumentations as A
from albumentations.pytorch import ToTensorV2
import warnings
warnings.filterwarnings('ignore')

# Configure Streamlit
st.set_page_config(
    page_title="Butterfly Identifier / Liblikamaja ID",
    page_icon="🦋",
    layout="wide"
)

# Load class names
@st.cache_data
def load_class_names():
    try:
        with open("class_names.txt", "r") as f:
            return [line.strip() for line in f.readlines()]
    except FileNotFoundError:
        st.error("class_names.txt file not found!")
        return []

class_names = load_class_names()

# Load butterfly info
@st.cache_data
def load_butterfly_info():
    try:
        with open("butterfly_info.json", "r") as f:
            return json.load(f)
    except:
        return {}

butterfly_info = load_butterfly_info()

# Define transform matching training pipeline
inference_transform = A.Compose([
    A.Resize(224, 224),
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ToTensorV2()
])

# Load the model
@st.cache_resource
def load_model():
    MODEL_PATH = "butterfly_classifier.pth"
    if not os.path.exists(MODEL_PATH):
        st.error("Model file not found!")
        return None
    try:
        checkpoint = torch.load(MODEL_PATH, map_location='cpu')
        model_state_dict = checkpoint['model_state_dict'] if 'model_state_dict' in checkpoint else checkpoint
        num_classes = len(class_names)

        # Attempt to auto-detect model from batch norm layer dimensions
        bn2_shape = None
        for key in model_state_dict:
            if key.endswith("bn2.weight"):
                bn2_shape = model_state_dict[key].shape[0]
                break

        feature_map = {
            1280: 'efficientnet_b0',
            1408: 'efficientnet_b1',
            1536: 'efficientnet_b2',
            1792: 'efficientnet_b3',
            1920: 'efficientnet_b4',
            2048: 'efficientnet_b5',
            2304: 'efficientnet_b6',
            2560: 'efficientnet_b7'
        }

        if bn2_shape is None:
            st.warning("Could not detect classifier or bn2 layer in checkpoint. Defaulting to efficientnet_b3")
            model_name = 'efficientnet_b3'
        else:
            model_name = feature_map.get(bn2_shape, 'efficientnet_b3')
            st.info(f"Detected model architecture: {model_name}")

        model = timm.create_model(model_name, pretrained=False, num_classes=num_classes, drop_rate=0.4, drop_path_rate=0.3)
        model.load_state_dict(model_state_dict, strict=False)
        model.eval()
        return model
    except Exception as e:
        st.error(f"Error loading model: {str(e)}")
        return None

model = load_model()

def predict_butterfly(image, threshold=0.5):
    try:
        if model is None:
            raise ValueError("Model is not loaded.")
        if image is None:
            return None, None
        if isinstance(image, np.ndarray):
            image = Image.fromarray(image)
        if image.mode != 'RGB':
            image = image.convert('RGB')
        transformed = inference_transform(image=np.array(image))
        input_tensor = transformed['image'].unsqueeze(0)
        with torch.no_grad():
            output = model(input_tensor)
            probabilities = torch.nn.functional.softmax(output[0], dim=0)
            confidence, pred = torch.max(probabilities, 0)
            if confidence.item() < threshold:
                return None, confidence.item()
            predicted_class = class_names[pred.item()]
            return predicted_class, confidence.item()
    except Exception as e:
        st.error(f"Prediction error: {str(e)}")
        return None, None

# UI Code
st.title("🦋  Liblikamaja ID / Butterfly Identifier")
st.write("Tuvasta liblikaid oma kaamera abil või laadi üles pilt! / Identify butterflies using your camera or by uploading an image!")

tab1, tab2 = st.tabs(["📷 Live Camera / Kaamera", "📁 Upload Image / Laadi üles"])


with tab1:
    st.header("Kaamera jäädvustamine / Camera Capture")
    st.write("Tee pilt liblikast tuvastamiseks / Take a photo of a butterfly for identification.")
    camera_photo = st.camera_input("Pildista liblikat / Capture a butterfly")
    if camera_photo is not None:
        try:
            image = Image.open(camera_photo).convert("RGB")
            col1, col2 = st.columns(2)
            with col1:
                st.image(image, caption="Jäädvustatud pilt / Captured Image", use_column_width=True)
            with col2:
                with st.spinner("Pildi analüüsimine... / Analyzing image..."):
                    predicted_class, confidence = predict_butterfly(image)
                if predicted_class and confidence >= 0.50:
                    st.success(f"**Liblikas / Butterfly: {predicted_class}**")
                    if predicted_class in butterfly_info:
                        st.markdown("**Liigi kirjeldus / About this species:**")
                        st.write(butterfly_info[predicted_class]["description"])
                else:
                    st.warning("⚠️ Ma ei tea, mis liblikas see on / I don't know what butterfly this is")
                    st.markdown("**Näpunäited paremate tulemuste saavutamiseks / Tips for better results:**")
                    st.markdown("- Kasutage paremat valgustust / Use better lighting")
                    st.markdown("- Veenduge, et liblikas oleks selgelt nähtav / Ensure the butterfly is clearly visible")
                    st.markdown("- Vältige uduseid või tumedaid pilte / Avoid blurry or dark images")
        except Exception as e:
            st.error(f"Error processing image: {str(e)}")

with tab2:
    st.header("Laadi üles pilt / Upload Image")
    st.write("Laadige üles liblika selge foto tuvastamiseks / Upload a clear photo of a butterfly for identification.")
    uploaded_file = st.file_uploader("Vali pilt... / Choose an image...", type=["jpg", "jpeg", "png"])
    if uploaded_file is not None:
        try:
            image_bytes = uploaded_file.read()
            image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
            col1, col2 = st.columns(2)
            with col1:
                st.image(image, caption="Üleslaetud pilt / Uploaded Image", use_column_width=True)
            with col2:
                with st.spinner("Pildi analüüsimine... / Analyzing image..."):
                    predicted_class, confidence = predict_butterfly(image)
                if predicted_class and confidence >= 0.50:
                    st.success(f"**Liblikas / Butterfly: {predicted_class}**")
                    if predicted_class in butterfly_info:
                        st.markdown("**Liigi kirjeldus / About this species:**")
                        st.write(butterfly_info[predicted_class]["description"])
                else:
                    st.warning("⚠️ Ma ei tea, mis liblikas see on / I don't know what butterfly this is")
                    st.markdown("**Näpunäited paremate tulemuste saavutamiseks / Tips for better results:**")
                    st.markdown("- Kasutage paremat valgustust / Use better lighting")
                    st.markdown("- Veenduge, et liblikas oleks selgelt nähtav / Ensure the butterfly is clearly visible")
                    st.markdown("- Vältige uduseid või tumedaid pilte / Avoid blurry or dark images")
        except Exception as e:
            st.error(f"Error processing image: {str(e)}")

# Footer
st.markdown("---")
st.markdown("### Kuidas kasutada / How to use:")
st.markdown("1. **Kaamera jäädvustamine / Camera Capture**: Tehke foto oma seadme kaameraga / Take a photo using your device camera")
st.markdown("2. **Laadi pilt üles / Upload Image**: Vali oma seadmest liblika foto / Choose a butterfly photo from your device")
st.markdown("3. **Parimad tulemused / Best Results**: Kasutage selgeid ja hästi valgustatud fotosid, kus liblikas on selgelt nähtav / Use clear, well-lit photos with the butterfly clearly visible")