Spaces:
Sleeping
Sleeping
File size: 2,680 Bytes
02001f5 605629e 02001f5 af316a5 02001f5 a233bbd 8bcdd21 a233bbd 8bcdd21 02001f5 a233bbd 02001f5 6052c4b 02001f5 af316a5 02001f5 605629e af316a5 02001f5 af316a5 02001f5 af316a5 02001f5 af316a5 02001f5 af316a5 605629e af316a5 605629e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
import streamlit as st
from PIL import Image
import torch
from torchvision import models, transforms
import json
import os
import tempfile
# Configure Streamlit
st.set_page_config(
page_title="Butterfly Identifier/liblika tuvastaja",
page_icon="🦋",
layout="wide"
)
# Load class names
with open("class_names.txt", "r") as f:
class_names = [line.strip() for line in f.readlines()]
# Load butterfly info
try:
with open("butterfly_info.json", "r") as f:
butterfly_info = json.load(f)
except:
butterfly_info = {}
@st.cache_resource
def load_model():
MODEL_PATH = "butterfly_classifier.pth"
if not os.path.exists(MODEL_PATH):
st.error("Model file not found. Please upload butterfly_classifier.pth to your space.")
return None
model = models.resnet18(pretrained=False)
model.fc = torch.nn.Linear(model.fc.in_features, len(class_names))
model.load_state_dict(torch.load(MODEL_PATH, map_location="cpu"))
model.eval()
return model
model = load_model()
if model is None:
st.stop()
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
])
st.title("🦋 Butterfly Identifier")
st.write("Upload a butterfly image and I'll tell you what species it is!")
# Alternative file upload with better error handling
try:
uploaded_file = st.file_uploader(
"Choose an image...",
type=["jpg", "jpeg", "png"],
help="Upload a clear photo of a butterfly",
key="butterfly_image"
)
if uploaded_file is not None:
# Save uploaded file to temporary location
with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as tmp_file:
tmp_file.write(uploaded_file.read())
tmp_file_path = tmp_file.name
# Load image from temporary file
image = Image.open(tmp_file_path).convert("RGB")
st.image(image, caption="Uploaded Image", use_column_width=True)
# Preprocess
input_tensor = transform(image).unsqueeze(0)
# Predict
with torch.no_grad():
output = model(input_tensor)
_, pred = torch.max(output, 1)
predicted_class = class_names[pred.item()]
st.success(f"**Prediction: {predicted_class}**")
if predicted_class in butterfly_info:
st.info(butterfly_info[predicted_class]["description"])
# Clean up temporary file
os.unlink(tmp_file_path)
except Exception as e:
st.error(f"Error with file upload: {str(e)}")
st.info("If you continue to see this error, try refreshing the page or using a different browser.") |