libkamaja_id / streamlit_app.py
leynessa's picture
Update streamlit_app.py
605629e verified
raw
history blame
2.68 kB
import streamlit as st
from PIL import Image
import torch
from torchvision import models, transforms
import json
import os
import tempfile
# Configure Streamlit
st.set_page_config(
page_title="Butterfly Identifier/liblika tuvastaja",
page_icon="🦋",
layout="wide"
)
# Load class names
with open("class_names.txt", "r") as f:
class_names = [line.strip() for line in f.readlines()]
# Load butterfly info
try:
with open("butterfly_info.json", "r") as f:
butterfly_info = json.load(f)
except:
butterfly_info = {}
@st.cache_resource
def load_model():
MODEL_PATH = "butterfly_classifier.pth"
if not os.path.exists(MODEL_PATH):
st.error("Model file not found. Please upload butterfly_classifier.pth to your space.")
return None
model = models.resnet18(pretrained=False)
model.fc = torch.nn.Linear(model.fc.in_features, len(class_names))
model.load_state_dict(torch.load(MODEL_PATH, map_location="cpu"))
model.eval()
return model
model = load_model()
if model is None:
st.stop()
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
])
st.title("🦋 Butterfly Identifier")
st.write("Upload a butterfly image and I'll tell you what species it is!")
# Alternative file upload with better error handling
try:
uploaded_file = st.file_uploader(
"Choose an image...",
type=["jpg", "jpeg", "png"],
help="Upload a clear photo of a butterfly",
key="butterfly_image"
)
if uploaded_file is not None:
# Save uploaded file to temporary location
with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as tmp_file:
tmp_file.write(uploaded_file.read())
tmp_file_path = tmp_file.name
# Load image from temporary file
image = Image.open(tmp_file_path).convert("RGB")
st.image(image, caption="Uploaded Image", use_column_width=True)
# Preprocess
input_tensor = transform(image).unsqueeze(0)
# Predict
with torch.no_grad():
output = model(input_tensor)
_, pred = torch.max(output, 1)
predicted_class = class_names[pred.item()]
st.success(f"**Prediction: {predicted_class}**")
if predicted_class in butterfly_info:
st.info(butterfly_info[predicted_class]["description"])
# Clean up temporary file
os.unlink(tmp_file_path)
except Exception as e:
st.error(f"Error with file upload: {str(e)}")
st.info("If you continue to see this error, try refreshing the page or using a different browser.")