Spaces:
Sleeping
Sleeping
# Enhanced Butterfly Identifier Streamlit App with Better Model Loading | |
import streamlit as st | |
from PIL import Image | |
import torch | |
import json | |
import os | |
import io | |
import numpy as np | |
import timm | |
import albumentations as A | |
from albumentations.pytorch import ToTensorV2 | |
import warnings | |
warnings.filterwarnings('ignore') | |
# Configure Streamlit | |
st.set_page_config( | |
page_title="Butterfly Identifier / Liblikamaja ID", | |
page_icon="🦋", | |
layout="wide" | |
) | |
# Load class names | |
def load_class_names(): | |
try: | |
with open("class_names.txt", "r") as f: | |
return [line.strip() for line in f.readlines()] | |
except FileNotFoundError: | |
st.error("class_names.txt file not found!") | |
return [] | |
class_names = load_class_names() | |
# Load butterfly info | |
def load_butterfly_info(): | |
try: | |
with open("butterfly_info.json", "r") as f: | |
return json.load(f) | |
except: | |
return {} | |
butterfly_info = load_butterfly_info() | |
# Define transform matching training pipeline | |
inference_transform = A.Compose([ | |
A.Resize(224, 224), | |
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
ToTensorV2() | |
]) | |
# Enhanced model loading function | |
def load_model(): | |
"""Enhanced model loading with better architecture detection""" | |
# Try different model file names | |
model_files = [ | |
"butterfly_classifier.pth", | |
"best_butterfly_model_v3.pth", | |
"best_butterfly_model.pth" | |
] | |
MODEL_PATH = None | |
for model_file in model_files: | |
if os.path.exists(model_file): | |
MODEL_PATH = model_file | |
break | |
if MODEL_PATH is None: | |
st.error("No model file found!") | |
return None | |
st.info(f"Loading model from: {MODEL_PATH}") | |
try: | |
# Load checkpoint | |
checkpoint = torch.load(MODEL_PATH, map_location='cpu') | |
# Extract model state dict | |
if 'model_state_dict' in checkpoint: | |
model_state_dict = checkpoint['model_state_dict'] | |
else: | |
model_state_dict = checkpoint | |
num_classes = len(class_names) | |
# Better architecture detection based on conv_stem channels | |
def detect_architecture_by_channels(state_dict): | |
"""Detect architecture by examining conv_stem channels""" | |
for key, tensor in state_dict.items(): | |
if key.endswith('conv_stem.weight'): | |
channels = tensor.shape[0] # Output channels | |
# Map channels to likely architectures | |
channel_map = { | |
24: ['tf_efficientnetv2_s', 'efficientnet_b0'], | |
32: ['tf_efficientnetv2_s', 'efficientnet_b1'], | |
40: ['efficientnet_b3', 'efficientnet_b2'], | |
48: ['efficientnet_b4', 'tf_efficientnetv2_m'], | |
56: ['efficientnet_b5'], | |
64: ['efficientnet_b6', 'tf_efficientnetv2_l'], | |
72: ['efficientnet_b7'] | |
} | |
return channel_map.get(channels, ['tf_efficientnetv2_s']) | |
return ['tf_efficientnetv2_s'] | |
# Get likely architectures based on channels | |
likely_architectures = detect_architecture_by_channels(model_state_dict) | |
# Expanded list of architectures to try | |
architectures_to_try = likely_architectures + [ | |
'tf_efficientnetv2_s', | |
'efficientnet_b0', | |
'efficientnet_b1', | |
'efficientnet_b2', | |
'efficientnet_b3', | |
'tf_efficientnetv2_m', | |
'efficientnet_b4' | |
] | |
# Remove duplicates while preserving order | |
seen = set() | |
architectures_to_try = [x for x in architectures_to_try if not (x in seen or seen.add(x))] | |
model = None | |
successful_arch = None | |
# Try each architecture | |
for arch in architectures_to_try: | |
try: | |
st.info(f"Trying architecture: {arch}") | |
# Create model | |
model = timm.create_model( | |
arch, | |
pretrained=False, | |
num_classes=num_classes, | |
drop_rate=0.0, # Set to 0 for inference | |
drop_path_rate=0.0 # Set to 0 for inference | |
) | |
# Try to load the state dict | |
try: | |
model.load_state_dict(model_state_dict, strict=True) | |
st.success(f"✅ Successfully loaded model with architecture: {arch}") | |
successful_arch = arch | |
break | |
except Exception as e: | |
# Try with strict=False | |
try: | |
model.load_state_dict(model_state_dict, strict=False) | |
st.warning(f"⚠️ Loaded {arch} with some mismatched weights") | |
successful_arch = arch | |
break | |
except Exception as e2: | |
st.warning(f"Failed to load {arch}: {str(e2)}") | |
continue | |
except Exception as e: | |
st.warning(f"Failed to create model {arch}: {str(e)}") | |
continue | |
if model is None: | |
st.error("❌ Failed to load model with any architecture!") | |
return None | |
# Set model to evaluation mode | |
model.eval() | |
# Display model info | |
total_params = sum(p.numel() for p in model.parameters()) | |
st.success(f"✅ Model loaded successfully!") | |
st.info(f"📊 Model: {successful_arch}") | |
st.info(f"🔢 Parameters: {total_params:,}") | |
st.info(f"🎯 Classes: {num_classes}") | |
return model | |
except Exception as e: | |
st.error(f"❌ Error loading model: {str(e)}") | |
return None | |
# Load model | |
model = load_model() | |
def predict_butterfly(image, threshold=0.5): | |
"""Predict butterfly species from image""" | |
try: | |
if model is None: | |
raise ValueError("Model is not loaded.") | |
if image is None: | |
return None, None | |
# Convert to PIL Image if needed | |
if isinstance(image, np.ndarray): | |
image = Image.fromarray(image) | |
if image.mode != 'RGB': | |
image = image.convert('RGB') | |
# Apply transforms | |
transformed = inference_transform(image=np.array(image)) | |
input_tensor = transformed['image'].unsqueeze(0) | |
# Make prediction | |
with torch.no_grad(): | |
output = model(input_tensor) | |
probabilities = torch.nn.functional.softmax(output[0], dim=0) | |
confidence, pred = torch.max(probabilities, 0) | |
if confidence.item() < threshold: | |
return None, confidence.item() | |
predicted_class = class_names[pred.item()] | |
return predicted_class, confidence.item() | |
except Exception as e: | |
st.error(f"Prediction error: {str(e)}") | |
return None, None | |
def predict_with_tta(image, threshold=0.5, num_tta=5): | |
"""Predict with Test Time Augmentation for better accuracy""" | |
try: | |
if model is None: | |
raise ValueError("Model is not loaded.") | |
if image is None: | |
return None, None | |
# Convert to PIL Image if needed | |
if isinstance(image, np.ndarray): | |
image = Image.fromarray(image) | |
if image.mode != 'RGB': | |
image = image.convert('RGB') | |
# Convert to numpy for albumentations | |
image_np = np.array(image) | |
# TTA transforms | |
tta_transforms = [ | |
A.Compose([ | |
A.Resize(224, 224), | |
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
ToTensorV2() | |
]), | |
A.Compose([ | |
A.Resize(256, 256), | |
A.CenterCrop(224, 224), | |
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
ToTensorV2() | |
]), | |
A.Compose([ | |
A.Resize(224, 224), | |
A.HorizontalFlip(p=1.0), | |
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
ToTensorV2() | |
]), | |
A.Compose([ | |
A.Resize(240, 240), | |
A.Rotate(limit=10, p=1.0), | |
A.CenterCrop(224, 224), | |
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
ToTensorV2() | |
]), | |
A.Compose([ | |
A.Resize(224, 224), | |
A.ColorJitter(brightness=0.1, contrast=0.1, p=1.0), | |
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
ToTensorV2() | |
]) | |
] | |
predictions = [] | |
for i, transform in enumerate(tta_transforms[:num_tta]): | |
transformed = transform(image=image_np) | |
input_tensor = transformed['image'].unsqueeze(0) | |
with torch.no_grad(): | |
output = model(input_tensor) | |
probabilities = torch.nn.functional.softmax(output, dim=1) | |
predictions.append(probabilities) | |
# Average predictions | |
avg_predictions = torch.mean(torch.stack(predictions), dim=0) | |
confidence, pred = torch.max(avg_predictions, 1) | |
if confidence.item() < threshold: | |
return None, confidence.item() | |
predicted_class = class_names[pred.item()] | |
return predicted_class, confidence.item() | |
except Exception as e: | |
st.error(f"TTA Prediction error: {str(e)}") | |
return None, None | |
# UI Code | |
st.title("🦋 Liblikamaja ID / Butterfly Identifier") | |
st.write("Tuvasta liblikaid oma kaamera abil või laadi üles pilt! / Identify butterflies using your camera or by uploading an image!") | |
# Add model status indicator | |
if model is not None: | |
st.success("✅ Model loaded and ready!") | |
else: | |
st.error("❌ Model not loaded. Please check your model file.") | |
st.stop() | |
# Add advanced options | |
with st.expander("🔧 Advanced Options / Täpsemad seaded"): | |
confidence_threshold = st.slider( | |
"Confidence Threshold / Kindluse lävi", | |
min_value=0.1, | |
max_value=1.0, | |
value=0.5, | |
step=0.05, | |
help="Higher values = more conservative predictions" | |
) | |
use_tta = st.checkbox( | |
"Use Test Time Augmentation (TTA) / Kasuta TTA", | |
value=False, | |
help="Slower but potentially more accurate predictions" | |
) | |
if use_tta: | |
tta_rounds = st.slider( | |
"TTA Rounds / TTA ringid", | |
min_value=3, | |
max_value=8, | |
value=5, | |
help="More rounds = slower but potentially more accurate" | |
) | |
tab1, tab2 = st.tabs(["📷 Live Camera / Kaamera", "📁 Upload Image / Laadi üles"]) | |
with tab1: | |
st.header("Kaamera jäädvustamine / Camera Capture") | |
st.write("Tee pilt liblikast tuvastamiseks / Take a photo of a butterfly for identification.") | |
camera_photo = st.camera_input("Pildista liblikat / Capture a butterfly") | |
if camera_photo is not None: | |
try: | |
image = Image.open(camera_photo).convert("RGB") | |
col1, col2 = st.columns(2) | |
with col1: | |
st.image(image, caption="Jäädvustatud pilt / Captured Image", use_column_width=True) | |
with col2: | |
with st.spinner("Pildi analüüsimine... / Analyzing image..."): | |
if use_tta: | |
predicted_class, confidence = predict_with_tta(image, confidence_threshold, tta_rounds) | |
else: | |
predicted_class, confidence = predict_butterfly(image, confidence_threshold) | |
if predicted_class and confidence >= confidence_threshold: | |
st.success(f"**Liblikas / Butterfly: {predicted_class}**") | |
st.info(f"Confidence: {confidence:.2%}") | |
if predicted_class in butterfly_info: | |
st.markdown("**Liigi kirjeldus / About this species:**") | |
st.write(butterfly_info[predicted_class]["description"]) | |
else: | |
st.info("No additional information available for this species.") | |
else: | |
confidence_text = f" (Confidence: {confidence:.2%})" if confidence else "" | |
st.warning(f"⚠️ Ma ei tea, mis liblikas see on / I don't know what butterfly this is{confidence_text}") | |
st.markdown("**Näpunäited paremate tulemuste saavutamiseks / Tips for better results:**") | |
st.markdown("- Kasutage paremat valgustust / Use better lighting") | |
st.markdown("- Veenduge, et liblikas oleks selgelt nähtav / Ensure the butterfly is clearly visible") | |
st.markdown("- Vältige uduseid või tumedaid pilte / Avoid blurry or dark images") | |
st.markdown("- Proovige madalamat kindluse läviväärtust / Try a lower confidence threshold") | |
except Exception as e: | |
st.error(f"Error processing image: {str(e)}") | |
with tab2: | |
st.header("Laadi üles pilt / Upload Image") | |
st.write("Laadige üles liblika selge foto tuvastamiseks / Upload a clear photo of a butterfly for identification.") | |
uploaded_file = st.file_uploader("Vali pilt... / Choose an image...", type=["jpg", "jpeg", "png"]) | |
if uploaded_file is not None: | |
try: | |
image_bytes = uploaded_file.read() | |
image = Image.open(io.BytesIO(image_bytes)).convert("RGB") | |
col1, col2 = st.columns(2) | |
with col1: | |
st.image(image, caption="Üleslaetud pilt / Uploaded Image", use_column_width=True) | |
with col2: | |
with st.spinner("Pildi analüüsimine... / Analyzing image..."): | |
if use_tta: | |
predicted_class, confidence = predict_with_tta(image, confidence_threshold, tta_rounds) | |
else: | |
predicted_class, confidence = predict_butterfly(image, confidence_threshold) | |
if predicted_class and confidence >= confidence_threshold: | |
st.success(f"**Liblikas / Butterfly: {predicted_class}**") | |
st.info(f"Confidence: {confidence:.2%}") | |
if predicted_class in butterfly_info: | |
st.markdown("**Liigi kirjeldus / About this species:**") | |
st.write(butterfly_info[predicted_class]["description"]) | |
else: | |
st.info("No additional information available for this species.") | |
else: | |
confidence_text = f" (Confidence: {confidence:.2%})" if confidence else "" | |
st.warning(f"⚠️ Ma ei tea, mis liblikas see on / I don't know what butterfly this is{confidence_text}") | |
st.markdown("**Näpunäited paremate tulemuste saavutamiseks / Tips for better results:**") | |
st.markdown("- Kasutage paremat valgustust / Use better lighting") | |
st.markdown("- Veenduge, et liblikas oleks selgelt nähtav / Ensure the butterfly is clearly visible") | |
st.markdown("- Vältige uduseid või tumedaid pilte / Avoid blurry or dark images") | |
st.markdown("- Proovige madalamat kindluse läviväärtust / Try a lower confidence threshold") | |
except Exception as e: | |
st.error(f"Error processing image: {str(e)}") | |
# Footer | |
st.markdown("---") | |
st.markdown("### Kuidas kasutada / How to use:") | |
st.markdown("1. **Kaamera jäädvustamine / Camera Capture**: Tehke foto oma seadme kaameraga / Take a photo using your device camera") | |
st.markdown("2. **Laadi pilt üles / Upload Image**: Vali oma seadmest liblika foto / Choose a butterfly photo from your device") | |
st.markdown("3. **Parimad tulemused / Best Results**: Kasutage selgeid ja hästi valgustatud fotosid, kus liblikas on selgelt nähtav / Use clear, well-lit photos with the butterfly clearly visible") | |
st.markdown("4. **Täpsemad seaded / Advanced Options**: Kohandage kindluse lävi ja kasutage TTA paremate tulemuste saamiseks / Adjust confidence threshold and use TTA for better results") | |
# Debug info | |
if st.checkbox("Show Debug Info"): | |
st.write("**Class Names:**", class_names) | |
st.write("**Number of Classes:**", len(class_names)) | |
st.write("**Model Status:**", "Loaded" if model else "Not Loaded") | |
if butterfly_info: | |
st.write("**Species Info Available:**", len(butterfly_info)) | |