mae-waste-classifier-demo / mae_waste_classifier.py
ysfad's picture
Upload mae_waste_classifier.py with huggingface_hub
0007f63 verified
#!/usr/bin/env python3
"""MAE ViT-Base waste classifier for inference."""
import torch
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
import timm
import os
import json
from huggingface_hub import hf_hub_download
class MAEWasteClassifier:
"""Waste classifier using finetuned MAE ViT-Base model."""
def __init__(self, model_path=None, hf_model_id="ysfad/mae-waste-classifier", device=None):
self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
self.hf_model_id = hf_model_id
# Try to load model from different sources
if model_path and os.path.exists(model_path):
self.model_path = model_path
print(f"πŸ“ Using local model: {model_path}")
else:
# Try to download from HF Hub
try:
print(f"🌐 Downloading model from HF Hub: {hf_model_id}")
self.model_path = hf_hub_download(
repo_id=hf_model_id,
filename="best_model.pth",
cache_dir="./hf_cache"
)
print(f"βœ… Downloaded model to: {self.model_path}")
except Exception as e:
print(f"⚠️ Could not download from HF Hub: {e}")
# Fallback to local path
self.model_path = "output_simple_mae/best_model.pth"
if not os.path.exists(self.model_path):
raise FileNotFoundError(f"Model not found locally at {self.model_path} and could not download from HF Hub")
# Class names from training
self.class_names = [
'Cardboard', 'Food Organics', 'Glass', 'Metal',
'Miscellaneous Trash', 'Paper', 'Plastic',
'Textile Trash', 'Vegetation'
]
# Load disposal instructions
self.disposal_instructions = {
"Cardboard": "Flatten and place in recycling bin. Remove any tape or staples.",
"Food Organics": "Compost in organic waste bin or home composter.",
"Glass": "Rinse and place in glass recycling. Remove lids and caps.",
"Metal": "Rinse aluminum/steel cans and place in recycling bin.",
"Miscellaneous Trash": "Dispose in general waste bin. Cannot be recycled.",
"Paper": "Place clean paper in recycling. Remove plastic windows from envelopes.",
"Plastic": "Check recycling number. Rinse containers before recycling.",
"Textile Trash": "Donate if reusable, otherwise dispose in textile recycling.",
"Vegetation": "Compost in organic waste or use for mulch in garden."
}
# Load model
self.model = self._load_model()
# Image preprocessing
self.transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
print(f"βœ… MAE Waste Classifier loaded on {self.device}")
print(f"πŸ“Š Model: ViT-Base MAE, Classes: {len(self.class_names)}")
def _load_model(self):
"""Load the finetuned MAE model."""
try:
# Create ViT model using timm
model = timm.create_model('vit_base_patch16_224', pretrained=False, num_classes=len(self.class_names))
# Load checkpoint
checkpoint = torch.load(self.model_path, map_location=self.device)
# Load state dict
if 'model_state_dict' in checkpoint:
model.load_state_dict(checkpoint['model_state_dict'])
else:
model.load_state_dict(checkpoint)
model.to(self.device)
model.eval()
print(f"βœ… Loaded finetuned MAE model from {self.model_path}")
return model
except Exception as e:
print(f"❌ Error loading model: {e}")
raise
def classify_image(self, image, top_k=5):
"""
Classify a waste image.
Args:
image: PIL Image or path to image
top_k: Number of top predictions to return
Returns:
dict: Classification results
"""
try:
# Load and preprocess image
if isinstance(image, str):
image = Image.open(image).convert('RGB')
elif not isinstance(image, Image.Image):
raise ValueError("Image must be PIL Image or path string")
# Preprocess
input_tensor = self.transform(image).unsqueeze(0).to(self.device)
# Inference
with torch.no_grad():
outputs = self.model(input_tensor)
probabilities = F.softmax(outputs, dim=1)
# Get top predictions
top_probs, top_indices = torch.topk(probabilities, k=min(top_k, len(self.class_names)))
top_predictions = []
for prob, idx in zip(top_probs[0], top_indices[0]):
top_predictions.append({
'class': self.class_names[idx.item()],
'confidence': prob.item()
})
# Best prediction
best_pred = top_predictions[0]
return {
'success': True,
'predicted_class': best_pred['class'],
'confidence': best_pred['confidence'],
'top_predictions': top_predictions
}
except Exception as e:
return {
'success': False,
'error': str(e)
}
def get_disposal_instructions(self, class_name):
"""Get disposal instructions for a waste class."""
return self.disposal_instructions.get(class_name, "No specific instructions available.")
def get_model_info(self):
"""Get information about the loaded model."""
return {
'model_name': 'ViT-Base MAE',
'architecture': 'Vision Transformer (ViT-Base)',
'pretrained': 'MAE (Masked Autoencoder)',
'num_classes': len(self.class_names),
'device': self.device,
'model_path': self.model_path
}
# Test the classifier
if __name__ == "__main__":
print("πŸ§ͺ Testing MAE Waste Classifier...")
try:
# Initialize classifier
classifier = MAEWasteClassifier()
# Test with a sample image if available
test_images = [
"fail_images/image.webp",
"fail_images/IMG_9501.webp"
]
for img_path in test_images:
if os.path.exists(img_path):
print(f"\nπŸ” Testing with {img_path}")
result = classifier.classify_image(img_path)
if result['success']:
print(f"βœ… Predicted: {result['predicted_class']} ({result['confidence']:.3f})")
print(f"πŸ“‹ Instructions: {classifier.get_disposal_instructions(result['predicted_class'])}")
print("\nπŸ“Š Top predictions:")
for i, pred in enumerate(result['top_predictions'][:3], 1):
print(f" {i}. {pred['class']}: {pred['confidence']:.3f}")
else:
print(f"❌ Error: {result['error']}")
break
else:
print("ℹ️ No test images found, but classifier loaded successfully!")
# Print model info
info = classifier.get_model_info()
print(f"\nπŸ€– Model Info:")
for key, value in info.items():
print(f" {key}: {value}")
print("\nSuccess!")
except Exception as e:
print(f"❌ Error: {e}")
import traceback
traceback.print_exc()