Spaces:
Sleeping
Sleeping
# Full corrected bilingual Streamlit app for Butterfly Identifier | |
import streamlit as st | |
from PIL import Image | |
import torch | |
import json | |
import os | |
import io | |
import numpy as np | |
import timm | |
import albumentations as A | |
from albumentations.pytorch import ToTensorV2 | |
import warnings | |
warnings.filterwarnings('ignore') | |
# Configure Streamlit | |
st.set_page_config( | |
page_title="Butterfly Identifier / Liblikamaja ID", | |
page_icon="🦋", | |
layout="wide" | |
) | |
# Load class names | |
def load_class_names(): | |
try: | |
with open("class_names.txt", "r") as f: | |
return [line.strip() for line in f.readlines()] | |
except FileNotFoundError: | |
st.error("class_names.txt file not found!") | |
return [] | |
class_names = load_class_names() | |
# Load butterfly info | |
def load_butterfly_info(): | |
try: | |
with open("butterfly_info.json", "r") as f: | |
return json.load(f) | |
except: | |
return {} | |
butterfly_info = load_butterfly_info() | |
# Define transform matching training pipeline | |
inference_transform = A.Compose([ | |
A.Resize(224, 224), | |
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
ToTensorV2() | |
]) | |
# Load the model | |
def load_model(): | |
MODEL_PATH = "butterfly_classifier.pth" | |
if not os.path.exists(MODEL_PATH): | |
st.error("Model file not found!") | |
return None | |
try: | |
checkpoint = torch.load(MODEL_PATH, map_location='cpu') | |
model_state_dict = checkpoint['model_state_dict'] if 'model_state_dict' in checkpoint else checkpoint | |
num_classes = len(class_names) | |
# Attempt to auto-detect model from batch norm layer dimensions | |
bn2_shape = None | |
for key in model_state_dict: | |
if key.endswith("bn2.weight"): | |
bn2_shape = model_state_dict[key].shape[0] | |
break | |
feature_map = { | |
1280: 'efficientnet_b0', | |
1408: 'efficientnet_b1', | |
1536: 'efficientnet_b2', | |
1792: 'efficientnet_b3', | |
1920: 'efficientnet_b4', | |
2048: 'efficientnet_b5', | |
2304: 'efficientnet_b6', | |
2560: 'efficientnet_b7' | |
} | |
if bn2_shape is None: | |
st.warning("Could not detect classifier or bn2 layer in checkpoint. Defaulting to efficientnet_b3") | |
model_name = 'efficientnet_b3' | |
else: | |
model_name = feature_map.get(bn2_shape, 'efficientnet_b3') | |
st.info(f"Detected model architecture: {model_name}") | |
model = timm.create_model(model_name, pretrained=False, num_classes=num_classes, drop_rate=0.4, drop_path_rate=0.3) | |
model.load_state_dict(model_state_dict, strict=False) | |
model.eval() | |
return model | |
except Exception as e: | |
st.error(f"Error loading model: {str(e)}") | |
return None | |
model = load_model() | |
def predict_butterfly(image, threshold=0.5): | |
try: | |
if model is None: | |
raise ValueError("Model is not loaded.") | |
if image is None: | |
return None, None | |
if isinstance(image, np.ndarray): | |
image = Image.fromarray(image) | |
if image.mode != 'RGB': | |
image = image.convert('RGB') | |
transformed = inference_transform(image=np.array(image)) | |
input_tensor = transformed['image'].unsqueeze(0) | |
with torch.no_grad(): | |
output = model(input_tensor) | |
probabilities = torch.nn.functional.softmax(output[0], dim=0) | |
confidence, pred = torch.max(probabilities, 0) | |
if confidence.item() < threshold: | |
return None, confidence.item() | |
predicted_class = class_names[pred.item()] | |
return predicted_class, confidence.item() | |
except Exception as e: | |
st.error(f"Prediction error: {str(e)}") | |
return None, None | |
# UI Code | |
st.title("🦋 Liblikamaja ID / Butterfly Identifier") | |
st.write("Tuvasta liblikaid oma kaamera abil või laadi üles pilt! / Identify butterflies using your camera or by uploading an image!") | |
tab1, tab2 = st.tabs(["📷 Live Camera / Kaamera", "📁 Upload Image / Laadi üles"]) | |
with tab1: | |
st.header("Kaamera jäädvustamine / Camera Capture") | |
st.write("Tee pilt liblikast tuvastamiseks / Take a photo of a butterfly for identification.") | |
camera_photo = st.camera_input("Pildista liblikat / Capture a butterfly") | |
if camera_photo is not None: | |
try: | |
image = Image.open(camera_photo).convert("RGB") | |
col1, col2 = st.columns(2) | |
with col1: | |
st.image(image, caption="Jäädvustatud pilt / Captured Image", use_column_width=True) | |
with col2: | |
with st.spinner("Pildi analüüsimine... / Analyzing image..."): | |
predicted_class, confidence = predict_butterfly(image) | |
if predicted_class and confidence >= 0.50: | |
st.success(f"**Liblikas / Butterfly: {predicted_class}**") | |
if predicted_class in butterfly_info: | |
st.markdown("**Liigi kirjeldus / About this species:**") | |
st.write(butterfly_info[predicted_class]["description"]) | |
else: | |
st.warning("⚠️ Ma ei tea, mis liblikas see on / I don't know what butterfly this is") | |
st.markdown("**Näpunäited paremate tulemuste saavutamiseks / Tips for better results:**") | |
st.markdown("- Kasutage paremat valgustust / Use better lighting") | |
st.markdown("- Veenduge, et liblikas oleks selgelt nähtav / Ensure the butterfly is clearly visible") | |
st.markdown("- Vältige uduseid või tumedaid pilte / Avoid blurry or dark images") | |
except Exception as e: | |
st.error(f"Error processing image: {str(e)}") | |
with tab2: | |
st.header("Laadi üles pilt / Upload Image") | |
st.write("Laadige üles liblika selge foto tuvastamiseks / Upload a clear photo of a butterfly for identification.") | |
uploaded_file = st.file_uploader("Vali pilt... / Choose an image...", type=["jpg", "jpeg", "png"]) | |
if uploaded_file is not None: | |
try: | |
image_bytes = uploaded_file.read() | |
image = Image.open(io.BytesIO(image_bytes)).convert("RGB") | |
col1, col2 = st.columns(2) | |
with col1: | |
st.image(image, caption="Üleslaetud pilt / Uploaded Image", use_column_width=True) | |
with col2: | |
with st.spinner("Pildi analüüsimine... / Analyzing image..."): | |
predicted_class, confidence = predict_butterfly(image) | |
if predicted_class and confidence >= 0.50: | |
st.success(f"**Liblikas / Butterfly: {predicted_class}**") | |
if predicted_class in butterfly_info: | |
st.markdown("**Liigi kirjeldus / About this species:**") | |
st.write(butterfly_info[predicted_class]["description"]) | |
else: | |
st.warning("⚠️ Ma ei tea, mis liblikas see on / I don't know what butterfly this is") | |
st.markdown("**Näpunäited paremate tulemuste saavutamiseks / Tips for better results:**") | |
st.markdown("- Kasutage paremat valgustust / Use better lighting") | |
st.markdown("- Veenduge, et liblikas oleks selgelt nähtav / Ensure the butterfly is clearly visible") | |
st.markdown("- Vältige uduseid või tumedaid pilte / Avoid blurry or dark images") | |
except Exception as e: | |
st.error(f"Error processing image: {str(e)}") | |
# Footer | |
st.markdown("---") | |
st.markdown("### Kuidas kasutada / How to use:") | |
st.markdown("1. **Kaamera jäädvustamine / Camera Capture**: Tehke foto oma seadme kaameraga / Take a photo using your device camera") | |
st.markdown("2. **Laadi pilt üles / Upload Image**: Vali oma seadmest liblika foto / Choose a butterfly photo from your device") | |
st.markdown("3. **Parimad tulemused / Best Results**: Kasutage selgeid ja hästi valgustatud fotosid, kus liblikas on selgelt nähtav / Use clear, well-lit photos with the butterfly clearly visible") | |