libkamaja_id / streamlit_app.py
leynessa's picture
Update streamlit_app.py
9d53d92 verified
raw
history blame
8.2 kB
# Full corrected bilingual Streamlit app for Butterfly Identifier
import streamlit as st
from PIL import Image
import torch
import json
import os
import io
import numpy as np
import timm
import albumentations as A
from albumentations.pytorch import ToTensorV2
import warnings
warnings.filterwarnings('ignore')
# Configure Streamlit
st.set_page_config(
page_title="Butterfly Identifier / Liblikamaja ID",
page_icon="🦋",
layout="wide"
)
# Load class names
@st.cache_data
def load_class_names():
try:
with open("class_names.txt", "r") as f:
return [line.strip() for line in f.readlines()]
except FileNotFoundError:
st.error("class_names.txt file not found!")
return []
class_names = load_class_names()
# Load butterfly info
@st.cache_data
def load_butterfly_info():
try:
with open("butterfly_info.json", "r") as f:
return json.load(f)
except:
return {}
butterfly_info = load_butterfly_info()
# Define transform matching training pipeline
inference_transform = A.Compose([
A.Resize(224, 224),
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
ToTensorV2()
])
# Load the model
@st.cache_resource
def load_model():
MODEL_PATH = "butterfly_classifier.pth"
if not os.path.exists(MODEL_PATH):
st.error("Model file not found!")
return None
try:
checkpoint = torch.load(MODEL_PATH, map_location='cpu')
model_state_dict = checkpoint['model_state_dict'] if 'model_state_dict' in checkpoint else checkpoint
num_classes = len(class_names)
# Attempt to auto-detect model from batch norm layer dimensions
bn2_shape = None
for key in model_state_dict:
if key.endswith("bn2.weight"):
bn2_shape = model_state_dict[key].shape[0]
break
feature_map = {
1280: 'efficientnet_b0',
1408: 'efficientnet_b1',
1536: 'efficientnet_b2',
1792: 'efficientnet_b3',
1920: 'efficientnet_b4',
2048: 'efficientnet_b5',
2304: 'efficientnet_b6',
2560: 'efficientnet_b7'
}
if bn2_shape is None:
st.warning("Could not detect classifier or bn2 layer in checkpoint. Defaulting to efficientnet_b3")
model_name = 'efficientnet_b3'
else:
model_name = feature_map.get(bn2_shape, 'efficientnet_b3')
st.info(f"Detected model architecture: {model_name}")
model = timm.create_model(model_name, pretrained=False, num_classes=num_classes, drop_rate=0.4, drop_path_rate=0.3)
model.load_state_dict(model_state_dict, strict=False)
model.eval()
return model
except Exception as e:
st.error(f"Error loading model: {str(e)}")
return None
model = load_model()
def predict_butterfly(image, threshold=0.5):
try:
if model is None:
raise ValueError("Model is not loaded.")
if image is None:
return None, None
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
if image.mode != 'RGB':
image = image.convert('RGB')
transformed = inference_transform(image=np.array(image))
input_tensor = transformed['image'].unsqueeze(0)
with torch.no_grad():
output = model(input_tensor)
probabilities = torch.nn.functional.softmax(output[0], dim=0)
confidence, pred = torch.max(probabilities, 0)
if confidence.item() < threshold:
return None, confidence.item()
predicted_class = class_names[pred.item()]
return predicted_class, confidence.item()
except Exception as e:
st.error(f"Prediction error: {str(e)}")
return None, None
# UI Code
st.title("🦋 Liblikamaja ID / Butterfly Identifier")
st.write("Tuvasta liblikaid oma kaamera abil või laadi üles pilt! / Identify butterflies using your camera or by uploading an image!")
tab1, tab2 = st.tabs(["📷 Live Camera / Kaamera", "📁 Upload Image / Laadi üles"])
with tab1:
st.header("Kaamera jäädvustamine / Camera Capture")
st.write("Tee pilt liblikast tuvastamiseks / Take a photo of a butterfly for identification.")
camera_photo = st.camera_input("Pildista liblikat / Capture a butterfly")
if camera_photo is not None:
try:
image = Image.open(camera_photo).convert("RGB")
col1, col2 = st.columns(2)
with col1:
st.image(image, caption="Jäädvustatud pilt / Captured Image", use_column_width=True)
with col2:
with st.spinner("Pildi analüüsimine... / Analyzing image..."):
predicted_class, confidence = predict_butterfly(image)
if predicted_class and confidence >= 0.50:
st.success(f"**Liblikas / Butterfly: {predicted_class}**")
if predicted_class in butterfly_info:
st.markdown("**Liigi kirjeldus / About this species:**")
st.write(butterfly_info[predicted_class]["description"])
else:
st.warning("⚠️ Ma ei tea, mis liblikas see on / I don't know what butterfly this is")
st.markdown("**Näpunäited paremate tulemuste saavutamiseks / Tips for better results:**")
st.markdown("- Kasutage paremat valgustust / Use better lighting")
st.markdown("- Veenduge, et liblikas oleks selgelt nähtav / Ensure the butterfly is clearly visible")
st.markdown("- Vältige uduseid või tumedaid pilte / Avoid blurry or dark images")
except Exception as e:
st.error(f"Error processing image: {str(e)}")
with tab2:
st.header("Laadi üles pilt / Upload Image")
st.write("Laadige üles liblika selge foto tuvastamiseks / Upload a clear photo of a butterfly for identification.")
uploaded_file = st.file_uploader("Vali pilt... / Choose an image...", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
try:
image_bytes = uploaded_file.read()
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
col1, col2 = st.columns(2)
with col1:
st.image(image, caption="Üleslaetud pilt / Uploaded Image", use_column_width=True)
with col2:
with st.spinner("Pildi analüüsimine... / Analyzing image..."):
predicted_class, confidence = predict_butterfly(image)
if predicted_class and confidence >= 0.50:
st.success(f"**Liblikas / Butterfly: {predicted_class}**")
if predicted_class in butterfly_info:
st.markdown("**Liigi kirjeldus / About this species:**")
st.write(butterfly_info[predicted_class]["description"])
else:
st.warning("⚠️ Ma ei tea, mis liblikas see on / I don't know what butterfly this is")
st.markdown("**Näpunäited paremate tulemuste saavutamiseks / Tips for better results:**")
st.markdown("- Kasutage paremat valgustust / Use better lighting")
st.markdown("- Veenduge, et liblikas oleks selgelt nähtav / Ensure the butterfly is clearly visible")
st.markdown("- Vältige uduseid või tumedaid pilte / Avoid blurry or dark images")
except Exception as e:
st.error(f"Error processing image: {str(e)}")
# Footer
st.markdown("---")
st.markdown("### Kuidas kasutada / How to use:")
st.markdown("1. **Kaamera jäädvustamine / Camera Capture**: Tehke foto oma seadme kaameraga / Take a photo using your device camera")
st.markdown("2. **Laadi pilt üles / Upload Image**: Vali oma seadmest liblika foto / Choose a butterfly photo from your device")
st.markdown("3. **Parimad tulemused / Best Results**: Kasutage selgeid ja hästi valgustatud fotosid, kus liblikas on selgelt nähtav / Use clear, well-lit photos with the butterfly clearly visible")