Spaces:
Runtime error
Runtime error
File size: 8,403 Bytes
0007f63 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 |
#!/usr/bin/env python3
"""MAE ViT-Base waste classifier for inference."""
import torch
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
import timm
import os
import json
from huggingface_hub import hf_hub_download
class MAEWasteClassifier:
"""Waste classifier using finetuned MAE ViT-Base model."""
def __init__(self, model_path=None, hf_model_id="ysfad/mae-waste-classifier", device=None):
self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
self.hf_model_id = hf_model_id
# Try to load model from different sources
if model_path and os.path.exists(model_path):
self.model_path = model_path
print(f"π Using local model: {model_path}")
else:
# Try to download from HF Hub
try:
print(f"π Downloading model from HF Hub: {hf_model_id}")
self.model_path = hf_hub_download(
repo_id=hf_model_id,
filename="best_model.pth",
cache_dir="./hf_cache"
)
print(f"β
Downloaded model to: {self.model_path}")
except Exception as e:
print(f"β οΈ Could not download from HF Hub: {e}")
# Fallback to local path
self.model_path = "output_simple_mae/best_model.pth"
if not os.path.exists(self.model_path):
raise FileNotFoundError(f"Model not found locally at {self.model_path} and could not download from HF Hub")
# Class names from training
self.class_names = [
'Cardboard', 'Food Organics', 'Glass', 'Metal',
'Miscellaneous Trash', 'Paper', 'Plastic',
'Textile Trash', 'Vegetation'
]
# Load disposal instructions
self.disposal_instructions = {
"Cardboard": "Flatten and place in recycling bin. Remove any tape or staples.",
"Food Organics": "Compost in organic waste bin or home composter.",
"Glass": "Rinse and place in glass recycling. Remove lids and caps.",
"Metal": "Rinse aluminum/steel cans and place in recycling bin.",
"Miscellaneous Trash": "Dispose in general waste bin. Cannot be recycled.",
"Paper": "Place clean paper in recycling. Remove plastic windows from envelopes.",
"Plastic": "Check recycling number. Rinse containers before recycling.",
"Textile Trash": "Donate if reusable, otherwise dispose in textile recycling.",
"Vegetation": "Compost in organic waste or use for mulch in garden."
}
# Load model
self.model = self._load_model()
# Image preprocessing
self.transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
print(f"β
MAE Waste Classifier loaded on {self.device}")
print(f"π Model: ViT-Base MAE, Classes: {len(self.class_names)}")
def _load_model(self):
"""Load the finetuned MAE model."""
try:
# Create ViT model using timm
model = timm.create_model('vit_base_patch16_224', pretrained=False, num_classes=len(self.class_names))
# Load checkpoint
checkpoint = torch.load(self.model_path, map_location=self.device)
# Load state dict
if 'model_state_dict' in checkpoint:
model.load_state_dict(checkpoint['model_state_dict'])
else:
model.load_state_dict(checkpoint)
model.to(self.device)
model.eval()
print(f"β
Loaded finetuned MAE model from {self.model_path}")
return model
except Exception as e:
print(f"β Error loading model: {e}")
raise
def classify_image(self, image, top_k=5):
"""
Classify a waste image.
Args:
image: PIL Image or path to image
top_k: Number of top predictions to return
Returns:
dict: Classification results
"""
try:
# Load and preprocess image
if isinstance(image, str):
image = Image.open(image).convert('RGB')
elif not isinstance(image, Image.Image):
raise ValueError("Image must be PIL Image or path string")
# Preprocess
input_tensor = self.transform(image).unsqueeze(0).to(self.device)
# Inference
with torch.no_grad():
outputs = self.model(input_tensor)
probabilities = F.softmax(outputs, dim=1)
# Get top predictions
top_probs, top_indices = torch.topk(probabilities, k=min(top_k, len(self.class_names)))
top_predictions = []
for prob, idx in zip(top_probs[0], top_indices[0]):
top_predictions.append({
'class': self.class_names[idx.item()],
'confidence': prob.item()
})
# Best prediction
best_pred = top_predictions[0]
return {
'success': True,
'predicted_class': best_pred['class'],
'confidence': best_pred['confidence'],
'top_predictions': top_predictions
}
except Exception as e:
return {
'success': False,
'error': str(e)
}
def get_disposal_instructions(self, class_name):
"""Get disposal instructions for a waste class."""
return self.disposal_instructions.get(class_name, "No specific instructions available.")
def get_model_info(self):
"""Get information about the loaded model."""
return {
'model_name': 'ViT-Base MAE',
'architecture': 'Vision Transformer (ViT-Base)',
'pretrained': 'MAE (Masked Autoencoder)',
'num_classes': len(self.class_names),
'device': self.device,
'model_path': self.model_path
}
# Test the classifier
if __name__ == "__main__":
print("π§ͺ Testing MAE Waste Classifier...")
try:
# Initialize classifier
classifier = MAEWasteClassifier()
# Test with a sample image if available
test_images = [
"fail_images/image.webp",
"fail_images/IMG_9501.webp"
]
for img_path in test_images:
if os.path.exists(img_path):
print(f"\nπ Testing with {img_path}")
result = classifier.classify_image(img_path)
if result['success']:
print(f"β
Predicted: {result['predicted_class']} ({result['confidence']:.3f})")
print(f"π Instructions: {classifier.get_disposal_instructions(result['predicted_class'])}")
print("\nπ Top predictions:")
for i, pred in enumerate(result['top_predictions'][:3], 1):
print(f" {i}. {pred['class']}: {pred['confidence']:.3f}")
else:
print(f"β Error: {result['error']}")
break
else:
print("βΉοΈ No test images found, but classifier loaded successfully!")
# Print model info
info = classifier.get_model_info()
print(f"\nπ€ Model Info:")
for key, value in info.items():
print(f" {key}: {value}")
print("\nSuccess!")
except Exception as e:
print(f"β Error: {e}")
import traceback
traceback.print_exc() |