Spaces:
Sleeping
Sleeping
import streamlit as st | |
from PIL import Image | |
import torch | |
from torchvision import models, transforms | |
import json | |
import os | |
import tempfile | |
# Configure Streamlit | |
st.set_page_config( | |
page_title="Butterfly Identifier/liblika tuvastaja", | |
page_icon="🦋", | |
layout="wide" | |
) | |
# Load class names | |
with open("class_names.txt", "r") as f: | |
class_names = [line.strip() for line in f.readlines()] | |
# Load butterfly info | |
try: | |
with open("butterfly_info.json", "r") as f: | |
butterfly_info = json.load(f) | |
except: | |
butterfly_info = {} | |
def load_model(): | |
MODEL_PATH = "butterfly_classifier.pth" | |
if not os.path.exists(MODEL_PATH): | |
st.error("Model file not found. Please upload butterfly_classifier.pth to your space.") | |
return None | |
model = models.resnet18(pretrained=False) | |
model.fc = torch.nn.Linear(model.fc.in_features, len(class_names)) | |
model.load_state_dict(torch.load(MODEL_PATH, map_location="cpu")) | |
model.eval() | |
return model | |
model = load_model() | |
if model is None: | |
st.stop() | |
transform = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
]) | |
st.title("🦋 Butterfly Identifier") | |
st.write("Upload a butterfly image and I'll tell you what species it is!") | |
# Alternative file upload with better error handling | |
try: | |
uploaded_file = st.file_uploader( | |
"Choose an image...", | |
type=["jpg", "jpeg", "png"], | |
help="Upload a clear photo of a butterfly", | |
key="butterfly_image" | |
) | |
if uploaded_file is not None: | |
# Save uploaded file to temporary location | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as tmp_file: | |
tmp_file.write(uploaded_file.read()) | |
tmp_file_path = tmp_file.name | |
# Load image from temporary file | |
image = Image.open(tmp_file_path).convert("RGB") | |
st.image(image, caption="Uploaded Image", use_column_width=True) | |
# Preprocess | |
input_tensor = transform(image).unsqueeze(0) | |
# Predict | |
with torch.no_grad(): | |
output = model(input_tensor) | |
_, pred = torch.max(output, 1) | |
predicted_class = class_names[pred.item()] | |
st.success(f"**Prediction: {predicted_class}**") | |
if predicted_class in butterfly_info: | |
st.info(butterfly_info[predicted_class]["description"]) | |
# Clean up temporary file | |
os.unlink(tmp_file_path) | |
except Exception as e: | |
st.error(f"Error with file upload: {str(e)}") | |
st.info("If you continue to see this error, try refreshing the page or using a different browser.") |