Geologist_AI / simple_classifier.py
solfedge's picture
Upload 9 files
71c32d5 verified
import os
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
from config import CANDIDATE_LABELS
import torch
from torchvision.models import resnet18, ResNet18_Weights
import torch.nn as nn
class SimpleRockClassifier:
def __init__(self):
# Load pre-trained model
self.transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
# Load ResNet model
weights = ResNet18_Weights.DEFAULT
self.model = resnet18(weights=weights)
self.model = nn.Sequential(*list(self.model.children())[:-1]) # Remove final layer
self.model.eval()
# Simple rule-based classification based on filename
self.keyword_mapping = {
'gold': 'gold-bearing rock',
'iron': 'iron-rich rock',
'pyrite': 'iron-rich rock',
'lithium': 'lithium-rich rock',
'spodumene': 'lithium-rich rock',
'copper': 'copper-bearing rock',
'quartz': 'quartz-rich rock',
'silica': 'quartz-rich rock',
'crystal': 'quartz-rich rock',
'waste': 'waste rock',
'granite': 'waste rock',
'basalt': 'waste rock'
}
def extract_features(self, image_path):
"""Extract features from image"""
try:
image = Image.open(image_path).convert("RGB")
image_tensor = self.transform(image).unsqueeze(0)
with torch.no_grad():
features = self.model(image_tensor)
features = features.view(features.size(0), -1)
return features.numpy()
except Exception as e:
print(f"Error extracting features: {e}")
return np.random.rand(1, 512) # Fallback
def classify_by_filename(self, image_path):
"""Classify based on filename keywords"""
filename = os.path.basename(image_path).lower()
for keyword, rock_type in self.keyword_mapping.items():
if keyword in filename:
return rock_type, 0.8
# Default classification based on color analysis
return self.analyze_colors(image_path)
def analyze_colors(self, image_path):
"""Simple color analysis"""
try:
image = Image.open(image_path).convert("RGB")
# Resize for faster processing
image_small = image.resize((50, 50))
pixels = np.array(image_small)
# Calculate average color
mean_color = np.mean(pixels, axis=(0, 1))
# Simple color-based classification
r, g, b = mean_color
# Gold detection (yellow)
if r > 180 and g > 150 and b < 100 and r > g > b:
return "gold-bearing rock", 0.7
# Iron detection (dark)
if (r + g + b) / 3 < 100:
return "iron-rich rock", 0.65
# Copper detection (green/blue)
if g > r and g > b and (r + g + b) / 3 > 80:
return "copper-bearing rock", 0.6
# Light minerals (lithium/quartz)
if (r + g + b) / 3 > 200:
# Check for purple tint (lithium)
if abs(r - b) < 30 and (r + g + b) / 3 > 220:
return "lithium-rich rock", 0.55
else:
return "quartz-rich rock", 0.7
return "waste rock", 0.5
except Exception as e:
print(f"Error in color analysis: {e}")
return "waste rock", 0.3
def predict(self, image_path):
"""Main prediction function"""
# First try filename-based classification
rock_type, confidence = self.classify_by_filename(image_path)
# Extract features for potential future use
features = self.extract_features(image_path)
return {
"rock_type": rock_type,
"confidence": confidence,
"features": features,
"explanation": f"Classified as {rock_type} based on visual characteristics"
}