File size: 7,686 Bytes
02001f5
 
 
 
 
605629e
bfc9fea
 
 
 
 
bf01465
02001f5
af316a5
 
bfc9fea
af316a5
 
 
 
02001f5
 
 
 
 
 
 
 
 
 
 
 
 
a233bbd
8bcdd21
a233bbd
8bcdd21
 
 
aa86cbd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
02001f5
aa86cbd
02001f5
 
bf01465
02001f5
 
6052c4b
 
 
02001f5
 
 
 
 
bfc9fea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
02001f5
bfc9fea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bf01465
bfc9fea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38a7793
 
bfc9fea
38a7793
 
bfc9fea
38a7793
 
 
 
 
 
 
 
 
 
 
 
 
bf01465
38a7793
 
 
34fd92b
 
 
bf01465
 
 
 
 
 
 
 
 
38a7793
 
bfc9fea
 
 
 
 
605629e
 
 
bfc9fea
605629e
 
 
bfc9fea
 
 
 
af316a5
bfc9fea
 
 
 
 
 
 
 
bf01465
bfc9fea
 
 
 
 
 
bf01465
 
 
 
 
 
 
 
bfc9fea
 
 
 
 
 
 
bf01465
bfc9fea
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
import streamlit as st
from PIL import Image
import torch
from torchvision import models, transforms
import json
import os
import io
import numpy as np
from streamlit_webrtc import webrtc_streamer, WebRtcMode, RTCConfiguration
import av
import cv2
import timm  

# Configure Streamlit
st.set_page_config(
    page_title="Butterfly Identifier/Liblikamaja ID",
    page_icon="🦋",
    layout="wide"
)

# Load class names
with open("class_names.txt", "r") as f:
    class_names = [line.strip() for line in f.readlines()]

# Load butterfly info
try:
    with open("butterfly_info.json", "r") as f:
        butterfly_info = json.load(f)
except:
    butterfly_info = {}

@st.cache_resource
def load_model():
    MODEL_PATH = "butterfly_classifier.pth"
    
    if not os.path.exists(MODEL_PATH):
        st.error("Model file not found. Please upload butterfly_classifier.pth to your space.")
        return None
    
    # Load the checkpoint first to check the actual number of classes
    checkpoint = torch.load(MODEL_PATH, map_location="cpu")
    
    # Get the number of classes from the saved model weights
    if 'classifier.weight' in checkpoint:
        num_classes_in_model = checkpoint['classifier.weight'].shape[0]
    elif 'fc.weight' in checkpoint:  # Alternative naming
        num_classes_in_model = checkpoint['fc.weight'].shape[0]
    else:
        # Fallback: assume it matches class_names
        num_classes_in_model = len(class_names)
    
    
    # Create model with the correct number of classes from the saved model
    model = timm.create_model('efficientnet_b0', pretrained=False, num_classes=num_classes_in_model)
    model.load_state_dict(checkpoint)
    model.eval()
    
    return model

# Load the model
model = load_model()

if model is None:
    st.stop()

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

def predict_butterfly(image):
    """Predict butterfly species from image"""
    if image is None:
        return None, None
    
    # Convert to PIL Image if needed
    if isinstance(image, np.ndarray):
        image = Image.fromarray(image)
    
    # Preprocess
    input_tensor = transform(image).unsqueeze(0)
    
    # Predict
    with torch.no_grad():
        output = model(input_tensor)
        probabilities = torch.nn.functional.softmax(output[0], dim=0)
        confidence, pred = torch.max(probabilities, 0)
        predicted_class = class_names[pred.item()]
        
    return predicted_class, confidence.item()

# Video frame callback for live camera
class VideoProcessor:
    def __init__(self):
        self.prediction_text = ""
        self.frame_count = 0
        
    def recv(self, frame):
        img = frame.to_ndarray(format="bgr24")
        
        # Only process every 30th frame for performance
        self.frame_count += 1
        if self.frame_count % 30 == 0:
            # Convert BGR to RGB
            rgb_img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            predicted_class, confidence = predict_butterfly(rgb_img)
            
            if predicted_class and confidence > 0.8:  # Only show if confidence > 80%
                self.prediction_text = f"{predicted_class} ({confidence:.2f})"
        
        # Draw prediction on frame
        if self.prediction_text:
            cv2.putText(img, self.prediction_text, (10, 30), 
                       cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
        
        return av.VideoFrame.from_ndarray(img, format="bgr24")

# Streamlit UI
st.title("🦋 Butterfly Identifier / Liblikamaja ID")
st.write("Identify butterflies using your camera or by uploading an image!")

# Create tabs for different input methods
tab1, tab2 = st.tabs(["📷 Live Camera", "📁 Upload Image"])

with tab1:
    st.header("Camera Capture")
    st.write("Take a photo of a butterfly for identification!")
    
    # Use Streamlit's built-in camera input
    camera_photo = st.camera_input("Take a picture of a butterfly")
    
    if camera_photo is not None:
        try:
            # Convert to PIL Image
            image = Image.open(camera_photo).convert("RGB")
            
            col1, col2 = st.columns(2)
            
            with col1:
                st.image(image, caption="Captured Image", use_column_width=True)
            
            with col2:
                predicted_class, confidence = predict_butterfly(image)
                
                if predicted_class and confidence >= 0.80:  # Only show if confidence >= 80%
                    st.success(f"**Prediction: {predicted_class}**")
                    st.info(f"Confidence: {confidence:.2%}")
                    
                    #if predicted_class in butterfly_info:
                    #    st.write("**Species Information:**")
                    #    st.write(butterfly_info[predicted_class]["description"])
                else:
                    st.warning("⚠️ **Image not clear - Unable to identify butterfly**")
                    st.info(f"Confidence too low: {confidence:.1%}")
                    st.markdown("**Tips for better results:**")
                    st.markdown("- Use better lighting")
                    st.markdown("- Get closer to the butterfly")
                    st.markdown("- Ensure the butterfly is clearly visible")
                    st.markdown("- Avoid blurry or dark images")
                                    
        except Exception as e:
            st.error(f"Error processing image: {str(e)}")

with tab2:
    st.header("Upload Image")
    st.write("Upload a clear photo of a butterfly for identification")
    
    uploaded_file = st.file_uploader(
        "Choose an image...", 
        type=["jpg", "jpeg", "png"],
        help="Upload a clear photo of a butterfly"
    )
    
    if uploaded_file is not None:
        try:
            # Read file directly into memory
            image_bytes = uploaded_file.read()
            image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
            
            col1, col2 = st.columns(2)
            
            with col1:
                st.image(image, caption="Uploaded Image", use_column_width=True)
            
            with col2:
                predicted_class, confidence = predict_butterfly(image)
                
                if predicted_class and confidence >= 0.80:  # Only show if confidence >= 80%
                    st.success(f"**Prediction: {predicted_class}**")
                    st.info(f"Confidence: {confidence:.2%}")
                    
                    if predicted_class in butterfly_info:
                        st.write("**Species Information:**")
                        st.write(butterfly_info[predicted_class]["description"])
                else:
                    st.warning("⚠️ **Image not clear - Unable to identify butterfly**")
                    st.info(f"Confidence too low: {confidence:.1%}")
                    st.markdown("**Tips for better results:**")
                    st.markdown("- Use better lighting")
                    st.markdown("- Get closer to the butterfly")
                    st.markdown("- Ensure the butterfly is clearly visible")
                    st.markdown("- Avoid blurry or dark images")
                
        except Exception as e:
            st.error(f"Error processing image: {str(e)}")

# Add footer with instructions
st.markdown("---")
st.markdown("### How to use:")
st.markdown("1. **Camera Capture**: Take a photo using your device camera")
st.markdown("2. **Upload Image**: Choose a butterfly photo from your device")
st.markdown("3. **Best Results**: Use clear, well-lit photos with the butterfly clearly visible")