Spaces:
Sleeping
Sleeping
import streamlit as st | |
from PIL import Image | |
import torch | |
from torchvision import models, transforms | |
import json | |
import os | |
import io | |
import numpy as np | |
from streamlit_webrtc import webrtc_streamer, WebRtcMode, RTCConfiguration | |
import av | |
import cv2 | |
import timm | |
# Configure Streamlit | |
st.set_page_config( | |
page_title="Butterfly Identifier/Liblikamaja ID", | |
page_icon="🦋", | |
layout="wide" | |
) | |
# Load class names | |
with open("class_names.txt", "r") as f: | |
class_names = [line.strip() for line in f.readlines()] | |
# Load butterfly info | |
try: | |
with open("butterfly_info.json", "r") as f: | |
butterfly_info = json.load(f) | |
except: | |
butterfly_info = {} | |
def load_model(): | |
MODEL_PATH = "butterfly_classifier.pth" | |
if not os.path.exists(MODEL_PATH): | |
st.error("Model file not found. Please upload butterfly_classifier.pth to your space.") | |
return None | |
# Load the checkpoint first to check the actual number of classes | |
checkpoint = torch.load(MODEL_PATH, map_location="cpu") | |
# Get the number of classes from the saved model weights | |
if 'classifier.weight' in checkpoint: | |
num_classes_in_model = checkpoint['classifier.weight'].shape[0] | |
elif 'fc.weight' in checkpoint: # Alternative naming | |
num_classes_in_model = checkpoint['fc.weight'].shape[0] | |
else: | |
# Fallback: assume it matches class_names | |
num_classes_in_model = len(class_names) | |
# Create model with the correct number of classes from the saved model | |
model = timm.create_model('efficientnet_b0', pretrained=False, num_classes=num_classes_in_model) | |
model.load_state_dict(checkpoint) | |
model.eval() | |
return model | |
# Load the model | |
model = load_model() | |
if model is None: | |
st.stop() | |
transform = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
]) | |
def predict_butterfly(image): | |
"""Predict butterfly species from image""" | |
if image is None: | |
return None, None | |
# Convert to PIL Image if needed | |
if isinstance(image, np.ndarray): | |
image = Image.fromarray(image) | |
# Preprocess | |
input_tensor = transform(image).unsqueeze(0) | |
# Predict | |
with torch.no_grad(): | |
output = model(input_tensor) | |
probabilities = torch.nn.functional.softmax(output[0], dim=0) | |
confidence, pred = torch.max(probabilities, 0) | |
predicted_class = class_names[pred.item()] | |
return predicted_class, confidence.item() | |
# Video frame callback for live camera | |
class VideoProcessor: | |
def __init__(self): | |
self.prediction_text = "" | |
self.frame_count = 0 | |
def recv(self, frame): | |
img = frame.to_ndarray(format="bgr24") | |
# Only process every 30th frame for performance | |
self.frame_count += 1 | |
if self.frame_count % 30 == 0: | |
# Convert BGR to RGB | |
rgb_img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
predicted_class, confidence = predict_butterfly(rgb_img) | |
if predicted_class and confidence > 0.8: # Only show if confidence > 80% | |
self.prediction_text = f"{predicted_class} ({confidence:.2f})" | |
# Draw prediction on frame | |
if self.prediction_text: | |
cv2.putText(img, self.prediction_text, (10, 30), | |
cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2) | |
return av.VideoFrame.from_ndarray(img, format="bgr24") | |
# Streamlit UI | |
st.title("🦋 Butterfly Identifier / Liblikamaja ID") | |
st.write("Identify butterflies using your camera or by uploading an image!") | |
# Create tabs for different input methods | |
tab1, tab2 = st.tabs(["📷 Live Camera", "📁 Upload Image"]) | |
with tab1: | |
st.header("Camera Capture") | |
st.write("Take a photo of a butterfly for identification!") | |
# Use Streamlit's built-in camera input | |
camera_photo = st.camera_input("Take a picture of a butterfly") | |
if camera_photo is not None: | |
try: | |
# Convert to PIL Image | |
image = Image.open(camera_photo).convert("RGB") | |
col1, col2 = st.columns(2) | |
with col1: | |
st.image(image, caption="Captured Image", use_column_width=True) | |
with col2: | |
predicted_class, confidence = predict_butterfly(image) | |
if predicted_class and confidence >= 0.80: # Only show if confidence >= 80% | |
st.success(f"**Prediction: {predicted_class}**") | |
st.info(f"Confidence: {confidence:.2%}") | |
#if predicted_class in butterfly_info: | |
# st.write("**Species Information:**") | |
# st.write(butterfly_info[predicted_class]["description"]) | |
else: | |
st.warning("⚠️ **Image not clear - Unable to identify butterfly**") | |
st.info(f"Confidence too low: {confidence:.1%}") | |
st.markdown("**Tips for better results:**") | |
st.markdown("- Use better lighting") | |
st.markdown("- Get closer to the butterfly") | |
st.markdown("- Ensure the butterfly is clearly visible") | |
st.markdown("- Avoid blurry or dark images") | |
except Exception as e: | |
st.error(f"Error processing image: {str(e)}") | |
with tab2: | |
st.header("Upload Image") | |
st.write("Upload a clear photo of a butterfly for identification") | |
uploaded_file = st.file_uploader( | |
"Choose an image...", | |
type=["jpg", "jpeg", "png"], | |
help="Upload a clear photo of a butterfly" | |
) | |
if uploaded_file is not None: | |
try: | |
# Read file directly into memory | |
image_bytes = uploaded_file.read() | |
image = Image.open(io.BytesIO(image_bytes)).convert("RGB") | |
col1, col2 = st.columns(2) | |
with col1: | |
st.image(image, caption="Uploaded Image", use_column_width=True) | |
with col2: | |
predicted_class, confidence = predict_butterfly(image) | |
if predicted_class and confidence >= 0.80: # Only show if confidence >= 80% | |
st.success(f"**Prediction: {predicted_class}**") | |
st.info(f"Confidence: {confidence:.2%}") | |
if predicted_class in butterfly_info: | |
st.write("**Species Information:**") | |
st.write(butterfly_info[predicted_class]["description"]) | |
else: | |
st.warning("⚠️ **Image not clear - Unable to identify butterfly**") | |
st.info(f"Confidence too low: {confidence:.1%}") | |
st.markdown("**Tips for better results:**") | |
st.markdown("- Use better lighting") | |
st.markdown("- Get closer to the butterfly") | |
st.markdown("- Ensure the butterfly is clearly visible") | |
st.markdown("- Avoid blurry or dark images") | |
except Exception as e: | |
st.error(f"Error processing image: {str(e)}") | |
# Add footer with instructions | |
st.markdown("---") | |
st.markdown("### How to use:") | |
st.markdown("1. **Camera Capture**: Take a photo using your device camera") | |
st.markdown("2. **Upload Image**: Choose a butterfly photo from your device") | |
st.markdown("3. **Best Results**: Use clear, well-lit photos with the butterfly clearly visible") |