File size: 4,417 Bytes
71c32d5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124


import os
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
from config import CANDIDATE_LABELS
import torch
from torchvision.models import resnet18, ResNet18_Weights
import torch.nn as nn

class SimpleRockClassifier:
    def __init__(self):
        # Load pre-trained model
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                               std=[0.229, 0.224, 0.225])
        ])
        
        # Load ResNet model
        weights = ResNet18_Weights.DEFAULT
        self.model = resnet18(weights=weights)
        self.model = nn.Sequential(*list(self.model.children())[:-1])  # Remove final layer
        self.model.eval()
        
        # Simple rule-based classification based on filename
        self.keyword_mapping = {
            'gold': 'gold-bearing rock',
            'iron': 'iron-rich rock',
            'pyrite': 'iron-rich rock',
            'lithium': 'lithium-rich rock',
            'spodumene': 'lithium-rich rock',
            'copper': 'copper-bearing rock',
            'quartz': 'quartz-rich rock',
            'silica': 'quartz-rich rock',
            'crystal': 'quartz-rich rock',
            'waste': 'waste rock',
            'granite': 'waste rock',
            'basalt': 'waste rock'
        }
    
    def extract_features(self, image_path):
        """Extract features from image"""
        try:
            image = Image.open(image_path).convert("RGB")
            image_tensor = self.transform(image).unsqueeze(0)
            
            with torch.no_grad():
                features = self.model(image_tensor)
                features = features.view(features.size(0), -1)
                
            return features.numpy()
        except Exception as e:
            print(f"Error extracting features: {e}")
            return np.random.rand(1, 512)  # Fallback
    
    def classify_by_filename(self, image_path):
        """Classify based on filename keywords"""
        filename = os.path.basename(image_path).lower()
        
        for keyword, rock_type in self.keyword_mapping.items():
            if keyword in filename:
                return rock_type, 0.8
        
        # Default classification based on color analysis
        return self.analyze_colors(image_path)
    
    def analyze_colors(self, image_path):
        """Simple color analysis"""
        try:
            image = Image.open(image_path).convert("RGB")
            # Resize for faster processing
            image_small = image.resize((50, 50))
            pixels = np.array(image_small)
            
            # Calculate average color
            mean_color = np.mean(pixels, axis=(0, 1))
            
            # Simple color-based classification
            r, g, b = mean_color
            
            # Gold detection (yellow)
            if r > 180 and g > 150 and b < 100 and r > g > b:
                return "gold-bearing rock", 0.7
            
            # Iron detection (dark)
            if (r + g + b) / 3 < 100:
                return "iron-rich rock", 0.65
            
            # Copper detection (green/blue)
            if g > r and g > b and (r + g + b) / 3 > 80:
                return "copper-bearing rock", 0.6
            
            # Light minerals (lithium/quartz)
            if (r + g + b) / 3 > 200:
                # Check for purple tint (lithium)
                if abs(r - b) < 30 and (r + g + b) / 3 > 220:
                    return "lithium-rich rock", 0.55
                else:
                    return "quartz-rich rock", 0.7
            
            return "waste rock", 0.5
            
        except Exception as e:
            print(f"Error in color analysis: {e}")
            return "waste rock", 0.3
    
    def predict(self, image_path):
        """Main prediction function"""
        # First try filename-based classification
        rock_type, confidence = self.classify_by_filename(image_path)
        
        # Extract features for potential future use
        features = self.extract_features(image_path)
        
        return {
            "rock_type": rock_type,
            "confidence": confidence,
            "features": features,
            "explanation": f"Classified as {rock_type} based on visual characteristics"
        }