Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 | |
"""MAE ViT-Base waste classifier for inference.""" | |
import torch | |
import torch.nn.functional as F | |
from torchvision import transforms | |
from PIL import Image | |
import timm | |
import os | |
import json | |
from huggingface_hub import hf_hub_download | |
class MAEWasteClassifier: | |
"""Waste classifier using finetuned MAE ViT-Base model.""" | |
def __init__(self, model_path=None, hf_model_id="ysfad/mae-waste-classifier", device=None): | |
self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu') | |
self.hf_model_id = hf_model_id | |
# Try to load model from different sources | |
if model_path and os.path.exists(model_path): | |
self.model_path = model_path | |
print(f"π Using local model: {model_path}") | |
else: | |
# Try to download from HF Hub | |
try: | |
print(f"π Downloading model from HF Hub: {hf_model_id}") | |
self.model_path = hf_hub_download( | |
repo_id=hf_model_id, | |
filename="best_model.pth", | |
cache_dir="./hf_cache" | |
) | |
print(f"β Downloaded model to: {self.model_path}") | |
except Exception as e: | |
print(f"β οΈ Could not download from HF Hub: {e}") | |
# Fallback to local path | |
self.model_path = "output_simple_mae/best_model.pth" | |
if not os.path.exists(self.model_path): | |
raise FileNotFoundError(f"Model not found locally at {self.model_path} and could not download from HF Hub") | |
# Class names from training | |
self.class_names = [ | |
'Cardboard', 'Food Organics', 'Glass', 'Metal', | |
'Miscellaneous Trash', 'Paper', 'Plastic', | |
'Textile Trash', 'Vegetation' | |
] | |
# Load disposal instructions | |
self.disposal_instructions = { | |
"Cardboard": "Flatten and place in recycling bin. Remove any tape or staples.", | |
"Food Organics": "Compost in organic waste bin or home composter.", | |
"Glass": "Rinse and place in glass recycling. Remove lids and caps.", | |
"Metal": "Rinse aluminum/steel cans and place in recycling bin.", | |
"Miscellaneous Trash": "Dispose in general waste bin. Cannot be recycled.", | |
"Paper": "Place clean paper in recycling. Remove plastic windows from envelopes.", | |
"Plastic": "Check recycling number. Rinse containers before recycling.", | |
"Textile Trash": "Donate if reusable, otherwise dispose in textile recycling.", | |
"Vegetation": "Compost in organic waste or use for mulch in garden." | |
} | |
# Load model | |
self.model = self._load_model() | |
# Image preprocessing | |
self.transform = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
]) | |
print(f"β MAE Waste Classifier loaded on {self.device}") | |
print(f"π Model: ViT-Base MAE, Classes: {len(self.class_names)}") | |
def _load_model(self): | |
"""Load the finetuned MAE model.""" | |
try: | |
# Create ViT model using timm | |
model = timm.create_model('vit_base_patch16_224', pretrained=False, num_classes=len(self.class_names)) | |
# Load checkpoint | |
checkpoint = torch.load(self.model_path, map_location=self.device) | |
# Load state dict | |
if 'model_state_dict' in checkpoint: | |
model.load_state_dict(checkpoint['model_state_dict']) | |
else: | |
model.load_state_dict(checkpoint) | |
model.to(self.device) | |
model.eval() | |
print(f"β Loaded finetuned MAE model from {self.model_path}") | |
return model | |
except Exception as e: | |
print(f"β Error loading model: {e}") | |
raise | |
def classify_image(self, image, top_k=5): | |
""" | |
Classify a waste image. | |
Args: | |
image: PIL Image or path to image | |
top_k: Number of top predictions to return | |
Returns: | |
dict: Classification results | |
""" | |
try: | |
# Load and preprocess image | |
if isinstance(image, str): | |
image = Image.open(image).convert('RGB') | |
elif not isinstance(image, Image.Image): | |
raise ValueError("Image must be PIL Image or path string") | |
# Preprocess | |
input_tensor = self.transform(image).unsqueeze(0).to(self.device) | |
# Inference | |
with torch.no_grad(): | |
outputs = self.model(input_tensor) | |
probabilities = F.softmax(outputs, dim=1) | |
# Get top predictions | |
top_probs, top_indices = torch.topk(probabilities, k=min(top_k, len(self.class_names))) | |
top_predictions = [] | |
for prob, idx in zip(top_probs[0], top_indices[0]): | |
top_predictions.append({ | |
'class': self.class_names[idx.item()], | |
'confidence': prob.item() | |
}) | |
# Best prediction | |
best_pred = top_predictions[0] | |
return { | |
'success': True, | |
'predicted_class': best_pred['class'], | |
'confidence': best_pred['confidence'], | |
'top_predictions': top_predictions | |
} | |
except Exception as e: | |
return { | |
'success': False, | |
'error': str(e) | |
} | |
def get_disposal_instructions(self, class_name): | |
"""Get disposal instructions for a waste class.""" | |
return self.disposal_instructions.get(class_name, "No specific instructions available.") | |
def get_model_info(self): | |
"""Get information about the loaded model.""" | |
return { | |
'model_name': 'ViT-Base MAE', | |
'architecture': 'Vision Transformer (ViT-Base)', | |
'pretrained': 'MAE (Masked Autoencoder)', | |
'num_classes': len(self.class_names), | |
'device': self.device, | |
'model_path': self.model_path | |
} | |
# Test the classifier | |
if __name__ == "__main__": | |
print("π§ͺ Testing MAE Waste Classifier...") | |
try: | |
# Initialize classifier | |
classifier = MAEWasteClassifier() | |
# Test with a sample image if available | |
test_images = [ | |
"fail_images/image.webp", | |
"fail_images/IMG_9501.webp" | |
] | |
for img_path in test_images: | |
if os.path.exists(img_path): | |
print(f"\nπ Testing with {img_path}") | |
result = classifier.classify_image(img_path) | |
if result['success']: | |
print(f"β Predicted: {result['predicted_class']} ({result['confidence']:.3f})") | |
print(f"π Instructions: {classifier.get_disposal_instructions(result['predicted_class'])}") | |
print("\nπ Top predictions:") | |
for i, pred in enumerate(result['top_predictions'][:3], 1): | |
print(f" {i}. {pred['class']}: {pred['confidence']:.3f}") | |
else: | |
print(f"β Error: {result['error']}") | |
break | |
else: | |
print("βΉοΈ No test images found, but classifier loaded successfully!") | |
# Print model info | |
info = classifier.get_model_info() | |
print(f"\nπ€ Model Info:") | |
for key, value in info.items(): | |
print(f" {key}: {value}") | |
print("\nSuccess!") | |
except Exception as e: | |
print(f"β Error: {e}") | |
import traceback | |
traceback.print_exc() |