# Enhanced Butterfly Identifier Streamlit App with Better Model Loading import streamlit as st from PIL import Image import torch import json import os import io import numpy as np import timm import albumentations as A from albumentations.pytorch import ToTensorV2 import warnings warnings.filterwarnings('ignore') # Configure Streamlit st.set_page_config( page_title="Butterfly Identifier / Liblikamaja ID", page_icon="🦋", layout="wide" ) # Load class names @st.cache_data def load_class_names(): try: with open("class_names.txt", "r") as f: return [line.strip() for line in f.readlines()] except FileNotFoundError: st.error("class_names.txt file not found!") return [] class_names = load_class_names() # Load butterfly info @st.cache_data def load_butterfly_info(): try: with open("butterfly_info.json", "r") as f: return json.load(f) except: return {} butterfly_info = load_butterfly_info() # Define transform matching training pipeline inference_transform = A.Compose([ A.Resize(224, 224), A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ToTensorV2() ]) # Enhanced model loading function @st.cache_resource def load_model(): """Enhanced model loading with better architecture detection""" # Try different model file names model_files = [ "butterfly_classifier.pth", "best_butterfly_model_v3.pth", "best_butterfly_model.pth" ] MODEL_PATH = None for model_file in model_files: if os.path.exists(model_file): MODEL_PATH = model_file break if MODEL_PATH is None: st.error("No model file found!") return None st.info(f"Loading model from: {MODEL_PATH}") try: # Load checkpoint checkpoint = torch.load(MODEL_PATH, map_location='cpu') # Extract model state dict if 'model_state_dict' in checkpoint: model_state_dict = checkpoint['model_state_dict'] else: model_state_dict = checkpoint num_classes = len(class_names) # Better architecture detection based on conv_stem channels def detect_architecture_by_channels(state_dict): """Detect architecture by examining conv_stem channels""" for key, tensor in state_dict.items(): if key.endswith('conv_stem.weight'): channels = tensor.shape[0] # Output channels # Map channels to likely architectures channel_map = { 24: ['tf_efficientnetv2_s', 'efficientnet_b0'], 32: ['tf_efficientnetv2_s', 'efficientnet_b1'], 40: ['efficientnet_b3', 'efficientnet_b2'], 48: ['efficientnet_b4', 'tf_efficientnetv2_m'], 56: ['efficientnet_b5'], 64: ['efficientnet_b6', 'tf_efficientnetv2_l'], 72: ['efficientnet_b7'] } return channel_map.get(channels, ['tf_efficientnetv2_s']) return ['tf_efficientnetv2_s'] # Get likely architectures based on channels likely_architectures = detect_architecture_by_channels(model_state_dict) # Expanded list of architectures to try architectures_to_try = likely_architectures + [ 'tf_efficientnetv2_s', 'efficientnet_b0', 'efficientnet_b1', 'efficientnet_b2', 'efficientnet_b3', 'tf_efficientnetv2_m', 'efficientnet_b4' ] # Remove duplicates while preserving order seen = set() architectures_to_try = [x for x in architectures_to_try if not (x in seen or seen.add(x))] model = None successful_arch = None # Try each architecture for arch in architectures_to_try: try: st.info(f"Trying architecture: {arch}") # Create model model = timm.create_model( arch, pretrained=False, num_classes=num_classes, drop_rate=0.0, # Set to 0 for inference drop_path_rate=0.0 # Set to 0 for inference ) # Try to load the state dict try: model.load_state_dict(model_state_dict, strict=True) st.success(f"✅ Successfully loaded model with architecture: {arch}") successful_arch = arch break except Exception as e: # Try with strict=False try: model.load_state_dict(model_state_dict, strict=False) st.warning(f"⚠️ Loaded {arch} with some mismatched weights") successful_arch = arch break except Exception as e2: st.warning(f"Failed to load {arch}: {str(e2)}") continue except Exception as e: st.warning(f"Failed to create model {arch}: {str(e)}") continue if model is None: st.error("❌ Failed to load model with any architecture!") return None # Set model to evaluation mode model.eval() # Display model info total_params = sum(p.numel() for p in model.parameters()) st.success(f"✅ Model loaded successfully!") st.info(f"📊 Model: {successful_arch}") st.info(f"🔢 Parameters: {total_params:,}") st.info(f"🎯 Classes: {num_classes}") return model except Exception as e: st.error(f"❌ Error loading model: {str(e)}") return None # Load model model = load_model() def predict_butterfly(image, threshold=0.5): """Predict butterfly species from image""" try: if model is None: raise ValueError("Model is not loaded.") if image is None: return None, None # Convert to PIL Image if needed if isinstance(image, np.ndarray): image = Image.fromarray(image) if image.mode != 'RGB': image = image.convert('RGB') # Apply transforms transformed = inference_transform(image=np.array(image)) input_tensor = transformed['image'].unsqueeze(0) # Make prediction with torch.no_grad(): output = model(input_tensor) probabilities = torch.nn.functional.softmax(output[0], dim=0) confidence, pred = torch.max(probabilities, 0) if confidence.item() < threshold: return None, confidence.item() predicted_class = class_names[pred.item()] return predicted_class, confidence.item() except Exception as e: st.error(f"Prediction error: {str(e)}") return None, None def predict_with_tta(image, threshold=0.5, num_tta=5): """Predict with Test Time Augmentation for better accuracy""" try: if model is None: raise ValueError("Model is not loaded.") if image is None: return None, None # Convert to PIL Image if needed if isinstance(image, np.ndarray): image = Image.fromarray(image) if image.mode != 'RGB': image = image.convert('RGB') # Convert to numpy for albumentations image_np = np.array(image) # TTA transforms tta_transforms = [ A.Compose([ A.Resize(224, 224), A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ToTensorV2() ]), A.Compose([ A.Resize(256, 256), A.CenterCrop(224, 224), A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ToTensorV2() ]), A.Compose([ A.Resize(224, 224), A.HorizontalFlip(p=1.0), A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ToTensorV2() ]), A.Compose([ A.Resize(240, 240), A.Rotate(limit=10, p=1.0), A.CenterCrop(224, 224), A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ToTensorV2() ]), A.Compose([ A.Resize(224, 224), A.ColorJitter(brightness=0.1, contrast=0.1, p=1.0), A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ToTensorV2() ]) ] predictions = [] for i, transform in enumerate(tta_transforms[:num_tta]): transformed = transform(image=image_np) input_tensor = transformed['image'].unsqueeze(0) with torch.no_grad(): output = model(input_tensor) probabilities = torch.nn.functional.softmax(output, dim=1) predictions.append(probabilities) # Average predictions avg_predictions = torch.mean(torch.stack(predictions), dim=0) confidence, pred = torch.max(avg_predictions, 1) if confidence.item() < threshold: return None, confidence.item() predicted_class = class_names[pred.item()] return predicted_class, confidence.item() except Exception as e: st.error(f"TTA Prediction error: {str(e)}") return None, None # UI Code st.title("🦋 Liblikamaja ID / Butterfly Identifier") st.write("Tuvasta liblikaid oma kaamera abil või laadi üles pilt! / Identify butterflies using your camera or by uploading an image!") # Add model status indicator if model is not None: st.success("✅ Model loaded and ready!") else: st.error("❌ Model not loaded. Please check your model file.") st.stop() # Add advanced options with st.expander("🔧 Advanced Options / Täpsemad seaded"): confidence_threshold = st.slider( "Confidence Threshold / Kindluse lävi", min_value=0.1, max_value=1.0, value=0.5, step=0.05, help="Higher values = more conservative predictions" ) use_tta = st.checkbox( "Use Test Time Augmentation (TTA) / Kasuta TTA", value=False, help="Slower but potentially more accurate predictions" ) if use_tta: tta_rounds = st.slider( "TTA Rounds / TTA ringid", min_value=3, max_value=8, value=5, help="More rounds = slower but potentially more accurate" ) tab1, tab2 = st.tabs(["📷 Live Camera / Kaamera", "📁 Upload Image / Laadi üles"]) with tab1: st.header("Kaamera jäädvustamine / Camera Capture") st.write("Tee pilt liblikast tuvastamiseks / Take a photo of a butterfly for identification.") camera_photo = st.camera_input("Pildista liblikat / Capture a butterfly") if camera_photo is not None: try: image = Image.open(camera_photo).convert("RGB") col1, col2 = st.columns(2) with col1: st.image(image, caption="Jäädvustatud pilt / Captured Image", use_column_width=True) with col2: with st.spinner("Pildi analüüsimine... / Analyzing image..."): if use_tta: predicted_class, confidence = predict_with_tta(image, confidence_threshold, tta_rounds) else: predicted_class, confidence = predict_butterfly(image, confidence_threshold) if predicted_class and confidence >= confidence_threshold: st.success(f"**Liblikas / Butterfly: {predicted_class}**") st.info(f"Confidence: {confidence:.2%}") if predicted_class in butterfly_info: st.markdown("**Liigi kirjeldus / About this species:**") st.write(butterfly_info[predicted_class]["description"]) else: st.info("No additional information available for this species.") else: confidence_text = f" (Confidence: {confidence:.2%})" if confidence else "" st.warning(f"⚠️ Ma ei tea, mis liblikas see on / I don't know what butterfly this is{confidence_text}") st.markdown("**Näpunäited paremate tulemuste saavutamiseks / Tips for better results:**") st.markdown("- Kasutage paremat valgustust / Use better lighting") st.markdown("- Veenduge, et liblikas oleks selgelt nähtav / Ensure the butterfly is clearly visible") st.markdown("- Vältige uduseid või tumedaid pilte / Avoid blurry or dark images") st.markdown("- Proovige madalamat kindluse läviväärtust / Try a lower confidence threshold") except Exception as e: st.error(f"Error processing image: {str(e)}") with tab2: st.header("Laadi üles pilt / Upload Image") st.write("Laadige üles liblika selge foto tuvastamiseks / Upload a clear photo of a butterfly for identification.") uploaded_file = st.file_uploader("Vali pilt... / Choose an image...", type=["jpg", "jpeg", "png"]) if uploaded_file is not None: try: image_bytes = uploaded_file.read() image = Image.open(io.BytesIO(image_bytes)).convert("RGB") col1, col2 = st.columns(2) with col1: st.image(image, caption="Üleslaetud pilt / Uploaded Image", use_column_width=True) with col2: with st.spinner("Pildi analüüsimine... / Analyzing image..."): if use_tta: predicted_class, confidence = predict_with_tta(image, confidence_threshold, tta_rounds) else: predicted_class, confidence = predict_butterfly(image, confidence_threshold) if predicted_class and confidence >= confidence_threshold: st.success(f"**Liblikas / Butterfly: {predicted_class}**") st.info(f"Confidence: {confidence:.2%}") if predicted_class in butterfly_info: st.markdown("**Liigi kirjeldus / About this species:**") st.write(butterfly_info[predicted_class]["description"]) else: st.info("No additional information available for this species.") else: confidence_text = f" (Confidence: {confidence:.2%})" if confidence else "" st.warning(f"⚠️ Ma ei tea, mis liblikas see on / I don't know what butterfly this is{confidence_text}") st.markdown("**Näpunäited paremate tulemuste saavutamiseks / Tips for better results:**") st.markdown("- Kasutage paremat valgustust / Use better lighting") st.markdown("- Veenduge, et liblikas oleks selgelt nähtav / Ensure the butterfly is clearly visible") st.markdown("- Vältige uduseid või tumedaid pilte / Avoid blurry or dark images") st.markdown("- Proovige madalamat kindluse läviväärtust / Try a lower confidence threshold") except Exception as e: st.error(f"Error processing image: {str(e)}") # Footer st.markdown("---") st.markdown("### Kuidas kasutada / How to use:") st.markdown("1. **Kaamera jäädvustamine / Camera Capture**: Tehke foto oma seadme kaameraga / Take a photo using your device camera") st.markdown("2. **Laadi pilt üles / Upload Image**: Vali oma seadmest liblika foto / Choose a butterfly photo from your device") st.markdown("3. **Parimad tulemused / Best Results**: Kasutage selgeid ja hästi valgustatud fotosid, kus liblikas on selgelt nähtav / Use clear, well-lit photos with the butterfly clearly visible") st.markdown("4. **Täpsemad seaded / Advanced Options**: Kohandage kindluse lävi ja kasutage TTA paremate tulemuste saamiseks / Adjust confidence threshold and use TTA for better results") # Debug info if st.checkbox("Show Debug Info"): st.write("**Class Names:**", class_names) st.write("**Number of Classes:**", len(class_names)) st.write("**Model Status:**", "Loaded" if model else "Not Loaded") if butterfly_info: st.write("**Species Info Available:**", len(butterfly_info))