Spaces:
Sleeping
Sleeping
File size: 8,197 Bytes
e11e26a 02001f5 605629e bfc9fea 8c61898 e11e26a 1809961 02001f5 af316a5 e11e26a af316a5 02001f5 1809961 02001f5 1809961 8c61898 1809961 e11e26a 02001f5 e11e26a 02001f5 a233bbd e11e26a 1809961 9625ec8 1809961 e11e26a 1809961 06966cd 9d53d92 6be4ed9 06966cd 9d53d92 06966cd e11e26a 1809961 02001f5 7fb82fb 02001f5 118e762 8445a65 7fb82fb 8445a65 e11e26a 8445a65 8c61898 8445a65 118e762 62475ba e11e26a a91e356 e11e26a bfc9fea 9d53d92 bfc9fea e11e26a 38a7793 e11e26a 38a7793 e11e26a a91e356 8c61898 e11e26a 118e762 e11e26a 38a7793 bfc9fea e11e26a 605629e bfc9fea e11e26a bfc9fea e11e26a a91e356 e11e26a bfc9fea e11e26a bfc9fea e11e26a a91e356 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 |
# Full corrected bilingual Streamlit app for Butterfly Identifier
import streamlit as st
from PIL import Image
import torch
import json
import os
import io
import numpy as np
import timm
import albumentations as A
from albumentations.pytorch import ToTensorV2
import warnings
warnings.filterwarnings('ignore')
# Configure Streamlit
st.set_page_config(
page_title="Butterfly Identifier / Liblikamaja ID",
page_icon="🦋",
layout="wide"
)
# Load class names
@st.cache_data
def load_class_names():
try:
with open("class_names.txt", "r") as f:
return [line.strip() for line in f.readlines()]
except FileNotFoundError:
st.error("class_names.txt file not found!")
return []
class_names = load_class_names()
# Load butterfly info
@st.cache_data
def load_butterfly_info():
try:
with open("butterfly_info.json", "r") as f:
return json.load(f)
except:
return {}
butterfly_info = load_butterfly_info()
# Define transform matching training pipeline
inference_transform = A.Compose([
A.Resize(224, 224),
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
ToTensorV2()
])
# Load the model
@st.cache_resource
def load_model():
MODEL_PATH = "butterfly_classifier.pth"
if not os.path.exists(MODEL_PATH):
st.error("Model file not found!")
return None
try:
checkpoint = torch.load(MODEL_PATH, map_location='cpu')
model_state_dict = checkpoint['model_state_dict'] if 'model_state_dict' in checkpoint else checkpoint
num_classes = len(class_names)
# Attempt to auto-detect model from batch norm layer dimensions
bn2_shape = None
for key in model_state_dict:
if key.endswith("bn2.weight"):
bn2_shape = model_state_dict[key].shape[0]
break
feature_map = {
1280: 'efficientnet_b0',
1408: 'efficientnet_b1',
1536: 'efficientnet_b2',
1792: 'efficientnet_b3',
1920: 'efficientnet_b4',
2048: 'efficientnet_b5',
2304: 'efficientnet_b6',
2560: 'efficientnet_b7'
}
if bn2_shape is None:
st.warning("Could not detect classifier or bn2 layer in checkpoint. Defaulting to efficientnet_b3")
model_name = 'efficientnet_b3'
else:
model_name = feature_map.get(bn2_shape, 'efficientnet_b3')
st.info(f"Detected model architecture: {model_name}")
model = timm.create_model(model_name, pretrained=False, num_classes=num_classes, drop_rate=0.4, drop_path_rate=0.3)
model.load_state_dict(model_state_dict, strict=False)
model.eval()
return model
except Exception as e:
st.error(f"Error loading model: {str(e)}")
return None
model = load_model()
def predict_butterfly(image, threshold=0.5):
try:
if model is None:
raise ValueError("Model is not loaded.")
if image is None:
return None, None
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
if image.mode != 'RGB':
image = image.convert('RGB')
transformed = inference_transform(image=np.array(image))
input_tensor = transformed['image'].unsqueeze(0)
with torch.no_grad():
output = model(input_tensor)
probabilities = torch.nn.functional.softmax(output[0], dim=0)
confidence, pred = torch.max(probabilities, 0)
if confidence.item() < threshold:
return None, confidence.item()
predicted_class = class_names[pred.item()]
return predicted_class, confidence.item()
except Exception as e:
st.error(f"Prediction error: {str(e)}")
return None, None
# UI Code
st.title("🦋 Liblikamaja ID / Butterfly Identifier")
st.write("Tuvasta liblikaid oma kaamera abil või laadi üles pilt! / Identify butterflies using your camera or by uploading an image!")
tab1, tab2 = st.tabs(["📷 Live Camera / Kaamera", "📁 Upload Image / Laadi üles"])
with tab1:
st.header("Kaamera jäädvustamine / Camera Capture")
st.write("Tee pilt liblikast tuvastamiseks / Take a photo of a butterfly for identification.")
camera_photo = st.camera_input("Pildista liblikat / Capture a butterfly")
if camera_photo is not None:
try:
image = Image.open(camera_photo).convert("RGB")
col1, col2 = st.columns(2)
with col1:
st.image(image, caption="Jäädvustatud pilt / Captured Image", use_column_width=True)
with col2:
with st.spinner("Pildi analüüsimine... / Analyzing image..."):
predicted_class, confidence = predict_butterfly(image)
if predicted_class and confidence >= 0.50:
st.success(f"**Liblikas / Butterfly: {predicted_class}**")
if predicted_class in butterfly_info:
st.markdown("**Liigi kirjeldus / About this species:**")
st.write(butterfly_info[predicted_class]["description"])
else:
st.warning("⚠️ Ma ei tea, mis liblikas see on / I don't know what butterfly this is")
st.markdown("**Näpunäited paremate tulemuste saavutamiseks / Tips for better results:**")
st.markdown("- Kasutage paremat valgustust / Use better lighting")
st.markdown("- Veenduge, et liblikas oleks selgelt nähtav / Ensure the butterfly is clearly visible")
st.markdown("- Vältige uduseid või tumedaid pilte / Avoid blurry or dark images")
except Exception as e:
st.error(f"Error processing image: {str(e)}")
with tab2:
st.header("Laadi üles pilt / Upload Image")
st.write("Laadige üles liblika selge foto tuvastamiseks / Upload a clear photo of a butterfly for identification.")
uploaded_file = st.file_uploader("Vali pilt... / Choose an image...", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
try:
image_bytes = uploaded_file.read()
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
col1, col2 = st.columns(2)
with col1:
st.image(image, caption="Üleslaetud pilt / Uploaded Image", use_column_width=True)
with col2:
with st.spinner("Pildi analüüsimine... / Analyzing image..."):
predicted_class, confidence = predict_butterfly(image)
if predicted_class and confidence >= 0.50:
st.success(f"**Liblikas / Butterfly: {predicted_class}**")
if predicted_class in butterfly_info:
st.markdown("**Liigi kirjeldus / About this species:**")
st.write(butterfly_info[predicted_class]["description"])
else:
st.warning("⚠️ Ma ei tea, mis liblikas see on / I don't know what butterfly this is")
st.markdown("**Näpunäited paremate tulemuste saavutamiseks / Tips for better results:**")
st.markdown("- Kasutage paremat valgustust / Use better lighting")
st.markdown("- Veenduge, et liblikas oleks selgelt nähtav / Ensure the butterfly is clearly visible")
st.markdown("- Vältige uduseid või tumedaid pilte / Avoid blurry or dark images")
except Exception as e:
st.error(f"Error processing image: {str(e)}")
# Footer
st.markdown("---")
st.markdown("### Kuidas kasutada / How to use:")
st.markdown("1. **Kaamera jäädvustamine / Camera Capture**: Tehke foto oma seadme kaameraga / Take a photo using your device camera")
st.markdown("2. **Laadi pilt üles / Upload Image**: Vali oma seadmest liblika foto / Choose a butterfly photo from your device")
st.markdown("3. **Parimad tulemused / Best Results**: Kasutage selgeid ja hästi valgustatud fotosid, kus liblikas on selgelt nähtav / Use clear, well-lit photos with the butterfly clearly visible")
|