Spaces:
Sleeping
Sleeping
File size: 17,278 Bytes
dbb62b2 02001f5 605629e bfc9fea 8c61898 e11e26a 1809961 02001f5 af316a5 e11e26a af316a5 02001f5 1809961 02001f5 1809961 8c61898 1809961 e11e26a 02001f5 dbb62b2 02001f5 2edff24 dbb62b2 2edff24 dbb62b2 2edff24 1809961 dbb62b2 9625ec8 dbb62b2 1809961 dbb62b2 9d53d92 dbb62b2 2edff24 dbb62b2 2edff24 dbb62b2 2edff24 dbb62b2 2edff24 dbb62b2 2edff24 dbb62b2 2edff24 dbb62b2 2edff24 dbb62b2 2edff24 dbb62b2 2edff24 dbb62b2 2edff24 dbb62b2 2edff24 dbb62b2 1809961 dbb62b2 1809961 dbb62b2 1809961 dbb62b2 1809961 02001f5 dbb62b2 7fb82fb 02001f5 118e762 dbb62b2 8445a65 7fb82fb 8445a65 dbb62b2 8445a65 dbb62b2 e11e26a dbb62b2 8445a65 dbb62b2 8c61898 dbb62b2 8445a65 dbb62b2 8445a65 118e762 dbb62b2 62475ba dbb62b2 e11e26a a91e356 dbb62b2 bfc9fea dbb62b2 9d53d92 bfc9fea e11e26a dbb62b2 e11e26a dbb62b2 38a7793 dbb62b2 38a7793 e11e26a dbb62b2 38a7793 e11e26a dbb62b2 e11e26a dbb62b2 e11e26a dbb62b2 118e762 dbb62b2 e11e26a dbb62b2 38a7793 bfc9fea e11e26a dbb62b2 e11e26a dbb62b2 605629e bfc9fea dbb62b2 bfc9fea e11e26a dbb62b2 bfc9fea e11e26a dbb62b2 e11e26a dbb62b2 e11e26a dbb62b2 e11e26a dbb62b2 e11e26a dbb62b2 bfc9fea e11e26a bfc9fea e11e26a dbb62b2 a91e356 dbb62b2 66294ea |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 |
# Enhanced Butterfly Identifier Streamlit App with Better Model Loading
import streamlit as st
from PIL import Image
import torch
import json
import os
import io
import numpy as np
import timm
import albumentations as A
from albumentations.pytorch import ToTensorV2
import warnings
warnings.filterwarnings('ignore')
# Configure Streamlit
st.set_page_config(
page_title="Butterfly Identifier / Liblikamaja ID",
page_icon="🦋",
layout="wide"
)
# Load class names
@st.cache_data
def load_class_names():
try:
with open("class_names.txt", "r") as f:
return [line.strip() for line in f.readlines()]
except FileNotFoundError:
st.error("class_names.txt file not found!")
return []
class_names = load_class_names()
# Load butterfly info
@st.cache_data
def load_butterfly_info():
try:
with open("butterfly_info.json", "r") as f:
return json.load(f)
except:
return {}
butterfly_info = load_butterfly_info()
# Define transform matching training pipeline
inference_transform = A.Compose([
A.Resize(224, 224),
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
ToTensorV2()
])
# Enhanced model loading function
@st.cache_resource
def load_model():
"""Enhanced model loading with better architecture detection"""
# Try different model file names
model_files = [
"butterfly_classifier.pth",
"best_butterfly_model_v3.pth",
"best_butterfly_model.pth"
]
MODEL_PATH = None
for model_file in model_files:
if os.path.exists(model_file):
MODEL_PATH = model_file
break
if MODEL_PATH is None:
st.error("No model file found!")
return None
st.info(f"Loading model from: {MODEL_PATH}")
try:
# Load checkpoint
checkpoint = torch.load(MODEL_PATH, map_location='cpu')
# Extract model state dict
if 'model_state_dict' in checkpoint:
model_state_dict = checkpoint['model_state_dict']
else:
model_state_dict = checkpoint
num_classes = len(class_names)
# Better architecture detection based on conv_stem channels
def detect_architecture_by_channels(state_dict):
"""Detect architecture by examining conv_stem channels"""
for key, tensor in state_dict.items():
if key.endswith('conv_stem.weight'):
channels = tensor.shape[0] # Output channels
# Map channels to likely architectures
channel_map = {
24: ['tf_efficientnetv2_s', 'efficientnet_b0'],
32: ['tf_efficientnetv2_s', 'efficientnet_b1'],
40: ['efficientnet_b3', 'efficientnet_b2'],
48: ['efficientnet_b4', 'tf_efficientnetv2_m'],
56: ['efficientnet_b5'],
64: ['efficientnet_b6', 'tf_efficientnetv2_l'],
72: ['efficientnet_b7']
}
return channel_map.get(channels, ['tf_efficientnetv2_s'])
return ['tf_efficientnetv2_s']
# Get likely architectures based on channels
likely_architectures = detect_architecture_by_channels(model_state_dict)
# Expanded list of architectures to try
architectures_to_try = likely_architectures + [
'tf_efficientnetv2_s',
'efficientnet_b0',
'efficientnet_b1',
'efficientnet_b2',
'efficientnet_b3',
'tf_efficientnetv2_m',
'efficientnet_b4'
]
# Remove duplicates while preserving order
seen = set()
architectures_to_try = [x for x in architectures_to_try if not (x in seen or seen.add(x))]
model = None
successful_arch = None
# Try each architecture
for arch in architectures_to_try:
try:
st.info(f"Trying architecture: {arch}")
# Create model
model = timm.create_model(
arch,
pretrained=False,
num_classes=num_classes,
drop_rate=0.0, # Set to 0 for inference
drop_path_rate=0.0 # Set to 0 for inference
)
# Try to load the state dict
try:
model.load_state_dict(model_state_dict, strict=True)
st.success(f"✅ Successfully loaded model with architecture: {arch}")
successful_arch = arch
break
except Exception as e:
# Try with strict=False
try:
model.load_state_dict(model_state_dict, strict=False)
st.warning(f"⚠️ Loaded {arch} with some mismatched weights")
successful_arch = arch
break
except Exception as e2:
st.warning(f"Failed to load {arch}: {str(e2)}")
continue
except Exception as e:
st.warning(f"Failed to create model {arch}: {str(e)}")
continue
if model is None:
st.error("❌ Failed to load model with any architecture!")
return None
# Set model to evaluation mode
model.eval()
# Display model info
total_params = sum(p.numel() for p in model.parameters())
st.success(f"✅ Model loaded successfully!")
st.info(f"📊 Model: {successful_arch}")
st.info(f"🔢 Parameters: {total_params:,}")
st.info(f"🎯 Classes: {num_classes}")
return model
except Exception as e:
st.error(f"❌ Error loading model: {str(e)}")
return None
# Load model
model = load_model()
def predict_butterfly(image, threshold=0.5):
"""Predict butterfly species from image"""
try:
if model is None:
raise ValueError("Model is not loaded.")
if image is None:
return None, None
# Convert to PIL Image if needed
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
if image.mode != 'RGB':
image = image.convert('RGB')
# Apply transforms
transformed = inference_transform(image=np.array(image))
input_tensor = transformed['image'].unsqueeze(0)
# Make prediction
with torch.no_grad():
output = model(input_tensor)
probabilities = torch.nn.functional.softmax(output[0], dim=0)
confidence, pred = torch.max(probabilities, 0)
if confidence.item() < threshold:
return None, confidence.item()
predicted_class = class_names[pred.item()]
return predicted_class, confidence.item()
except Exception as e:
st.error(f"Prediction error: {str(e)}")
return None, None
def predict_with_tta(image, threshold=0.5, num_tta=5):
"""Predict with Test Time Augmentation for better accuracy"""
try:
if model is None:
raise ValueError("Model is not loaded.")
if image is None:
return None, None
# Convert to PIL Image if needed
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
if image.mode != 'RGB':
image = image.convert('RGB')
# Convert to numpy for albumentations
image_np = np.array(image)
# TTA transforms
tta_transforms = [
A.Compose([
A.Resize(224, 224),
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
ToTensorV2()
]),
A.Compose([
A.Resize(256, 256),
A.CenterCrop(224, 224),
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
ToTensorV2()
]),
A.Compose([
A.Resize(224, 224),
A.HorizontalFlip(p=1.0),
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
ToTensorV2()
]),
A.Compose([
A.Resize(240, 240),
A.Rotate(limit=10, p=1.0),
A.CenterCrop(224, 224),
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
ToTensorV2()
]),
A.Compose([
A.Resize(224, 224),
A.ColorJitter(brightness=0.1, contrast=0.1, p=1.0),
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
ToTensorV2()
])
]
predictions = []
for i, transform in enumerate(tta_transforms[:num_tta]):
transformed = transform(image=image_np)
input_tensor = transformed['image'].unsqueeze(0)
with torch.no_grad():
output = model(input_tensor)
probabilities = torch.nn.functional.softmax(output, dim=1)
predictions.append(probabilities)
# Average predictions
avg_predictions = torch.mean(torch.stack(predictions), dim=0)
confidence, pred = torch.max(avg_predictions, 1)
if confidence.item() < threshold:
return None, confidence.item()
predicted_class = class_names[pred.item()]
return predicted_class, confidence.item()
except Exception as e:
st.error(f"TTA Prediction error: {str(e)}")
return None, None
# UI Code
st.title("🦋 Liblikamaja ID / Butterfly Identifier")
st.write("Tuvasta liblikaid oma kaamera abil või laadi üles pilt! / Identify butterflies using your camera or by uploading an image!")
# Add model status indicator
if model is not None:
st.success("✅ Model loaded and ready!")
else:
st.error("❌ Model not loaded. Please check your model file.")
st.stop()
# Add advanced options
with st.expander("🔧 Advanced Options / Täpsemad seaded"):
confidence_threshold = st.slider(
"Confidence Threshold / Kindluse lävi",
min_value=0.1,
max_value=1.0,
value=0.5,
step=0.05,
help="Higher values = more conservative predictions"
)
use_tta = st.checkbox(
"Use Test Time Augmentation (TTA) / Kasuta TTA",
value=False,
help="Slower but potentially more accurate predictions"
)
if use_tta:
tta_rounds = st.slider(
"TTA Rounds / TTA ringid",
min_value=3,
max_value=8,
value=5,
help="More rounds = slower but potentially more accurate"
)
tab1, tab2 = st.tabs(["📷 Live Camera / Kaamera", "📁 Upload Image / Laadi üles"])
with tab1:
st.header("Kaamera jäädvustamine / Camera Capture")
st.write("Tee pilt liblikast tuvastamiseks / Take a photo of a butterfly for identification.")
camera_photo = st.camera_input("Pildista liblikat / Capture a butterfly")
if camera_photo is not None:
try:
image = Image.open(camera_photo).convert("RGB")
col1, col2 = st.columns(2)
with col1:
st.image(image, caption="Jäädvustatud pilt / Captured Image", use_column_width=True)
with col2:
with st.spinner("Pildi analüüsimine... / Analyzing image..."):
if use_tta:
predicted_class, confidence = predict_with_tta(image, confidence_threshold, tta_rounds)
else:
predicted_class, confidence = predict_butterfly(image, confidence_threshold)
if predicted_class and confidence >= confidence_threshold:
st.success(f"**Liblikas / Butterfly: {predicted_class}**")
st.info(f"Confidence: {confidence:.2%}")
if predicted_class in butterfly_info:
st.markdown("**Liigi kirjeldus / About this species:**")
st.write(butterfly_info[predicted_class]["description"])
else:
st.info("No additional information available for this species.")
else:
confidence_text = f" (Confidence: {confidence:.2%})" if confidence else ""
st.warning(f"⚠️ Ma ei tea, mis liblikas see on / I don't know what butterfly this is{confidence_text}")
st.markdown("**Näpunäited paremate tulemuste saavutamiseks / Tips for better results:**")
st.markdown("- Kasutage paremat valgustust / Use better lighting")
st.markdown("- Veenduge, et liblikas oleks selgelt nähtav / Ensure the butterfly is clearly visible")
st.markdown("- Vältige uduseid või tumedaid pilte / Avoid blurry or dark images")
st.markdown("- Proovige madalamat kindluse läviväärtust / Try a lower confidence threshold")
except Exception as e:
st.error(f"Error processing image: {str(e)}")
with tab2:
st.header("Laadi üles pilt / Upload Image")
st.write("Laadige üles liblika selge foto tuvastamiseks / Upload a clear photo of a butterfly for identification.")
uploaded_file = st.file_uploader("Vali pilt... / Choose an image...", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
try:
image_bytes = uploaded_file.read()
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
col1, col2 = st.columns(2)
with col1:
st.image(image, caption="Üleslaetud pilt / Uploaded Image", use_column_width=True)
with col2:
with st.spinner("Pildi analüüsimine... / Analyzing image..."):
if use_tta:
predicted_class, confidence = predict_with_tta(image, confidence_threshold, tta_rounds)
else:
predicted_class, confidence = predict_butterfly(image, confidence_threshold)
if predicted_class and confidence >= confidence_threshold:
st.success(f"**Liblikas / Butterfly: {predicted_class}**")
st.info(f"Confidence: {confidence:.2%}")
if predicted_class in butterfly_info:
st.markdown("**Liigi kirjeldus / About this species:**")
st.write(butterfly_info[predicted_class]["description"])
else:
st.info("No additional information available for this species.")
else:
confidence_text = f" (Confidence: {confidence:.2%})" if confidence else ""
st.warning(f"⚠️ Ma ei tea, mis liblikas see on / I don't know what butterfly this is{confidence_text}")
st.markdown("**Näpunäited paremate tulemuste saavutamiseks / Tips for better results:**")
st.markdown("- Kasutage paremat valgustust / Use better lighting")
st.markdown("- Veenduge, et liblikas oleks selgelt nähtav / Ensure the butterfly is clearly visible")
st.markdown("- Vältige uduseid või tumedaid pilte / Avoid blurry or dark images")
st.markdown("- Proovige madalamat kindluse läviväärtust / Try a lower confidence threshold")
except Exception as e:
st.error(f"Error processing image: {str(e)}")
# Footer
st.markdown("---")
st.markdown("### Kuidas kasutada / How to use:")
st.markdown("1. **Kaamera jäädvustamine / Camera Capture**: Tehke foto oma seadme kaameraga / Take a photo using your device camera")
st.markdown("2. **Laadi pilt üles / Upload Image**: Vali oma seadmest liblika foto / Choose a butterfly photo from your device")
st.markdown("3. **Parimad tulemused / Best Results**: Kasutage selgeid ja hästi valgustatud fotosid, kus liblikas on selgelt nähtav / Use clear, well-lit photos with the butterfly clearly visible")
st.markdown("4. **Täpsemad seaded / Advanced Options**: Kohandage kindluse lävi ja kasutage TTA paremate tulemuste saamiseks / Adjust confidence threshold and use TTA for better results")
# Debug info
if st.checkbox("Show Debug Info"):
st.write("**Class Names:**", class_names)
st.write("**Number of Classes:**", len(class_names))
st.write("**Model Status:**", "Loaded" if model else "Not Loaded")
if butterfly_info:
st.write("**Species Info Available:**", len(butterfly_info))
|