libkamaja_id / streamlit_app.py
leynessa's picture
Update streamlit_app.py
2edff24 verified
# Enhanced Butterfly Identifier Streamlit App with Better Model Loading
import streamlit as st
from PIL import Image
import torch
import json
import os
import io
import numpy as np
import timm
import albumentations as A
from albumentations.pytorch import ToTensorV2
import warnings
warnings.filterwarnings('ignore')
# Configure Streamlit
st.set_page_config(
page_title="Butterfly Identifier / Liblikamaja ID",
page_icon="🦋",
layout="wide"
)
# Load class names
@st.cache_data
def load_class_names():
try:
with open("class_names.txt", "r") as f:
return [line.strip() for line in f.readlines()]
except FileNotFoundError:
st.error("class_names.txt file not found!")
return []
class_names = load_class_names()
# Load butterfly info
@st.cache_data
def load_butterfly_info():
try:
with open("butterfly_info.json", "r") as f:
return json.load(f)
except:
return {}
butterfly_info = load_butterfly_info()
# Define transform matching training pipeline
inference_transform = A.Compose([
A.Resize(224, 224),
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
ToTensorV2()
])
# Enhanced model loading function
@st.cache_resource
def load_model():
"""Enhanced model loading with better architecture detection"""
# Try different model file names
model_files = [
"butterfly_classifier.pth",
"best_butterfly_model_v3.pth",
"best_butterfly_model.pth"
]
MODEL_PATH = None
for model_file in model_files:
if os.path.exists(model_file):
MODEL_PATH = model_file
break
if MODEL_PATH is None:
st.error("No model file found!")
return None
st.info(f"Loading model from: {MODEL_PATH}")
try:
# Load checkpoint
checkpoint = torch.load(MODEL_PATH, map_location='cpu')
# Extract model state dict
if 'model_state_dict' in checkpoint:
model_state_dict = checkpoint['model_state_dict']
else:
model_state_dict = checkpoint
num_classes = len(class_names)
# Better architecture detection based on conv_stem channels
def detect_architecture_by_channels(state_dict):
"""Detect architecture by examining conv_stem channels"""
for key, tensor in state_dict.items():
if key.endswith('conv_stem.weight'):
channels = tensor.shape[0] # Output channels
# Map channels to likely architectures
channel_map = {
24: ['tf_efficientnetv2_s', 'efficientnet_b0'],
32: ['tf_efficientnetv2_s', 'efficientnet_b1'],
40: ['efficientnet_b3', 'efficientnet_b2'],
48: ['efficientnet_b4', 'tf_efficientnetv2_m'],
56: ['efficientnet_b5'],
64: ['efficientnet_b6', 'tf_efficientnetv2_l'],
72: ['efficientnet_b7']
}
return channel_map.get(channels, ['tf_efficientnetv2_s'])
return ['tf_efficientnetv2_s']
# Get likely architectures based on channels
likely_architectures = detect_architecture_by_channels(model_state_dict)
# Expanded list of architectures to try
architectures_to_try = likely_architectures + [
'tf_efficientnetv2_s',
'efficientnet_b0',
'efficientnet_b1',
'efficientnet_b2',
'efficientnet_b3',
'tf_efficientnetv2_m',
'efficientnet_b4'
]
# Remove duplicates while preserving order
seen = set()
architectures_to_try = [x for x in architectures_to_try if not (x in seen or seen.add(x))]
model = None
successful_arch = None
# Try each architecture
for arch in architectures_to_try:
try:
st.info(f"Trying architecture: {arch}")
# Create model
model = timm.create_model(
arch,
pretrained=False,
num_classes=num_classes,
drop_rate=0.0, # Set to 0 for inference
drop_path_rate=0.0 # Set to 0 for inference
)
# Try to load the state dict
try:
model.load_state_dict(model_state_dict, strict=True)
st.success(f"✅ Successfully loaded model with architecture: {arch}")
successful_arch = arch
break
except Exception as e:
# Try with strict=False
try:
model.load_state_dict(model_state_dict, strict=False)
st.warning(f"⚠️ Loaded {arch} with some mismatched weights")
successful_arch = arch
break
except Exception as e2:
st.warning(f"Failed to load {arch}: {str(e2)}")
continue
except Exception as e:
st.warning(f"Failed to create model {arch}: {str(e)}")
continue
if model is None:
st.error("❌ Failed to load model with any architecture!")
return None
# Set model to evaluation mode
model.eval()
# Display model info
total_params = sum(p.numel() for p in model.parameters())
st.success(f"✅ Model loaded successfully!")
st.info(f"📊 Model: {successful_arch}")
st.info(f"🔢 Parameters: {total_params:,}")
st.info(f"🎯 Classes: {num_classes}")
return model
except Exception as e:
st.error(f"❌ Error loading model: {str(e)}")
return None
# Load model
model = load_model()
def predict_butterfly(image, threshold=0.5):
"""Predict butterfly species from image"""
try:
if model is None:
raise ValueError("Model is not loaded.")
if image is None:
return None, None
# Convert to PIL Image if needed
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
if image.mode != 'RGB':
image = image.convert('RGB')
# Apply transforms
transformed = inference_transform(image=np.array(image))
input_tensor = transformed['image'].unsqueeze(0)
# Make prediction
with torch.no_grad():
output = model(input_tensor)
probabilities = torch.nn.functional.softmax(output[0], dim=0)
confidence, pred = torch.max(probabilities, 0)
if confidence.item() < threshold:
return None, confidence.item()
predicted_class = class_names[pred.item()]
return predicted_class, confidence.item()
except Exception as e:
st.error(f"Prediction error: {str(e)}")
return None, None
def predict_with_tta(image, threshold=0.5, num_tta=5):
"""Predict with Test Time Augmentation for better accuracy"""
try:
if model is None:
raise ValueError("Model is not loaded.")
if image is None:
return None, None
# Convert to PIL Image if needed
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
if image.mode != 'RGB':
image = image.convert('RGB')
# Convert to numpy for albumentations
image_np = np.array(image)
# TTA transforms
tta_transforms = [
A.Compose([
A.Resize(224, 224),
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
ToTensorV2()
]),
A.Compose([
A.Resize(256, 256),
A.CenterCrop(224, 224),
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
ToTensorV2()
]),
A.Compose([
A.Resize(224, 224),
A.HorizontalFlip(p=1.0),
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
ToTensorV2()
]),
A.Compose([
A.Resize(240, 240),
A.Rotate(limit=10, p=1.0),
A.CenterCrop(224, 224),
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
ToTensorV2()
]),
A.Compose([
A.Resize(224, 224),
A.ColorJitter(brightness=0.1, contrast=0.1, p=1.0),
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
ToTensorV2()
])
]
predictions = []
for i, transform in enumerate(tta_transforms[:num_tta]):
transformed = transform(image=image_np)
input_tensor = transformed['image'].unsqueeze(0)
with torch.no_grad():
output = model(input_tensor)
probabilities = torch.nn.functional.softmax(output, dim=1)
predictions.append(probabilities)
# Average predictions
avg_predictions = torch.mean(torch.stack(predictions), dim=0)
confidence, pred = torch.max(avg_predictions, 1)
if confidence.item() < threshold:
return None, confidence.item()
predicted_class = class_names[pred.item()]
return predicted_class, confidence.item()
except Exception as e:
st.error(f"TTA Prediction error: {str(e)}")
return None, None
# UI Code
st.title("🦋 Liblikamaja ID / Butterfly Identifier")
st.write("Tuvasta liblikaid oma kaamera abil või laadi üles pilt! / Identify butterflies using your camera or by uploading an image!")
# Add model status indicator
if model is not None:
st.success("✅ Model loaded and ready!")
else:
st.error("❌ Model not loaded. Please check your model file.")
st.stop()
# Add advanced options
with st.expander("🔧 Advanced Options / Täpsemad seaded"):
confidence_threshold = st.slider(
"Confidence Threshold / Kindluse lävi",
min_value=0.1,
max_value=1.0,
value=0.5,
step=0.05,
help="Higher values = more conservative predictions"
)
use_tta = st.checkbox(
"Use Test Time Augmentation (TTA) / Kasuta TTA",
value=False,
help="Slower but potentially more accurate predictions"
)
if use_tta:
tta_rounds = st.slider(
"TTA Rounds / TTA ringid",
min_value=3,
max_value=8,
value=5,
help="More rounds = slower but potentially more accurate"
)
tab1, tab2 = st.tabs(["📷 Live Camera / Kaamera", "📁 Upload Image / Laadi üles"])
with tab1:
st.header("Kaamera jäädvustamine / Camera Capture")
st.write("Tee pilt liblikast tuvastamiseks / Take a photo of a butterfly for identification.")
camera_photo = st.camera_input("Pildista liblikat / Capture a butterfly")
if camera_photo is not None:
try:
image = Image.open(camera_photo).convert("RGB")
col1, col2 = st.columns(2)
with col1:
st.image(image, caption="Jäädvustatud pilt / Captured Image", use_column_width=True)
with col2:
with st.spinner("Pildi analüüsimine... / Analyzing image..."):
if use_tta:
predicted_class, confidence = predict_with_tta(image, confidence_threshold, tta_rounds)
else:
predicted_class, confidence = predict_butterfly(image, confidence_threshold)
if predicted_class and confidence >= confidence_threshold:
st.success(f"**Liblikas / Butterfly: {predicted_class}**")
st.info(f"Confidence: {confidence:.2%}")
if predicted_class in butterfly_info:
st.markdown("**Liigi kirjeldus / About this species:**")
st.write(butterfly_info[predicted_class]["description"])
else:
st.info("No additional information available for this species.")
else:
confidence_text = f" (Confidence: {confidence:.2%})" if confidence else ""
st.warning(f"⚠️ Ma ei tea, mis liblikas see on / I don't know what butterfly this is{confidence_text}")
st.markdown("**Näpunäited paremate tulemuste saavutamiseks / Tips for better results:**")
st.markdown("- Kasutage paremat valgustust / Use better lighting")
st.markdown("- Veenduge, et liblikas oleks selgelt nähtav / Ensure the butterfly is clearly visible")
st.markdown("- Vältige uduseid või tumedaid pilte / Avoid blurry or dark images")
st.markdown("- Proovige madalamat kindluse läviväärtust / Try a lower confidence threshold")
except Exception as e:
st.error(f"Error processing image: {str(e)}")
with tab2:
st.header("Laadi üles pilt / Upload Image")
st.write("Laadige üles liblika selge foto tuvastamiseks / Upload a clear photo of a butterfly for identification.")
uploaded_file = st.file_uploader("Vali pilt... / Choose an image...", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
try:
image_bytes = uploaded_file.read()
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
col1, col2 = st.columns(2)
with col1:
st.image(image, caption="Üleslaetud pilt / Uploaded Image", use_column_width=True)
with col2:
with st.spinner("Pildi analüüsimine... / Analyzing image..."):
if use_tta:
predicted_class, confidence = predict_with_tta(image, confidence_threshold, tta_rounds)
else:
predicted_class, confidence = predict_butterfly(image, confidence_threshold)
if predicted_class and confidence >= confidence_threshold:
st.success(f"**Liblikas / Butterfly: {predicted_class}**")
st.info(f"Confidence: {confidence:.2%}")
if predicted_class in butterfly_info:
st.markdown("**Liigi kirjeldus / About this species:**")
st.write(butterfly_info[predicted_class]["description"])
else:
st.info("No additional information available for this species.")
else:
confidence_text = f" (Confidence: {confidence:.2%})" if confidence else ""
st.warning(f"⚠️ Ma ei tea, mis liblikas see on / I don't know what butterfly this is{confidence_text}")
st.markdown("**Näpunäited paremate tulemuste saavutamiseks / Tips for better results:**")
st.markdown("- Kasutage paremat valgustust / Use better lighting")
st.markdown("- Veenduge, et liblikas oleks selgelt nähtav / Ensure the butterfly is clearly visible")
st.markdown("- Vältige uduseid või tumedaid pilte / Avoid blurry or dark images")
st.markdown("- Proovige madalamat kindluse läviväärtust / Try a lower confidence threshold")
except Exception as e:
st.error(f"Error processing image: {str(e)}")
# Footer
st.markdown("---")
st.markdown("### Kuidas kasutada / How to use:")
st.markdown("1. **Kaamera jäädvustamine / Camera Capture**: Tehke foto oma seadme kaameraga / Take a photo using your device camera")
st.markdown("2. **Laadi pilt üles / Upload Image**: Vali oma seadmest liblika foto / Choose a butterfly photo from your device")
st.markdown("3. **Parimad tulemused / Best Results**: Kasutage selgeid ja hästi valgustatud fotosid, kus liblikas on selgelt nähtav / Use clear, well-lit photos with the butterfly clearly visible")
st.markdown("4. **Täpsemad seaded / Advanced Options**: Kohandage kindluse lävi ja kasutage TTA paremate tulemuste saamiseks / Adjust confidence threshold and use TTA for better results")
# Debug info
if st.checkbox("Show Debug Info"):
st.write("**Class Names:**", class_names)
st.write("**Number of Classes:**", len(class_names))
st.write("**Model Status:**", "Loaded" if model else "Not Loaded")
if butterfly_info:
st.write("**Species Info Available:**", len(butterfly_info))