Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 | |
""" | |
Improved MAE Waste Classifier with temperature scaling and bias correction | |
""" | |
import os | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import numpy as np | |
from PIL import Image | |
from torchvision import transforms | |
from huggingface_hub import hf_hub_download | |
import warnings | |
warnings.filterwarnings("ignore") | |
# Import MAE model | |
from mae.models_vit import vit_base_patch16 | |
class ImprovedMAEWasteClassifier: | |
def __init__(self, | |
model_path=None, | |
hf_model_id=None, | |
device=None, | |
temperature=2.5, # Temperature scaling to reduce overconfidence | |
cardboard_penalty=0.8): # Penalty factor for cardboard predictions | |
""" | |
Initialize improved MAE waste classifier with bias correction | |
Args: | |
model_path: Local path to model file | |
hf_model_id: Hugging Face model ID | |
device: Device to run on | |
temperature: Temperature scaling factor (>1 reduces confidence) | |
cardboard_penalty: Penalty factor for cardboard predictions | |
""" | |
self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu') | |
self.temperature = temperature | |
self.cardboard_penalty = cardboard_penalty | |
# Class names (must match training order) | |
self.class_names = [ | |
'Cardboard', 'Food Organics', 'Glass', 'Metal', | |
'Miscellaneous Trash', 'Paper', 'Plastic', 'Textile Trash', 'Vegetation' | |
] | |
# Class-specific confidence thresholds | |
self.class_thresholds = { | |
'Cardboard': 0.8, # Higher threshold for cardboard | |
'Plastic': 0.6, | |
'Metal': 0.6, | |
'Glass': 0.6, | |
'Paper': 0.6, | |
'Food Organics': 0.5, | |
'Miscellaneous Trash': 0.5, | |
'Textile Trash': 0.4, # Lower threshold for underrepresented class | |
'Vegetation': 0.5 | |
} | |
# Load model | |
self.model = self._load_model(model_path, hf_model_id) | |
self.model.eval() | |
# Data preprocessing | |
self.transform = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
transforms.Normalize( | |
mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225] | |
) | |
]) | |
print(f"β Improved MAE Classifier loaded on {self.device}") | |
print(f"π‘οΈ Temperature scaling: {self.temperature}") | |
print(f"ποΈ Cardboard penalty: {self.cardboard_penalty}") | |
def _load_model(self, model_path=None, hf_model_id=None): | |
"""Load the finetuned MAE model""" | |
# Determine model path | |
if model_path and os.path.exists(model_path): | |
checkpoint_path = model_path | |
print(f"π Loading local model from {model_path}") | |
elif hf_model_id: | |
print(f"π Downloading model from HF Hub: {hf_model_id}") | |
checkpoint_path = hf_hub_download( | |
repo_id=hf_model_id, | |
filename="best_model.pth", | |
cache_dir="./hf_cache" | |
) | |
print(f"β Downloaded model to: {checkpoint_path}") | |
else: | |
# Try local file | |
local_path = "output_simple_mae/best_model.pth" | |
if os.path.exists(local_path): | |
checkpoint_path = local_path | |
print(f"π Using local model: {local_path}") | |
else: | |
raise FileNotFoundError("No model found. Provide model_path or hf_model_id") | |
# Create model | |
model = vit_base_patch16(num_classes=len(self.class_names)) | |
# Load checkpoint | |
checkpoint = torch.load(checkpoint_path, map_location=self.device) | |
# Handle different checkpoint formats | |
if 'model_state_dict' in checkpoint: | |
state_dict = checkpoint['model_state_dict'] | |
elif 'model' in checkpoint: | |
state_dict = checkpoint['model'] | |
else: | |
state_dict = checkpoint | |
# Load state dict | |
model.load_state_dict(state_dict, strict=False) | |
model = model.to(self.device) | |
print(f"β Loaded finetuned MAE model from {checkpoint_path}") | |
return model | |
def _apply_temperature_scaling(self, logits): | |
"""Apply temperature scaling to reduce overconfidence""" | |
return logits / self.temperature | |
def _apply_class_bias_correction(self, probs): | |
"""Apply bias correction to reduce cardboard overconfidence""" | |
probs_corrected = probs.clone() | |
# Find cardboard class index | |
cardboard_idx = self.class_names.index('Cardboard') | |
# Apply penalty to cardboard predictions | |
probs_corrected[cardboard_idx] *= self.cardboard_penalty | |
# Renormalize probabilities | |
probs_corrected = probs_corrected / probs_corrected.sum() | |
return probs_corrected | |
def _ensemble_prediction(self, image, num_crops=5): | |
"""Use ensemble of augmented predictions for better stability""" | |
# Different augmentation transforms | |
augment_transforms = [ | |
transforms.Compose([ | |
transforms.Resize((256, 256)), | |
transforms.RandomResizedCrop(224, scale=(0.9, 1.0)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
]), | |
transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.RandomHorizontalFlip(p=1.0), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
]), | |
transforms.Compose([ | |
transforms.Resize((256, 256)), | |
transforms.CenterCrop(224), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
]), | |
transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ColorJitter(brightness=0.1, contrast=0.1), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
]), | |
# Standard transform | |
self.transform | |
] | |
all_probs = [] | |
with torch.no_grad(): | |
for transform in augment_transforms[:num_crops]: | |
# Apply transform | |
input_tensor = transform(image).unsqueeze(0).to(self.device) | |
# Get prediction | |
logits = self.model(input_tensor) | |
# Apply temperature scaling | |
scaled_logits = self._apply_temperature_scaling(logits) | |
# Get probabilities | |
probs = F.softmax(scaled_logits, dim=1).squeeze(0) | |
# Apply bias correction | |
corrected_probs = self._apply_class_bias_correction(probs) | |
all_probs.append(corrected_probs.cpu().numpy()) | |
# Average ensemble predictions | |
ensemble_probs = np.mean(all_probs, axis=0) | |
return ensemble_probs | |
def classify_image(self, image, top_k=5, use_ensemble=True): | |
""" | |
Classify a waste image with improved confidence calibration | |
Args: | |
image: PIL Image or path to image | |
top_k: Number of top predictions to return | |
use_ensemble: Whether to use ensemble prediction | |
Returns: | |
Dictionary with classification results | |
""" | |
try: | |
# Load image if path provided | |
if isinstance(image, str): | |
image = Image.open(image).convert('RGB') | |
elif not isinstance(image, Image.Image): | |
raise ValueError("Image must be PIL Image or file path") | |
# Get predictions | |
if use_ensemble: | |
probs = self._ensemble_prediction(image) | |
else: | |
# Single prediction with improvements | |
input_tensor = self.transform(image).unsqueeze(0).to(self.device) | |
with torch.no_grad(): | |
logits = self.model(input_tensor) | |
scaled_logits = self._apply_temperature_scaling(logits) | |
probs = F.softmax(scaled_logits, dim=1).squeeze(0) | |
probs = self._apply_class_bias_correction(probs) | |
probs = probs.cpu().numpy() | |
# Get top predictions | |
top_indices = np.argsort(probs)[::-1][:top_k] | |
top_predictions = [] | |
for idx in top_indices: | |
class_name = self.class_names[idx] | |
confidence = float(probs[idx]) | |
top_predictions.append({ | |
'class': class_name, | |
'confidence': confidence | |
}) | |
# Determine final prediction with class-specific thresholds | |
predicted_class = top_predictions[0]['class'] | |
predicted_confidence = top_predictions[0]['confidence'] | |
# Check if prediction meets class-specific threshold | |
threshold = self.class_thresholds.get(predicted_class, 0.5) | |
if predicted_confidence < threshold: | |
# If below threshold, mark as uncertain | |
predicted_class = "Uncertain" | |
predicted_confidence = predicted_confidence | |
return { | |
'success': True, | |
'predicted_class': predicted_class, | |
'confidence': predicted_confidence, | |
'top_predictions': top_predictions, | |
'ensemble_used': use_ensemble, | |
'temperature': self.temperature | |
} | |
except Exception as e: | |
return { | |
'success': False, | |
'error': str(e) | |
} | |
def get_disposal_instructions(self, class_name): | |
"""Get disposal instructions for a waste class""" | |
instructions = { | |
'Cardboard': 'Flatten and place in recycling bin. Remove any tape or staples.', | |
'Food Organics': 'Place in compost bin or organic waste collection.', | |
'Glass': 'Rinse and place in glass recycling bin. Remove caps and lids.', | |
'Metal': 'Rinse cans and place in metal recycling bin.', | |
'Miscellaneous Trash': 'Place in general waste bin.', | |
'Paper': 'Place in paper recycling bin. Remove any plastic components.', | |
'Plastic': 'Check recycling number and place in appropriate plastic recycling bin.', | |
'Textile Trash': 'Donate if in good condition, otherwise place in textile recycling.', | |
'Vegetation': 'Compost or place in yard waste collection.', | |
'Uncertain': 'Please take another photo from a different angle or with better lighting.' | |
} | |
return instructions.get(class_name, 'Please consult local waste management guidelines.') | |
def get_model_info(self): | |
"""Get model information""" | |
return { | |
'model_name': 'Improved ViT-Base MAE', | |
'architecture': 'Vision Transformer (ViT-Base)', | |
'pretrained': 'MAE (Masked Autoencoder)', | |
'num_classes': len(self.class_names), | |
'device': str(self.device), | |
'temperature': self.temperature, | |
'cardboard_penalty': self.cardboard_penalty, | |
'improvements': [ | |
'Temperature scaling for confidence calibration', | |
'Class-specific bias correction', | |
'Ensemble predictions for stability', | |
'Class-specific confidence thresholds' | |
] | |
} | |
def test_improved_classifier(): | |
"""Test the improved classifier""" | |
print("π§ͺ Testing Improved MAE Waste Classifier...") | |
# Load improved classifier | |
classifier = ImprovedMAEWasteClassifier(hf_model_id="ysfad/mae-waste-classifier") | |
# Test with a sample image | |
test_image = "fail_images/image.webp" | |
if os.path.exists(test_image): | |
print(f"\nπ Testing with {test_image}") | |
# Test both single and ensemble prediction | |
print("\n1. Single prediction:") | |
result1 = classifier.classify_image(test_image, use_ensemble=False) | |
if result1['success']: | |
print(f"π― Predicted: {result1['predicted_class']} ({result1['confidence']:.3f})") | |
print("\n2. Ensemble prediction:") | |
result2 = classifier.classify_image(test_image, use_ensemble=True) | |
if result2['success']: | |
print(f"π― Predicted: {result2['predicted_class']} ({result2['confidence']:.3f})") | |
print("π Top predictions:") | |
for i, pred in enumerate(result2['top_predictions'], 1): | |
print(f" {i}. {pred['class']}: {pred['confidence']:.3f}") | |
print("\nπ€ Model Info:") | |
info = classifier.get_model_info() | |
for key, value in info.items(): | |
if isinstance(value, list): | |
print(f" {key}:") | |
for item in value: | |
print(f" - {item}") | |
else: | |
print(f" {key}: {value}") | |
if __name__ == "__main__": | |
test_improved_classifier() |