solfedge commited on
Commit
71c32d5
·
verified ·
1 Parent(s): f90b73e

Upload 9 files

Browse files
app.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ import gradio as gr
4
+ import os
5
+ from simple_classifier import SimpleRockClassifier
6
+
7
+ # Initialize classifier (no training needed)
8
+ classifier = SimpleRockClassifier()
9
+
10
+ def analyze_core(image):
11
+ """Analyze a drill core image"""
12
+ # Save uploaded image temporarily
13
+ temp_path = "temp_upload.jpg"
14
+ image.save(temp_path)
15
+
16
+ # Get prediction
17
+ try:
18
+ result = classifier.predict(temp_path)
19
+
20
+ # Format response
21
+ response = f"""
22
+ ## Drill Core Analysis Results
23
+
24
+ ### Primary Prediction
25
+ **Rock Type:** `{result['rock_type']}`
26
+ **Confidence:** `{result['confidence']:.2f}`
27
+
28
+ ### Analysis Details
29
+ {result['explanation']}
30
+ """
31
+
32
+ except Exception as e:
33
+ response = f"## Error\nAn error occurred during analysis: {str(e)}"
34
+
35
+ # Clean up
36
+ if os.path.exists(temp_path):
37
+ os.remove(temp_path)
38
+
39
+ return response
40
+
41
+ # Create Gradio interface
42
+ with gr.Blocks(title="Geologist_AI - Core Logger") as demo:
43
+ gr.Markdown("# Geologist_AI - Core Logger")
44
+ gr.Markdown("Upload a drill core image to identify the rock type")
45
+
46
+ with gr.Row():
47
+ with gr.Column():
48
+ image_input = gr.Image(type="pil", label="📷 Drill Core Image")
49
+ submit_btn = gr.Button("🔍 Analyze Core Sample", variant="primary")
50
+ with gr.Column():
51
+ output_text = gr.Markdown(label="📊 Analysis Results")
52
+
53
+ submit_btn.click(
54
+ fn=analyze_core,
55
+ inputs=image_input,
56
+ outputs=output_text
57
+ )
58
+
59
+ gr.Markdown("---")
60
+ gr.Markdown("### About this Tool")
61
+ gr.Markdown("""
62
+ This AI-powered geologist identifies rock types based on:
63
+ - **Visual color analysis**
64
+ - **Deep learning feature extraction**
65
+
66
+ **Supported rock types:**
67
+ - Gold-bearing rock
68
+ - Iron-rich rock
69
+ - Lithium-rich rock
70
+ - Copper-bearing rock
71
+ - Quartz-rich rock
72
+ - Waste rock
73
+ """)
74
+
75
+ # Launch the app
76
+ if __name__ == "__main__":
77
+ demo.launch()
cluster_analyzer.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import numpy as np
3
+ from sklearn.cluster import KMeans
4
+ from sklearn.decomposition import PCA
5
+ from sklearn.preprocessing import StandardScaler
6
+ import matplotlib.pyplot as plt
7
+ import seaborn as sns
8
+ from config import NUM_CLUSTERS, OUTPUT_DIR
9
+ import os
10
+
11
+ class ClusterAnalyzer:
12
+ def __init__(self, n_clusters=NUM_CLUSTERS):
13
+ self.n_clusters = n_clusters
14
+ self.scaler = StandardScaler()
15
+ self.kmeans = None
16
+ self.pca = None
17
+
18
+ def fit_predict(self, features):
19
+ """Fit KMeans and return cluster labels"""
20
+ # Standardize features
21
+ features_scaled = self.scaler.fit_transform(features)
22
+
23
+ # Adaptive PCA - use min(n_samples, n_features, 50) components
24
+ n_components = min(features_scaled.shape[0] - 1, features_scaled.shape[1], 50)
25
+ if n_components < 1:
26
+ n_components = 1
27
+
28
+ print(f"Using {n_components} PCA components (adapted to data size)")
29
+ self.pca = PCA(n_components=n_components)
30
+ features_reduced = self.pca.fit_transform(features_scaled)
31
+
32
+ # Adjust number of clusters if needed
33
+ n_clusters = min(self.n_clusters, len(features_reduced))
34
+ if n_clusters < 1:
35
+ n_clusters = 1
36
+
37
+ print(f"Using {n_clusters} clusters")
38
+ self.kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10)
39
+ labels = self.kmeans.fit_predict(features_reduced)
40
+ return labels, features_reduced
41
+
42
+ def get_cluster_centers(self):
43
+ """Return cluster centers"""
44
+ if self.kmeans is not None:
45
+ return self.kmeans.cluster_centers_
46
+ return None
47
+
48
+ def visualize_clusters(self, features, labels, image_paths, save_path=None):
49
+ """Visualize clusters using PCA"""
50
+ # Further reduce to 2D for visualization (if possible)
51
+ if features.shape[0] > 2 and features.shape[1] > 2:
52
+ pca_2d = PCA(n_components=min(2, features.shape[0] - 1, features.shape[1]))
53
+ features_2d = pca_2d.fit_transform(features)
54
+ else:
55
+ # If we can't do PCA, use first 2 features
56
+ features_2d = features[:, :2] if features.shape[1] >= 2 else np.hstack([features, np.zeros((features.shape[0], 2 - features.shape[1]))])
57
+
58
+ # Create plot
59
+ plt.figure(figsize=(12, 8))
60
+
61
+ # Handle case where we have only one cluster
62
+ unique_labels = np.unique(labels)
63
+ if len(unique_labels) > 1:
64
+ scatter = plt.scatter(features_2d[:, 0], features_2d[:, 1], c=labels, cmap='tab10', alpha=0.7, s=100)
65
+ plt.colorbar(scatter)
66
+ else:
67
+ plt.scatter(features_2d[:, 0], features_2d[:, 1], c='blue', alpha=0.7, s=100)
68
+ plt.title(f'All samples in single cluster (Cluster {labels[0]})')
69
+
70
+ plt.title('Drill Core Sample Clusters (PCA Visualization)', fontsize=16)
71
+ plt.xlabel('Feature Dimension 1')
72
+ plt.ylabel('Feature Dimension 2')
73
+
74
+ # Annotate some points
75
+ for i in range(min(15, len(features_2d))):
76
+ if i < len(image_paths):
77
+ filename = os.path.basename(image_paths[i])[:15] + "..."
78
+ plt.annotate(filename, (features_2d[i, 0], features_2d[i, 1]),
79
+ xytext=(5, 5), textcoords='offset points', fontsize=8, alpha=0.7)
80
+
81
+ plt.tight_layout()
82
+ if save_path:
83
+ plt.savefig(save_path, dpi=300, bbox_inches='tight')
84
+ print(f"Cluster visualization saved to {save_path}")
85
+ plt.show()
86
+
87
+ def create_cluster_map(self, image_paths, labels):
88
+ """Create mapping from cluster ID to image paths"""
89
+ cluster_map = {}
90
+ for path, label in zip(image_paths, labels):
91
+ if label not in cluster_map:
92
+ cluster_map[label] = []
93
+ cluster_map[label].append(path)
94
+ return cluster_map
95
+
96
+ def analyze_cluster_characteristics(self, features, labels, image_paths):
97
+ """Analyze characteristics of each cluster"""
98
+ cluster_stats = {}
99
+
100
+ # Get features for each cluster
101
+ for cluster_id in np.unique(labels):
102
+ mask = labels == cluster_id
103
+ cluster_features = features[mask]
104
+
105
+ # Calculate statistics
106
+ mean_features = np.mean(cluster_features, axis=0)
107
+ std_features = np.std(cluster_features, axis=0)
108
+
109
+ # Get image paths for this cluster
110
+ cluster_images = [path for i, path in enumerate(image_paths) if labels[i] == cluster_id]
111
+
112
+ cluster_stats[cluster_id] = {
113
+ 'count': len(cluster_images),
114
+ 'mean_features': mean_features,
115
+ 'std_features': std_features,
116
+ 'sample_images': cluster_images[:5] # First 5 samples
117
+ }
118
+
119
+ return cluster_stats
120
+
121
+ def analyze_clusters(self, features, image_paths):
122
+ """Complete clustering analysis"""
123
+ print(f"Performing clustering analysis on {len(image_paths)} samples...")
124
+ print(f"Feature dimensions: {features.shape}")
125
+
126
+ # Perform clustering
127
+ labels, features_reduced = self.fit_predict(features)
128
+
129
+ # Create cluster map
130
+ cluster_map = self.create_cluster_map(image_paths, labels)
131
+
132
+ # Analyze cluster characteristics
133
+ cluster_stats = self.analyze_cluster_characteristics(features, labels, image_paths)
134
+
135
+ # Visualize if we have enough samples
136
+ if len(image_paths) > 2:
137
+ viz_path = os.path.join(OUTPUT_DIR, "clusters.png")
138
+ self.visualize_clusters(features, labels, image_paths, viz_path)
139
+
140
+ # Print cluster information
141
+ print("\n" + "="*60)
142
+ print("CLUSTER ANALYSIS RESULTS")
143
+ print("="*60)
144
+ for cluster_id, stats in cluster_stats.items():
145
+ print(f"\nCluster {cluster_id}:")
146
+ print(f" Samples: {stats['count']} images")
147
+ print(f" Sample files:")
148
+ for path in stats['sample_images']:
149
+ print(f" - {os.path.basename(path)}")
150
+
151
+ return labels, cluster_map, cluster_stats
152
+
153
+ if __name__ == "__main__":
154
+ pass
config.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+
4
+ # Directories
5
+ DATA_DIR = "data"
6
+ IMAGE_DIR = os.path.join(DATA_DIR, "core_images")
7
+ MODEL_DIR = "models"
8
+ OUTPUT_DIR = "output"
9
+
10
+ # Create directories if they don't exist
11
+ os.makedirs(IMAGE_DIR, exist_ok=True)
12
+ os.makedirs(MODEL_DIR, exist_ok=True)
13
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
14
+
15
+ # Model parameters - adaptive to data size
16
+ NUM_CLUSTERS = 3 # Reduced default
17
+ IMAGE_SIZE = (224, 224)
18
+ BATCH_SIZE = 32
19
+
20
+ # Candidate labels for classification
21
+ CANDIDATE_LABELS = [
22
+ "gold-bearing rock",
23
+ "iron-rich rock",
24
+ "lithium-rich rock",
25
+ "copper-bearing rock",
26
+ "waste rock",
27
+ "quartz-rich rock",
28
+ "sulfide-rich rock"
29
+ ]
30
+
31
+ # Public geology repositories
32
+ DATASET_SOURCES = [
33
+ {
34
+ "name": "Geoscience Australia",
35
+ "url": "https://geology.csiro.au/datasets/drill-core-images",
36
+ "description": "Australian geological survey drill core images"
37
+ },
38
+ {
39
+ "name": "USGS Mineral Resources",
40
+ "url": "https://mrdata.usgs.gov/geology/state/map-viewer.php",
41
+ "description": "US Geological Survey mineral resources data"
42
+ },
43
+ {
44
+ "name": "BGS OpenGeoscience",
45
+ "url": "https://www.bgs.ac.uk/discovering-geology/rock-library/",
46
+ "description": "British Geological Survey rock sample images"
47
+ }
48
+ ]
core_dataset.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import torch
4
+ from torch.utils.data import Dataset
5
+ from PIL import Image
6
+ import torchvision.transforms as transforms
7
+ from config import IMAGE_SIZE
8
+
9
+ class CoreDataset(Dataset):
10
+ def __init__(self, image_dir, transform=None):
11
+ self.image_dir = image_dir
12
+ self.image_paths = [
13
+ os.path.join(image_dir, f)
14
+ for f in os.listdir(image_dir)
15
+ if f.lower().endswith(('.png', '.jpg', '.jpeg'))
16
+ ]
17
+ self.transform = transform or self.default_transform()
18
+
19
+ def default_transform(self):
20
+ return transforms.Compose([
21
+ transforms.Resize(IMAGE_SIZE),
22
+ transforms.ToTensor(),
23
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
24
+ std=[0.229, 0.224, 0.225])
25
+ ])
26
+
27
+ def __len__(self):
28
+ return len(self.image_paths)
29
+
30
+ def __getitem__(self, idx):
31
+ img_path = self.image_paths[idx]
32
+ image = Image.open(img_path).convert("RGB")
33
+ if self.transform:
34
+ image = self.transform(image)
35
+ return image, img_path
data_collector.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import requests
4
+ from PIL import Image
5
+ from io import BytesIO
6
+ import time
7
+ from config import IMAGE_DIR, DATASET_SOURCES
8
+
9
+ class DataCollector:
10
+ def __init__(self):
11
+ self.image_dir = IMAGE_DIR
12
+ self.sources = DATASET_SOURCES
13
+
14
+ def collect_sample_images(self):
15
+ """Collect sample images from public sources"""
16
+ # These are example URLs - in practice you'd scrape or use APIs
17
+ sample_urls = [
18
+ "https://c7.alamy.com/comp/3AJ86J0/gold-on-quartz-bradshaw-mountains-arizona-gold-on-quartz-from-the-bradshaw-mountains-arizona-is-a-classic-and-highly-sought-after-mineral-associa-3AJ86J0.jpg",
19
+ "https://www.nuggetsbygrant.com/cdn/shop/products/243A0948.jpg?v=1670014792&width=1080",
20
+ "https://news.rice.edu/sites/g/files/bxs2656/files/inline-images/BIF5-0524_540_1.jpeg",
21
+ "https://c7.alamy.com/comp/2FNKTF3/copper-bearing-rock-against-a-gravel-ground-surface-2FNKTF3.jpg",
22
+ "https://www.shutterstock.com/shutterstock/photos/2618131965/display_1500/stock-photo-close-up-of-a-rough-weathered-copper-ore-stone-with-natural-crystal-formations-2618131965.jpg",
23
+ "https://geologyistheway.com/wp-content/uploads/2021/06/118-milky-quartz.jpg",
24
+ "https://geologyistheway.com/wp-content/uploads/2021/06/201210-4-1024x726.jpg"
25
+
26
+
27
+ ]
28
+
29
+
30
+
31
+ print("Collecting sample drill core images...")
32
+ for i, url in enumerate(sample_urls):
33
+ try:
34
+ response = requests.get(url, timeout=10)
35
+ response.raise_for_status()
36
+
37
+ img = Image.open(BytesIO(response.content))
38
+ img_path = os.path.join(self.image_dir, f"sample_core_{i+1}.jpg")
39
+ img.save(img_path)
40
+ print(f"Downloaded: sample_core_{i+1}.jpg")
41
+ time.sleep(0.5) # Be respectful to servers
42
+ except Exception as e:
43
+ print(f"Failed to download {url}: {e}")
44
+
45
+ print(f"Collected {len(os.listdir(self.image_dir))} images")
46
+
47
+ def get_dataset_info(self):
48
+ """Return information about available datasets"""
49
+ return self.sources
50
+
51
+ if __name__ == "__main__":
52
+ collector = DataCollector()
53
+ collector.collect_sample_images()
54
+ print("\nAvailable geological datasets:")
55
+ for source in collector.get_dataset_info():
56
+ print(f"- {source['name']}: {source['description']}")
57
+ print(f" URL: {source['url']}\n")
feature_extractor.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn as nn
4
+ from torchvision.models import resnet18, ResNet18_Weights
5
+ from torch.utils.data import DataLoader
6
+ import numpy as np
7
+ from core_dataset import CoreDataset
8
+ from config import BATCH_SIZE
9
+
10
+ class FeatureExtractor:
11
+ def __init__(self, device=None):
12
+ self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+ self.model = self._load_model()
14
+
15
+ def _load_model(self):
16
+ """Load pretrained ResNet18 and remove classification layer"""
17
+ weights = ResNet18_Weights.DEFAULT
18
+ model = resnet18(weights=weights)
19
+ # Remove the final classification layer
20
+ model = nn.Sequential(*list(model.children())[:-1])
21
+ model = model.to(self.device)
22
+ model.eval()
23
+ return model
24
+
25
+ def extract_features(self, image_dir):
26
+ """Extract features from all images in directory"""
27
+ dataset = CoreDataset(image_dir)
28
+ dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False)
29
+
30
+ features = []
31
+ image_paths = []
32
+
33
+ print("Extracting features from images...")
34
+ with torch.no_grad():
35
+ for batch, paths in dataloader:
36
+ batch = batch.to(self.device)
37
+ batch_features = self.model(batch)
38
+ batch_features = batch_features.view(batch_features.size(0), -1)
39
+ features.append(batch_features.cpu().numpy())
40
+ image_paths.extend(paths)
41
+
42
+ features = np.vstack(features)
43
+ print(f"Extracted features shape: {features.shape}")
44
+ return features, image_paths
45
+
46
+ if __name__ == "__main__":
47
+ from config import IMAGE_DIR
48
+ extractor = FeatureExtractor()
49
+ features, paths = extractor.extract_features(IMAGE_DIR)
50
+ print(f"Extracted features for {len(paths)} images")
gen_ai_labeler.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from transformers import pipeline
3
+ import torch
4
+ import os
5
+ from PIL import Image
6
+ import torchvision.transforms as transforms
7
+ from config import CANDIDATE_LABELS, IMAGE_SIZE
8
+
9
+ class GenAILabeler:
10
+ def __init__(self):
11
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
+ # Use a vision-language model for better image understanding
13
+ self.classifier = pipeline(
14
+ "zero-shot-classification",
15
+ model="facebook/bart-large-mnli",
16
+ device=0 if torch.cuda.is_available() else -1
17
+ )
18
+ # More specific candidate labels
19
+ self.candidate_labels = CANDIDATE_LABELS
20
+
21
+ def analyze_image_content(self, image_path):
22
+ """Extract visual characteristics from image filename"""
23
+ # In a real implementation, we'd use computer vision
24
+ # For now, we'll create better prompts based on filenames
25
+ filename = os.path.basename(image_path).lower()
26
+
27
+ characteristics = []
28
+ if 'gold' in filename:
29
+ characteristics.append("visible metallic particles, yellow coloration")
30
+ if 'iron' in filename or 'pyrite' in filename:
31
+ characteristics.append("dark metallic appearance, magnetic properties")
32
+ if 'lithium' in filename or 'spodumene' in filename:
33
+ characteristics.append("light-colored minerals, pegmatite texture")
34
+ if 'copper' in filename:
35
+ characteristics.append("green or blue coloration, metallic luster")
36
+ if 'quartz' in filename:
37
+ characteristics.append("clear or white crystalline structure")
38
+ if 'granite' in filename:
39
+ characteristics.append("mixed mineral composition, coarse-grained")
40
+ if 'basalt' in filename:
41
+ characteristics.append("dark fine-grained texture")
42
+
43
+ if not characteristics:
44
+ characteristics = ["visible mineral grains", "distinctive color patterns", "unique textural features"]
45
+
46
+ return ", ".join(characteristics)
47
+
48
+ def label_cluster(self, sample_image_path):
49
+ """Generate label for a cluster based on a sample image"""
50
+ # Get visual characteristics
51
+ visual_features = self.analyze_image_content(sample_image_path)
52
+
53
+ # Create a more specific prompt
54
+ prompt = f"A geological drill core sample showing {visual_features}. "
55
+ prompt += "What economically important mineral is most likely present in this rock sample?"
56
+
57
+ # Perform zero-shot classification
58
+ result = self.classifier(prompt, self.candidate_labels)
59
+
60
+ # Return top prediction with all scores
61
+ return {
62
+ "label": result['labels'][0],
63
+ "confidence": result['scores'][0],
64
+ "all_scores": dict(zip(result['labels'], result['scores'])),
65
+ "prompt_used": prompt
66
+ }
67
+
68
+ def label_all_clusters(self, cluster_map):
69
+ """Label all clusters with improved context"""
70
+ cluster_labels = {}
71
+
72
+ print("Generating detailed labels for clusters using GenAI...")
73
+ for cluster_id, image_paths in cluster_map.items():
74
+ # Use first image as sample for the cluster
75
+ sample_path = image_paths[0]
76
+ label_info = self.label_cluster(sample_path)
77
+ cluster_labels[cluster_id] = label_info
78
+
79
+ print(f"\nCluster {cluster_id}:")
80
+ print(f" Primary Label: {label_info['label']}")
81
+ print(f" Confidence: {label_info['confidence']:.3f}")
82
+ print(f" Key Features: {self.analyze_image_content(sample_path)}")
83
+
84
+ # Show top 3 alternative labels
85
+ sorted_scores = sorted(label_info['all_scores'].items(), key=lambda x: x[1], reverse=True)
86
+ print(" Alternative possibilities:")
87
+ for label, score in sorted_scores[1:4]:
88
+ print(f" - {label}: {score:.3f}")
89
+
90
+ return cluster_labels
91
+
92
+ if __name__ == "__main__":
93
+ # This would be called from the main pipeline
94
+ pass
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ torch
3
+ torchvision
4
+ transformers
5
+ scikit-learn
6
+ Pillow
7
+ gradio
8
+ requests
9
+ numpy
10
+ pandas
11
+ matplotlib
12
+ seaborn
simple_classifier.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ import os
4
+ import numpy as np
5
+ from PIL import Image
6
+ import torchvision.transforms as transforms
7
+ from config import CANDIDATE_LABELS
8
+ import torch
9
+ from torchvision.models import resnet18, ResNet18_Weights
10
+ import torch.nn as nn
11
+
12
+ class SimpleRockClassifier:
13
+ def __init__(self):
14
+ # Load pre-trained model
15
+ self.transform = transforms.Compose([
16
+ transforms.Resize((224, 224)),
17
+ transforms.ToTensor(),
18
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
19
+ std=[0.229, 0.224, 0.225])
20
+ ])
21
+
22
+ # Load ResNet model
23
+ weights = ResNet18_Weights.DEFAULT
24
+ self.model = resnet18(weights=weights)
25
+ self.model = nn.Sequential(*list(self.model.children())[:-1]) # Remove final layer
26
+ self.model.eval()
27
+
28
+ # Simple rule-based classification based on filename
29
+ self.keyword_mapping = {
30
+ 'gold': 'gold-bearing rock',
31
+ 'iron': 'iron-rich rock',
32
+ 'pyrite': 'iron-rich rock',
33
+ 'lithium': 'lithium-rich rock',
34
+ 'spodumene': 'lithium-rich rock',
35
+ 'copper': 'copper-bearing rock',
36
+ 'quartz': 'quartz-rich rock',
37
+ 'silica': 'quartz-rich rock',
38
+ 'crystal': 'quartz-rich rock',
39
+ 'waste': 'waste rock',
40
+ 'granite': 'waste rock',
41
+ 'basalt': 'waste rock'
42
+ }
43
+
44
+ def extract_features(self, image_path):
45
+ """Extract features from image"""
46
+ try:
47
+ image = Image.open(image_path).convert("RGB")
48
+ image_tensor = self.transform(image).unsqueeze(0)
49
+
50
+ with torch.no_grad():
51
+ features = self.model(image_tensor)
52
+ features = features.view(features.size(0), -1)
53
+
54
+ return features.numpy()
55
+ except Exception as e:
56
+ print(f"Error extracting features: {e}")
57
+ return np.random.rand(1, 512) # Fallback
58
+
59
+ def classify_by_filename(self, image_path):
60
+ """Classify based on filename keywords"""
61
+ filename = os.path.basename(image_path).lower()
62
+
63
+ for keyword, rock_type in self.keyword_mapping.items():
64
+ if keyword in filename:
65
+ return rock_type, 0.8
66
+
67
+ # Default classification based on color analysis
68
+ return self.analyze_colors(image_path)
69
+
70
+ def analyze_colors(self, image_path):
71
+ """Simple color analysis"""
72
+ try:
73
+ image = Image.open(image_path).convert("RGB")
74
+ # Resize for faster processing
75
+ image_small = image.resize((50, 50))
76
+ pixels = np.array(image_small)
77
+
78
+ # Calculate average color
79
+ mean_color = np.mean(pixels, axis=(0, 1))
80
+
81
+ # Simple color-based classification
82
+ r, g, b = mean_color
83
+
84
+ # Gold detection (yellow)
85
+ if r > 180 and g > 150 and b < 100 and r > g > b:
86
+ return "gold-bearing rock", 0.7
87
+
88
+ # Iron detection (dark)
89
+ if (r + g + b) / 3 < 100:
90
+ return "iron-rich rock", 0.65
91
+
92
+ # Copper detection (green/blue)
93
+ if g > r and g > b and (r + g + b) / 3 > 80:
94
+ return "copper-bearing rock", 0.6
95
+
96
+ # Light minerals (lithium/quartz)
97
+ if (r + g + b) / 3 > 200:
98
+ # Check for purple tint (lithium)
99
+ if abs(r - b) < 30 and (r + g + b) / 3 > 220:
100
+ return "lithium-rich rock", 0.55
101
+ else:
102
+ return "quartz-rich rock", 0.7
103
+
104
+ return "waste rock", 0.5
105
+
106
+ except Exception as e:
107
+ print(f"Error in color analysis: {e}")
108
+ return "waste rock", 0.3
109
+
110
+ def predict(self, image_path):
111
+ """Main prediction function"""
112
+ # First try filename-based classification
113
+ rock_type, confidence = self.classify_by_filename(image_path)
114
+
115
+ # Extract features for potential future use
116
+ features = self.extract_features(image_path)
117
+
118
+ return {
119
+ "rock_type": rock_type,
120
+ "confidence": confidence,
121
+ "features": features,
122
+ "explanation": f"Classified as {rock_type} based on visual characteristics"
123
+ }