Spaces:
Runtime error
Runtime error
Upload mae_waste_classifier.py with huggingface_hub
Browse files- mae_waste_classifier.py +209 -0
mae_waste_classifier.py
ADDED
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
"""MAE ViT-Base waste classifier for inference."""
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from torchvision import transforms
|
7 |
+
from PIL import Image
|
8 |
+
import timm
|
9 |
+
import os
|
10 |
+
import json
|
11 |
+
from huggingface_hub import hf_hub_download
|
12 |
+
|
13 |
+
class MAEWasteClassifier:
|
14 |
+
"""Waste classifier using finetuned MAE ViT-Base model."""
|
15 |
+
|
16 |
+
def __init__(self, model_path=None, hf_model_id="ysfad/mae-waste-classifier", device=None):
|
17 |
+
self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
|
18 |
+
self.hf_model_id = hf_model_id
|
19 |
+
|
20 |
+
# Try to load model from different sources
|
21 |
+
if model_path and os.path.exists(model_path):
|
22 |
+
self.model_path = model_path
|
23 |
+
print(f"π Using local model: {model_path}")
|
24 |
+
else:
|
25 |
+
# Try to download from HF Hub
|
26 |
+
try:
|
27 |
+
print(f"π Downloading model from HF Hub: {hf_model_id}")
|
28 |
+
self.model_path = hf_hub_download(
|
29 |
+
repo_id=hf_model_id,
|
30 |
+
filename="best_model.pth",
|
31 |
+
cache_dir="./hf_cache"
|
32 |
+
)
|
33 |
+
print(f"β
Downloaded model to: {self.model_path}")
|
34 |
+
except Exception as e:
|
35 |
+
print(f"β οΈ Could not download from HF Hub: {e}")
|
36 |
+
# Fallback to local path
|
37 |
+
self.model_path = "output_simple_mae/best_model.pth"
|
38 |
+
if not os.path.exists(self.model_path):
|
39 |
+
raise FileNotFoundError(f"Model not found locally at {self.model_path} and could not download from HF Hub")
|
40 |
+
|
41 |
+
# Class names from training
|
42 |
+
self.class_names = [
|
43 |
+
'Cardboard', 'Food Organics', 'Glass', 'Metal',
|
44 |
+
'Miscellaneous Trash', 'Paper', 'Plastic',
|
45 |
+
'Textile Trash', 'Vegetation'
|
46 |
+
]
|
47 |
+
|
48 |
+
# Load disposal instructions
|
49 |
+
self.disposal_instructions = {
|
50 |
+
"Cardboard": "Flatten and place in recycling bin. Remove any tape or staples.",
|
51 |
+
"Food Organics": "Compost in organic waste bin or home composter.",
|
52 |
+
"Glass": "Rinse and place in glass recycling. Remove lids and caps.",
|
53 |
+
"Metal": "Rinse aluminum/steel cans and place in recycling bin.",
|
54 |
+
"Miscellaneous Trash": "Dispose in general waste bin. Cannot be recycled.",
|
55 |
+
"Paper": "Place clean paper in recycling. Remove plastic windows from envelopes.",
|
56 |
+
"Plastic": "Check recycling number. Rinse containers before recycling.",
|
57 |
+
"Textile Trash": "Donate if reusable, otherwise dispose in textile recycling.",
|
58 |
+
"Vegetation": "Compost in organic waste or use for mulch in garden."
|
59 |
+
}
|
60 |
+
|
61 |
+
# Load model
|
62 |
+
self.model = self._load_model()
|
63 |
+
|
64 |
+
# Image preprocessing
|
65 |
+
self.transform = transforms.Compose([
|
66 |
+
transforms.Resize((224, 224)),
|
67 |
+
transforms.ToTensor(),
|
68 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
69 |
+
])
|
70 |
+
|
71 |
+
print(f"β
MAE Waste Classifier loaded on {self.device}")
|
72 |
+
print(f"π Model: ViT-Base MAE, Classes: {len(self.class_names)}")
|
73 |
+
|
74 |
+
def _load_model(self):
|
75 |
+
"""Load the finetuned MAE model."""
|
76 |
+
try:
|
77 |
+
# Create ViT model using timm
|
78 |
+
model = timm.create_model('vit_base_patch16_224', pretrained=False, num_classes=len(self.class_names))
|
79 |
+
|
80 |
+
# Load checkpoint
|
81 |
+
checkpoint = torch.load(self.model_path, map_location=self.device)
|
82 |
+
|
83 |
+
# Load state dict
|
84 |
+
if 'model_state_dict' in checkpoint:
|
85 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
86 |
+
else:
|
87 |
+
model.load_state_dict(checkpoint)
|
88 |
+
|
89 |
+
model.to(self.device)
|
90 |
+
model.eval()
|
91 |
+
|
92 |
+
print(f"β
Loaded finetuned MAE model from {self.model_path}")
|
93 |
+
return model
|
94 |
+
|
95 |
+
except Exception as e:
|
96 |
+
print(f"β Error loading model: {e}")
|
97 |
+
raise
|
98 |
+
|
99 |
+
def classify_image(self, image, top_k=5):
|
100 |
+
"""
|
101 |
+
Classify a waste image.
|
102 |
+
|
103 |
+
Args:
|
104 |
+
image: PIL Image or path to image
|
105 |
+
top_k: Number of top predictions to return
|
106 |
+
|
107 |
+
Returns:
|
108 |
+
dict: Classification results
|
109 |
+
"""
|
110 |
+
try:
|
111 |
+
# Load and preprocess image
|
112 |
+
if isinstance(image, str):
|
113 |
+
image = Image.open(image).convert('RGB')
|
114 |
+
elif not isinstance(image, Image.Image):
|
115 |
+
raise ValueError("Image must be PIL Image or path string")
|
116 |
+
|
117 |
+
# Preprocess
|
118 |
+
input_tensor = self.transform(image).unsqueeze(0).to(self.device)
|
119 |
+
|
120 |
+
# Inference
|
121 |
+
with torch.no_grad():
|
122 |
+
outputs = self.model(input_tensor)
|
123 |
+
probabilities = F.softmax(outputs, dim=1)
|
124 |
+
|
125 |
+
# Get top predictions
|
126 |
+
top_probs, top_indices = torch.topk(probabilities, k=min(top_k, len(self.class_names)))
|
127 |
+
|
128 |
+
top_predictions = []
|
129 |
+
for prob, idx in zip(top_probs[0], top_indices[0]):
|
130 |
+
top_predictions.append({
|
131 |
+
'class': self.class_names[idx.item()],
|
132 |
+
'confidence': prob.item()
|
133 |
+
})
|
134 |
+
|
135 |
+
# Best prediction
|
136 |
+
best_pred = top_predictions[0]
|
137 |
+
|
138 |
+
return {
|
139 |
+
'success': True,
|
140 |
+
'predicted_class': best_pred['class'],
|
141 |
+
'confidence': best_pred['confidence'],
|
142 |
+
'top_predictions': top_predictions
|
143 |
+
}
|
144 |
+
|
145 |
+
except Exception as e:
|
146 |
+
return {
|
147 |
+
'success': False,
|
148 |
+
'error': str(e)
|
149 |
+
}
|
150 |
+
|
151 |
+
def get_disposal_instructions(self, class_name):
|
152 |
+
"""Get disposal instructions for a waste class."""
|
153 |
+
return self.disposal_instructions.get(class_name, "No specific instructions available.")
|
154 |
+
|
155 |
+
def get_model_info(self):
|
156 |
+
"""Get information about the loaded model."""
|
157 |
+
return {
|
158 |
+
'model_name': 'ViT-Base MAE',
|
159 |
+
'architecture': 'Vision Transformer (ViT-Base)',
|
160 |
+
'pretrained': 'MAE (Masked Autoencoder)',
|
161 |
+
'num_classes': len(self.class_names),
|
162 |
+
'device': self.device,
|
163 |
+
'model_path': self.model_path
|
164 |
+
}
|
165 |
+
|
166 |
+
# Test the classifier
|
167 |
+
if __name__ == "__main__":
|
168 |
+
print("π§ͺ Testing MAE Waste Classifier...")
|
169 |
+
|
170 |
+
try:
|
171 |
+
# Initialize classifier
|
172 |
+
classifier = MAEWasteClassifier()
|
173 |
+
|
174 |
+
# Test with a sample image if available
|
175 |
+
test_images = [
|
176 |
+
"fail_images/image.webp",
|
177 |
+
"fail_images/IMG_9501.webp"
|
178 |
+
]
|
179 |
+
|
180 |
+
for img_path in test_images:
|
181 |
+
if os.path.exists(img_path):
|
182 |
+
print(f"\nπ Testing with {img_path}")
|
183 |
+
result = classifier.classify_image(img_path)
|
184 |
+
|
185 |
+
if result['success']:
|
186 |
+
print(f"β
Predicted: {result['predicted_class']} ({result['confidence']:.3f})")
|
187 |
+
print(f"π Instructions: {classifier.get_disposal_instructions(result['predicted_class'])}")
|
188 |
+
|
189 |
+
print("\nπ Top predictions:")
|
190 |
+
for i, pred in enumerate(result['top_predictions'][:3], 1):
|
191 |
+
print(f" {i}. {pred['class']}: {pred['confidence']:.3f}")
|
192 |
+
else:
|
193 |
+
print(f"β Error: {result['error']}")
|
194 |
+
break
|
195 |
+
else:
|
196 |
+
print("βΉοΈ No test images found, but classifier loaded successfully!")
|
197 |
+
|
198 |
+
# Print model info
|
199 |
+
info = classifier.get_model_info()
|
200 |
+
print(f"\nπ€ Model Info:")
|
201 |
+
for key, value in info.items():
|
202 |
+
print(f" {key}: {value}")
|
203 |
+
|
204 |
+
print("\nSuccess!")
|
205 |
+
|
206 |
+
except Exception as e:
|
207 |
+
print(f"β Error: {e}")
|
208 |
+
import traceback
|
209 |
+
traceback.print_exc()
|