mae-waste-classifier-demo / improved_mae_classifier.py
ysfad's picture
Update: Enhanced classifier with temperature scaling
905ac99 verified
#!/usr/bin/env python3
"""
Improved MAE Waste Classifier with temperature scaling and bias correction
"""
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from PIL import Image
from torchvision import transforms
from huggingface_hub import hf_hub_download
import warnings
warnings.filterwarnings("ignore")
# Import MAE model
from mae.models_vit import vit_base_patch16
class ImprovedMAEWasteClassifier:
def __init__(self,
model_path=None,
hf_model_id=None,
device=None,
temperature=2.5, # Temperature scaling to reduce overconfidence
cardboard_penalty=0.8): # Penalty factor for cardboard predictions
"""
Initialize improved MAE waste classifier with bias correction
Args:
model_path: Local path to model file
hf_model_id: Hugging Face model ID
device: Device to run on
temperature: Temperature scaling factor (>1 reduces confidence)
cardboard_penalty: Penalty factor for cardboard predictions
"""
self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
self.temperature = temperature
self.cardboard_penalty = cardboard_penalty
# Class names (must match training order)
self.class_names = [
'Cardboard', 'Food Organics', 'Glass', 'Metal',
'Miscellaneous Trash', 'Paper', 'Plastic', 'Textile Trash', 'Vegetation'
]
# Class-specific confidence thresholds
self.class_thresholds = {
'Cardboard': 0.8, # Higher threshold for cardboard
'Plastic': 0.6,
'Metal': 0.6,
'Glass': 0.6,
'Paper': 0.6,
'Food Organics': 0.5,
'Miscellaneous Trash': 0.5,
'Textile Trash': 0.4, # Lower threshold for underrepresented class
'Vegetation': 0.5
}
# Load model
self.model = self._load_model(model_path, hf_model_id)
self.model.eval()
# Data preprocessing
self.transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
print(f"βœ… Improved MAE Classifier loaded on {self.device}")
print(f"🌑️ Temperature scaling: {self.temperature}")
print(f"πŸ—‚οΈ Cardboard penalty: {self.cardboard_penalty}")
def _load_model(self, model_path=None, hf_model_id=None):
"""Load the finetuned MAE model"""
# Determine model path
if model_path and os.path.exists(model_path):
checkpoint_path = model_path
print(f"πŸ“ Loading local model from {model_path}")
elif hf_model_id:
print(f"🌐 Downloading model from HF Hub: {hf_model_id}")
checkpoint_path = hf_hub_download(
repo_id=hf_model_id,
filename="best_model.pth",
cache_dir="./hf_cache"
)
print(f"βœ… Downloaded model to: {checkpoint_path}")
else:
# Try local file
local_path = "output_simple_mae/best_model.pth"
if os.path.exists(local_path):
checkpoint_path = local_path
print(f"πŸ“ Using local model: {local_path}")
else:
raise FileNotFoundError("No model found. Provide model_path or hf_model_id")
# Create model
model = vit_base_patch16(num_classes=len(self.class_names))
# Load checkpoint
checkpoint = torch.load(checkpoint_path, map_location=self.device)
# Handle different checkpoint formats
if 'model_state_dict' in checkpoint:
state_dict = checkpoint['model_state_dict']
elif 'model' in checkpoint:
state_dict = checkpoint['model']
else:
state_dict = checkpoint
# Load state dict
model.load_state_dict(state_dict, strict=False)
model = model.to(self.device)
print(f"βœ… Loaded finetuned MAE model from {checkpoint_path}")
return model
def _apply_temperature_scaling(self, logits):
"""Apply temperature scaling to reduce overconfidence"""
return logits / self.temperature
def _apply_class_bias_correction(self, probs):
"""Apply bias correction to reduce cardboard overconfidence"""
probs_corrected = probs.clone()
# Find cardboard class index
cardboard_idx = self.class_names.index('Cardboard')
# Apply penalty to cardboard predictions
probs_corrected[cardboard_idx] *= self.cardboard_penalty
# Renormalize probabilities
probs_corrected = probs_corrected / probs_corrected.sum()
return probs_corrected
def _ensemble_prediction(self, image, num_crops=5):
"""Use ensemble of augmented predictions for better stability"""
# Different augmentation transforms
augment_transforms = [
transforms.Compose([
transforms.Resize((256, 256)),
transforms.RandomResizedCrop(224, scale=(0.9, 1.0)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
]),
transforms.Compose([
transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(p=1.0),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
]),
transforms.Compose([
transforms.Resize((256, 256)),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
]),
transforms.Compose([
transforms.Resize((224, 224)),
transforms.ColorJitter(brightness=0.1, contrast=0.1),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
]),
# Standard transform
self.transform
]
all_probs = []
with torch.no_grad():
for transform in augment_transforms[:num_crops]:
# Apply transform
input_tensor = transform(image).unsqueeze(0).to(self.device)
# Get prediction
logits = self.model(input_tensor)
# Apply temperature scaling
scaled_logits = self._apply_temperature_scaling(logits)
# Get probabilities
probs = F.softmax(scaled_logits, dim=1).squeeze(0)
# Apply bias correction
corrected_probs = self._apply_class_bias_correction(probs)
all_probs.append(corrected_probs.cpu().numpy())
# Average ensemble predictions
ensemble_probs = np.mean(all_probs, axis=0)
return ensemble_probs
def classify_image(self, image, top_k=5, use_ensemble=True):
"""
Classify a waste image with improved confidence calibration
Args:
image: PIL Image or path to image
top_k: Number of top predictions to return
use_ensemble: Whether to use ensemble prediction
Returns:
Dictionary with classification results
"""
try:
# Load image if path provided
if isinstance(image, str):
image = Image.open(image).convert('RGB')
elif not isinstance(image, Image.Image):
raise ValueError("Image must be PIL Image or file path")
# Get predictions
if use_ensemble:
probs = self._ensemble_prediction(image)
else:
# Single prediction with improvements
input_tensor = self.transform(image).unsqueeze(0).to(self.device)
with torch.no_grad():
logits = self.model(input_tensor)
scaled_logits = self._apply_temperature_scaling(logits)
probs = F.softmax(scaled_logits, dim=1).squeeze(0)
probs = self._apply_class_bias_correction(probs)
probs = probs.cpu().numpy()
# Get top predictions
top_indices = np.argsort(probs)[::-1][:top_k]
top_predictions = []
for idx in top_indices:
class_name = self.class_names[idx]
confidence = float(probs[idx])
top_predictions.append({
'class': class_name,
'confidence': confidence
})
# Determine final prediction with class-specific thresholds
predicted_class = top_predictions[0]['class']
predicted_confidence = top_predictions[0]['confidence']
# Check if prediction meets class-specific threshold
threshold = self.class_thresholds.get(predicted_class, 0.5)
if predicted_confidence < threshold:
# If below threshold, mark as uncertain
predicted_class = "Uncertain"
predicted_confidence = predicted_confidence
return {
'success': True,
'predicted_class': predicted_class,
'confidence': predicted_confidence,
'top_predictions': top_predictions,
'ensemble_used': use_ensemble,
'temperature': self.temperature
}
except Exception as e:
return {
'success': False,
'error': str(e)
}
def get_disposal_instructions(self, class_name):
"""Get disposal instructions for a waste class"""
instructions = {
'Cardboard': 'Flatten and place in recycling bin. Remove any tape or staples.',
'Food Organics': 'Place in compost bin or organic waste collection.',
'Glass': 'Rinse and place in glass recycling bin. Remove caps and lids.',
'Metal': 'Rinse cans and place in metal recycling bin.',
'Miscellaneous Trash': 'Place in general waste bin.',
'Paper': 'Place in paper recycling bin. Remove any plastic components.',
'Plastic': 'Check recycling number and place in appropriate plastic recycling bin.',
'Textile Trash': 'Donate if in good condition, otherwise place in textile recycling.',
'Vegetation': 'Compost or place in yard waste collection.',
'Uncertain': 'Please take another photo from a different angle or with better lighting.'
}
return instructions.get(class_name, 'Please consult local waste management guidelines.')
def get_model_info(self):
"""Get model information"""
return {
'model_name': 'Improved ViT-Base MAE',
'architecture': 'Vision Transformer (ViT-Base)',
'pretrained': 'MAE (Masked Autoencoder)',
'num_classes': len(self.class_names),
'device': str(self.device),
'temperature': self.temperature,
'cardboard_penalty': self.cardboard_penalty,
'improvements': [
'Temperature scaling for confidence calibration',
'Class-specific bias correction',
'Ensemble predictions for stability',
'Class-specific confidence thresholds'
]
}
def test_improved_classifier():
"""Test the improved classifier"""
print("πŸ§ͺ Testing Improved MAE Waste Classifier...")
# Load improved classifier
classifier = ImprovedMAEWasteClassifier(hf_model_id="ysfad/mae-waste-classifier")
# Test with a sample image
test_image = "fail_images/image.webp"
if os.path.exists(test_image):
print(f"\nπŸ” Testing with {test_image}")
# Test both single and ensemble prediction
print("\n1. Single prediction:")
result1 = classifier.classify_image(test_image, use_ensemble=False)
if result1['success']:
print(f"🎯 Predicted: {result1['predicted_class']} ({result1['confidence']:.3f})")
print("\n2. Ensemble prediction:")
result2 = classifier.classify_image(test_image, use_ensemble=True)
if result2['success']:
print(f"🎯 Predicted: {result2['predicted_class']} ({result2['confidence']:.3f})")
print("πŸ“Š Top predictions:")
for i, pred in enumerate(result2['top_predictions'], 1):
print(f" {i}. {pred['class']}: {pred['confidence']:.3f}")
print("\nπŸ€– Model Info:")
info = classifier.get_model_info()
for key, value in info.items():
if isinstance(value, list):
print(f" {key}:")
for item in value:
print(f" - {item}")
else:
print(f" {key}: {value}")
if __name__ == "__main__":
test_improved_classifier()