File size: 2,680 Bytes
02001f5
 
 
 
 
605629e
 
02001f5
af316a5
 
 
 
 
 
 
02001f5
 
 
 
 
 
 
 
 
 
 
 
 
a233bbd
8bcdd21
a233bbd
8bcdd21
 
 
02001f5
 
a233bbd
02001f5
 
 
 
 
6052c4b
 
 
02001f5
 
 
 
 
 
af316a5
02001f5
605629e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
af316a5
02001f5
af316a5
 
02001f5
af316a5
 
 
 
 
02001f5
af316a5
02001f5
af316a5
 
605629e
 
 
af316a5
605629e
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import streamlit as st
from PIL import Image
import torch
from torchvision import models, transforms
import json
import os
import tempfile

# Configure Streamlit
st.set_page_config(
    page_title="Butterfly Identifier/liblika tuvastaja",
    page_icon="🦋",
    layout="wide"
)

# Load class names
with open("class_names.txt", "r") as f:
    class_names = [line.strip() for line in f.readlines()]

# Load butterfly info
try:
    with open("butterfly_info.json", "r") as f:
        butterfly_info = json.load(f)
except:
    butterfly_info = {}

@st.cache_resource
def load_model():
    MODEL_PATH = "butterfly_classifier.pth"
    
    if not os.path.exists(MODEL_PATH):
        st.error("Model file not found. Please upload butterfly_classifier.pth to your space.")
        return None
    
    model = models.resnet18(pretrained=False)
    model.fc = torch.nn.Linear(model.fc.in_features, len(class_names))
    model.load_state_dict(torch.load(MODEL_PATH, map_location="cpu"))
    model.eval()
    return model

model = load_model()

if model is None:
    st.stop()

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

st.title("🦋 Butterfly Identifier")
st.write("Upload a butterfly image and I'll tell you what species it is!")

# Alternative file upload with better error handling
try:
    uploaded_file = st.file_uploader(
        "Choose an image...", 
        type=["jpg", "jpeg", "png"],
        help="Upload a clear photo of a butterfly",
        key="butterfly_image"
    )
    
    if uploaded_file is not None:
        # Save uploaded file to temporary location
        with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as tmp_file:
            tmp_file.write(uploaded_file.read())
            tmp_file_path = tmp_file.name
        
        # Load image from temporary file
        image = Image.open(tmp_file_path).convert("RGB")
        st.image(image, caption="Uploaded Image", use_column_width=True)

        # Preprocess
        input_tensor = transform(image).unsqueeze(0)

        # Predict
        with torch.no_grad():
            output = model(input_tensor)
            _, pred = torch.max(output, 1)
            predicted_class = class_names[pred.item()]

        st.success(f"**Prediction: {predicted_class}**")

        if predicted_class in butterfly_info:
            st.info(butterfly_info[predicted_class]["description"])
        
        # Clean up temporary file
        os.unlink(tmp_file_path)
            
except Exception as e:
    st.error(f"Error with file upload: {str(e)}")
    st.info("If you continue to see this error, try refreshing the page or using a different browser.")