#!/usr/bin/env python3 """ Improved MAE Waste Classifier with temperature scaling and bias correction """ import os import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from PIL import Image from torchvision import transforms from huggingface_hub import hf_hub_download import warnings warnings.filterwarnings("ignore") # Import MAE model from mae.models_vit import vit_base_patch16 class ImprovedMAEWasteClassifier: def __init__(self, model_path=None, hf_model_id=None, device=None, temperature=2.5, # Temperature scaling to reduce overconfidence cardboard_penalty=0.8): # Penalty factor for cardboard predictions """ Initialize improved MAE waste classifier with bias correction Args: model_path: Local path to model file hf_model_id: Hugging Face model ID device: Device to run on temperature: Temperature scaling factor (>1 reduces confidence) cardboard_penalty: Penalty factor for cardboard predictions """ self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu') self.temperature = temperature self.cardboard_penalty = cardboard_penalty # Class names (must match training order) self.class_names = [ 'Cardboard', 'Food Organics', 'Glass', 'Metal', 'Miscellaneous Trash', 'Paper', 'Plastic', 'Textile Trash', 'Vegetation' ] # Class-specific confidence thresholds self.class_thresholds = { 'Cardboard': 0.8, # Higher threshold for cardboard 'Plastic': 0.6, 'Metal': 0.6, 'Glass': 0.6, 'Paper': 0.6, 'Food Organics': 0.5, 'Miscellaneous Trash': 0.5, 'Textile Trash': 0.4, # Lower threshold for underrepresented class 'Vegetation': 0.5 } # Load model self.model = self._load_model(model_path, hf_model_id) self.model.eval() # Data preprocessing self.transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) ]) print(f"โœ… Improved MAE Classifier loaded on {self.device}") print(f"๐ŸŒก๏ธ Temperature scaling: {self.temperature}") print(f"๐Ÿ—‚๏ธ Cardboard penalty: {self.cardboard_penalty}") def _load_model(self, model_path=None, hf_model_id=None): """Load the finetuned MAE model""" # Determine model path if model_path and os.path.exists(model_path): checkpoint_path = model_path print(f"๐Ÿ“ Loading local model from {model_path}") elif hf_model_id: print(f"๐ŸŒ Downloading model from HF Hub: {hf_model_id}") checkpoint_path = hf_hub_download( repo_id=hf_model_id, filename="best_model.pth", cache_dir="./hf_cache" ) print(f"โœ… Downloaded model to: {checkpoint_path}") else: # Try local file local_path = "output_simple_mae/best_model.pth" if os.path.exists(local_path): checkpoint_path = local_path print(f"๐Ÿ“ Using local model: {local_path}") else: raise FileNotFoundError("No model found. Provide model_path or hf_model_id") # Create model model = vit_base_patch16(num_classes=len(self.class_names)) # Load checkpoint checkpoint = torch.load(checkpoint_path, map_location=self.device) # Handle different checkpoint formats if 'model_state_dict' in checkpoint: state_dict = checkpoint['model_state_dict'] elif 'model' in checkpoint: state_dict = checkpoint['model'] else: state_dict = checkpoint # Load state dict model.load_state_dict(state_dict, strict=False) model = model.to(self.device) print(f"โœ… Loaded finetuned MAE model from {checkpoint_path}") return model def _apply_temperature_scaling(self, logits): """Apply temperature scaling to reduce overconfidence""" return logits / self.temperature def _apply_class_bias_correction(self, probs): """Apply bias correction to reduce cardboard overconfidence""" probs_corrected = probs.clone() # Find cardboard class index cardboard_idx = self.class_names.index('Cardboard') # Apply penalty to cardboard predictions probs_corrected[cardboard_idx] *= self.cardboard_penalty # Renormalize probabilities probs_corrected = probs_corrected / probs_corrected.sum() return probs_corrected def _ensemble_prediction(self, image, num_crops=5): """Use ensemble of augmented predictions for better stability""" # Different augmentation transforms augment_transforms = [ transforms.Compose([ transforms.Resize((256, 256)), transforms.RandomResizedCrop(224, scale=(0.9, 1.0)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]), transforms.Compose([ transforms.Resize((224, 224)), transforms.RandomHorizontalFlip(p=1.0), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]), transforms.Compose([ transforms.Resize((256, 256)), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]), transforms.Compose([ transforms.Resize((224, 224)), transforms.ColorJitter(brightness=0.1, contrast=0.1), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]), # Standard transform self.transform ] all_probs = [] with torch.no_grad(): for transform in augment_transforms[:num_crops]: # Apply transform input_tensor = transform(image).unsqueeze(0).to(self.device) # Get prediction logits = self.model(input_tensor) # Apply temperature scaling scaled_logits = self._apply_temperature_scaling(logits) # Get probabilities probs = F.softmax(scaled_logits, dim=1).squeeze(0) # Apply bias correction corrected_probs = self._apply_class_bias_correction(probs) all_probs.append(corrected_probs.cpu().numpy()) # Average ensemble predictions ensemble_probs = np.mean(all_probs, axis=0) return ensemble_probs def classify_image(self, image, top_k=5, use_ensemble=True): """ Classify a waste image with improved confidence calibration Args: image: PIL Image or path to image top_k: Number of top predictions to return use_ensemble: Whether to use ensemble prediction Returns: Dictionary with classification results """ try: # Load image if path provided if isinstance(image, str): image = Image.open(image).convert('RGB') elif not isinstance(image, Image.Image): raise ValueError("Image must be PIL Image or file path") # Get predictions if use_ensemble: probs = self._ensemble_prediction(image) else: # Single prediction with improvements input_tensor = self.transform(image).unsqueeze(0).to(self.device) with torch.no_grad(): logits = self.model(input_tensor) scaled_logits = self._apply_temperature_scaling(logits) probs = F.softmax(scaled_logits, dim=1).squeeze(0) probs = self._apply_class_bias_correction(probs) probs = probs.cpu().numpy() # Get top predictions top_indices = np.argsort(probs)[::-1][:top_k] top_predictions = [] for idx in top_indices: class_name = self.class_names[idx] confidence = float(probs[idx]) top_predictions.append({ 'class': class_name, 'confidence': confidence }) # Determine final prediction with class-specific thresholds predicted_class = top_predictions[0]['class'] predicted_confidence = top_predictions[0]['confidence'] # Check if prediction meets class-specific threshold threshold = self.class_thresholds.get(predicted_class, 0.5) if predicted_confidence < threshold: # If below threshold, mark as uncertain predicted_class = "Uncertain" predicted_confidence = predicted_confidence return { 'success': True, 'predicted_class': predicted_class, 'confidence': predicted_confidence, 'top_predictions': top_predictions, 'ensemble_used': use_ensemble, 'temperature': self.temperature } except Exception as e: return { 'success': False, 'error': str(e) } def get_disposal_instructions(self, class_name): """Get disposal instructions for a waste class""" instructions = { 'Cardboard': 'Flatten and place in recycling bin. Remove any tape or staples.', 'Food Organics': 'Place in compost bin or organic waste collection.', 'Glass': 'Rinse and place in glass recycling bin. Remove caps and lids.', 'Metal': 'Rinse cans and place in metal recycling bin.', 'Miscellaneous Trash': 'Place in general waste bin.', 'Paper': 'Place in paper recycling bin. Remove any plastic components.', 'Plastic': 'Check recycling number and place in appropriate plastic recycling bin.', 'Textile Trash': 'Donate if in good condition, otherwise place in textile recycling.', 'Vegetation': 'Compost or place in yard waste collection.', 'Uncertain': 'Please take another photo from a different angle or with better lighting.' } return instructions.get(class_name, 'Please consult local waste management guidelines.') def get_model_info(self): """Get model information""" return { 'model_name': 'Improved ViT-Base MAE', 'architecture': 'Vision Transformer (ViT-Base)', 'pretrained': 'MAE (Masked Autoencoder)', 'num_classes': len(self.class_names), 'device': str(self.device), 'temperature': self.temperature, 'cardboard_penalty': self.cardboard_penalty, 'improvements': [ 'Temperature scaling for confidence calibration', 'Class-specific bias correction', 'Ensemble predictions for stability', 'Class-specific confidence thresholds' ] } def test_improved_classifier(): """Test the improved classifier""" print("๐Ÿงช Testing Improved MAE Waste Classifier...") # Load improved classifier classifier = ImprovedMAEWasteClassifier(hf_model_id="ysfad/mae-waste-classifier") # Test with a sample image test_image = "fail_images/image.webp" if os.path.exists(test_image): print(f"\n๐Ÿ” Testing with {test_image}") # Test both single and ensemble prediction print("\n1. Single prediction:") result1 = classifier.classify_image(test_image, use_ensemble=False) if result1['success']: print(f"๐ŸŽฏ Predicted: {result1['predicted_class']} ({result1['confidence']:.3f})") print("\n2. Ensemble prediction:") result2 = classifier.classify_image(test_image, use_ensemble=True) if result2['success']: print(f"๐ŸŽฏ Predicted: {result2['predicted_class']} ({result2['confidence']:.3f})") print("๐Ÿ“Š Top predictions:") for i, pred in enumerate(result2['top_predictions'], 1): print(f" {i}. {pred['class']}: {pred['confidence']:.3f}") print("\n๐Ÿค– Model Info:") info = classifier.get_model_info() for key, value in info.items(): if isinstance(value, list): print(f" {key}:") for item in value: print(f" - {item}") else: print(f" {key}: {value}") if __name__ == "__main__": test_improved_classifier()