Spaces:
Sleeping
Sleeping
Upload 9 files
Browse files- app.py +77 -0
- cluster_analyzer.py +154 -0
- config.py +48 -0
- core_dataset.py +35 -0
- data_collector.py +57 -0
- feature_extractor.py +50 -0
- gen_ai_labeler.py +94 -0
- requirements.txt +12 -0
- simple_classifier.py +123 -0
app.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
import gradio as gr
|
4 |
+
import os
|
5 |
+
from simple_classifier import SimpleRockClassifier
|
6 |
+
|
7 |
+
# Initialize classifier (no training needed)
|
8 |
+
classifier = SimpleRockClassifier()
|
9 |
+
|
10 |
+
def analyze_core(image):
|
11 |
+
"""Analyze a drill core image"""
|
12 |
+
# Save uploaded image temporarily
|
13 |
+
temp_path = "temp_upload.jpg"
|
14 |
+
image.save(temp_path)
|
15 |
+
|
16 |
+
# Get prediction
|
17 |
+
try:
|
18 |
+
result = classifier.predict(temp_path)
|
19 |
+
|
20 |
+
# Format response
|
21 |
+
response = f"""
|
22 |
+
## Drill Core Analysis Results
|
23 |
+
|
24 |
+
### Primary Prediction
|
25 |
+
**Rock Type:** `{result['rock_type']}`
|
26 |
+
**Confidence:** `{result['confidence']:.2f}`
|
27 |
+
|
28 |
+
### Analysis Details
|
29 |
+
{result['explanation']}
|
30 |
+
"""
|
31 |
+
|
32 |
+
except Exception as e:
|
33 |
+
response = f"## Error\nAn error occurred during analysis: {str(e)}"
|
34 |
+
|
35 |
+
# Clean up
|
36 |
+
if os.path.exists(temp_path):
|
37 |
+
os.remove(temp_path)
|
38 |
+
|
39 |
+
return response
|
40 |
+
|
41 |
+
# Create Gradio interface
|
42 |
+
with gr.Blocks(title="Geologist_AI - Core Logger") as demo:
|
43 |
+
gr.Markdown("# Geologist_AI - Core Logger")
|
44 |
+
gr.Markdown("Upload a drill core image to identify the rock type")
|
45 |
+
|
46 |
+
with gr.Row():
|
47 |
+
with gr.Column():
|
48 |
+
image_input = gr.Image(type="pil", label="📷 Drill Core Image")
|
49 |
+
submit_btn = gr.Button("🔍 Analyze Core Sample", variant="primary")
|
50 |
+
with gr.Column():
|
51 |
+
output_text = gr.Markdown(label="📊 Analysis Results")
|
52 |
+
|
53 |
+
submit_btn.click(
|
54 |
+
fn=analyze_core,
|
55 |
+
inputs=image_input,
|
56 |
+
outputs=output_text
|
57 |
+
)
|
58 |
+
|
59 |
+
gr.Markdown("---")
|
60 |
+
gr.Markdown("### About this Tool")
|
61 |
+
gr.Markdown("""
|
62 |
+
This AI-powered geologist identifies rock types based on:
|
63 |
+
- **Visual color analysis**
|
64 |
+
- **Deep learning feature extraction**
|
65 |
+
|
66 |
+
**Supported rock types:**
|
67 |
+
- Gold-bearing rock
|
68 |
+
- Iron-rich rock
|
69 |
+
- Lithium-rich rock
|
70 |
+
- Copper-bearing rock
|
71 |
+
- Quartz-rich rock
|
72 |
+
- Waste rock
|
73 |
+
""")
|
74 |
+
|
75 |
+
# Launch the app
|
76 |
+
if __name__ == "__main__":
|
77 |
+
demo.launch()
|
cluster_analyzer.py
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import numpy as np
|
3 |
+
from sklearn.cluster import KMeans
|
4 |
+
from sklearn.decomposition import PCA
|
5 |
+
from sklearn.preprocessing import StandardScaler
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
import seaborn as sns
|
8 |
+
from config import NUM_CLUSTERS, OUTPUT_DIR
|
9 |
+
import os
|
10 |
+
|
11 |
+
class ClusterAnalyzer:
|
12 |
+
def __init__(self, n_clusters=NUM_CLUSTERS):
|
13 |
+
self.n_clusters = n_clusters
|
14 |
+
self.scaler = StandardScaler()
|
15 |
+
self.kmeans = None
|
16 |
+
self.pca = None
|
17 |
+
|
18 |
+
def fit_predict(self, features):
|
19 |
+
"""Fit KMeans and return cluster labels"""
|
20 |
+
# Standardize features
|
21 |
+
features_scaled = self.scaler.fit_transform(features)
|
22 |
+
|
23 |
+
# Adaptive PCA - use min(n_samples, n_features, 50) components
|
24 |
+
n_components = min(features_scaled.shape[0] - 1, features_scaled.shape[1], 50)
|
25 |
+
if n_components < 1:
|
26 |
+
n_components = 1
|
27 |
+
|
28 |
+
print(f"Using {n_components} PCA components (adapted to data size)")
|
29 |
+
self.pca = PCA(n_components=n_components)
|
30 |
+
features_reduced = self.pca.fit_transform(features_scaled)
|
31 |
+
|
32 |
+
# Adjust number of clusters if needed
|
33 |
+
n_clusters = min(self.n_clusters, len(features_reduced))
|
34 |
+
if n_clusters < 1:
|
35 |
+
n_clusters = 1
|
36 |
+
|
37 |
+
print(f"Using {n_clusters} clusters")
|
38 |
+
self.kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10)
|
39 |
+
labels = self.kmeans.fit_predict(features_reduced)
|
40 |
+
return labels, features_reduced
|
41 |
+
|
42 |
+
def get_cluster_centers(self):
|
43 |
+
"""Return cluster centers"""
|
44 |
+
if self.kmeans is not None:
|
45 |
+
return self.kmeans.cluster_centers_
|
46 |
+
return None
|
47 |
+
|
48 |
+
def visualize_clusters(self, features, labels, image_paths, save_path=None):
|
49 |
+
"""Visualize clusters using PCA"""
|
50 |
+
# Further reduce to 2D for visualization (if possible)
|
51 |
+
if features.shape[0] > 2 and features.shape[1] > 2:
|
52 |
+
pca_2d = PCA(n_components=min(2, features.shape[0] - 1, features.shape[1]))
|
53 |
+
features_2d = pca_2d.fit_transform(features)
|
54 |
+
else:
|
55 |
+
# If we can't do PCA, use first 2 features
|
56 |
+
features_2d = features[:, :2] if features.shape[1] >= 2 else np.hstack([features, np.zeros((features.shape[0], 2 - features.shape[1]))])
|
57 |
+
|
58 |
+
# Create plot
|
59 |
+
plt.figure(figsize=(12, 8))
|
60 |
+
|
61 |
+
# Handle case where we have only one cluster
|
62 |
+
unique_labels = np.unique(labels)
|
63 |
+
if len(unique_labels) > 1:
|
64 |
+
scatter = plt.scatter(features_2d[:, 0], features_2d[:, 1], c=labels, cmap='tab10', alpha=0.7, s=100)
|
65 |
+
plt.colorbar(scatter)
|
66 |
+
else:
|
67 |
+
plt.scatter(features_2d[:, 0], features_2d[:, 1], c='blue', alpha=0.7, s=100)
|
68 |
+
plt.title(f'All samples in single cluster (Cluster {labels[0]})')
|
69 |
+
|
70 |
+
plt.title('Drill Core Sample Clusters (PCA Visualization)', fontsize=16)
|
71 |
+
plt.xlabel('Feature Dimension 1')
|
72 |
+
plt.ylabel('Feature Dimension 2')
|
73 |
+
|
74 |
+
# Annotate some points
|
75 |
+
for i in range(min(15, len(features_2d))):
|
76 |
+
if i < len(image_paths):
|
77 |
+
filename = os.path.basename(image_paths[i])[:15] + "..."
|
78 |
+
plt.annotate(filename, (features_2d[i, 0], features_2d[i, 1]),
|
79 |
+
xytext=(5, 5), textcoords='offset points', fontsize=8, alpha=0.7)
|
80 |
+
|
81 |
+
plt.tight_layout()
|
82 |
+
if save_path:
|
83 |
+
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
84 |
+
print(f"Cluster visualization saved to {save_path}")
|
85 |
+
plt.show()
|
86 |
+
|
87 |
+
def create_cluster_map(self, image_paths, labels):
|
88 |
+
"""Create mapping from cluster ID to image paths"""
|
89 |
+
cluster_map = {}
|
90 |
+
for path, label in zip(image_paths, labels):
|
91 |
+
if label not in cluster_map:
|
92 |
+
cluster_map[label] = []
|
93 |
+
cluster_map[label].append(path)
|
94 |
+
return cluster_map
|
95 |
+
|
96 |
+
def analyze_cluster_characteristics(self, features, labels, image_paths):
|
97 |
+
"""Analyze characteristics of each cluster"""
|
98 |
+
cluster_stats = {}
|
99 |
+
|
100 |
+
# Get features for each cluster
|
101 |
+
for cluster_id in np.unique(labels):
|
102 |
+
mask = labels == cluster_id
|
103 |
+
cluster_features = features[mask]
|
104 |
+
|
105 |
+
# Calculate statistics
|
106 |
+
mean_features = np.mean(cluster_features, axis=0)
|
107 |
+
std_features = np.std(cluster_features, axis=0)
|
108 |
+
|
109 |
+
# Get image paths for this cluster
|
110 |
+
cluster_images = [path for i, path in enumerate(image_paths) if labels[i] == cluster_id]
|
111 |
+
|
112 |
+
cluster_stats[cluster_id] = {
|
113 |
+
'count': len(cluster_images),
|
114 |
+
'mean_features': mean_features,
|
115 |
+
'std_features': std_features,
|
116 |
+
'sample_images': cluster_images[:5] # First 5 samples
|
117 |
+
}
|
118 |
+
|
119 |
+
return cluster_stats
|
120 |
+
|
121 |
+
def analyze_clusters(self, features, image_paths):
|
122 |
+
"""Complete clustering analysis"""
|
123 |
+
print(f"Performing clustering analysis on {len(image_paths)} samples...")
|
124 |
+
print(f"Feature dimensions: {features.shape}")
|
125 |
+
|
126 |
+
# Perform clustering
|
127 |
+
labels, features_reduced = self.fit_predict(features)
|
128 |
+
|
129 |
+
# Create cluster map
|
130 |
+
cluster_map = self.create_cluster_map(image_paths, labels)
|
131 |
+
|
132 |
+
# Analyze cluster characteristics
|
133 |
+
cluster_stats = self.analyze_cluster_characteristics(features, labels, image_paths)
|
134 |
+
|
135 |
+
# Visualize if we have enough samples
|
136 |
+
if len(image_paths) > 2:
|
137 |
+
viz_path = os.path.join(OUTPUT_DIR, "clusters.png")
|
138 |
+
self.visualize_clusters(features, labels, image_paths, viz_path)
|
139 |
+
|
140 |
+
# Print cluster information
|
141 |
+
print("\n" + "="*60)
|
142 |
+
print("CLUSTER ANALYSIS RESULTS")
|
143 |
+
print("="*60)
|
144 |
+
for cluster_id, stats in cluster_stats.items():
|
145 |
+
print(f"\nCluster {cluster_id}:")
|
146 |
+
print(f" Samples: {stats['count']} images")
|
147 |
+
print(f" Sample files:")
|
148 |
+
for path in stats['sample_images']:
|
149 |
+
print(f" - {os.path.basename(path)}")
|
150 |
+
|
151 |
+
return labels, cluster_map, cluster_stats
|
152 |
+
|
153 |
+
if __name__ == "__main__":
|
154 |
+
pass
|
config.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import os
|
3 |
+
|
4 |
+
# Directories
|
5 |
+
DATA_DIR = "data"
|
6 |
+
IMAGE_DIR = os.path.join(DATA_DIR, "core_images")
|
7 |
+
MODEL_DIR = "models"
|
8 |
+
OUTPUT_DIR = "output"
|
9 |
+
|
10 |
+
# Create directories if they don't exist
|
11 |
+
os.makedirs(IMAGE_DIR, exist_ok=True)
|
12 |
+
os.makedirs(MODEL_DIR, exist_ok=True)
|
13 |
+
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
14 |
+
|
15 |
+
# Model parameters - adaptive to data size
|
16 |
+
NUM_CLUSTERS = 3 # Reduced default
|
17 |
+
IMAGE_SIZE = (224, 224)
|
18 |
+
BATCH_SIZE = 32
|
19 |
+
|
20 |
+
# Candidate labels for classification
|
21 |
+
CANDIDATE_LABELS = [
|
22 |
+
"gold-bearing rock",
|
23 |
+
"iron-rich rock",
|
24 |
+
"lithium-rich rock",
|
25 |
+
"copper-bearing rock",
|
26 |
+
"waste rock",
|
27 |
+
"quartz-rich rock",
|
28 |
+
"sulfide-rich rock"
|
29 |
+
]
|
30 |
+
|
31 |
+
# Public geology repositories
|
32 |
+
DATASET_SOURCES = [
|
33 |
+
{
|
34 |
+
"name": "Geoscience Australia",
|
35 |
+
"url": "https://geology.csiro.au/datasets/drill-core-images",
|
36 |
+
"description": "Australian geological survey drill core images"
|
37 |
+
},
|
38 |
+
{
|
39 |
+
"name": "USGS Mineral Resources",
|
40 |
+
"url": "https://mrdata.usgs.gov/geology/state/map-viewer.php",
|
41 |
+
"description": "US Geological Survey mineral resources data"
|
42 |
+
},
|
43 |
+
{
|
44 |
+
"name": "BGS OpenGeoscience",
|
45 |
+
"url": "https://www.bgs.ac.uk/discovering-geology/rock-library/",
|
46 |
+
"description": "British Geological Survey rock sample images"
|
47 |
+
}
|
48 |
+
]
|
core_dataset.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
from torch.utils.data import Dataset
|
5 |
+
from PIL import Image
|
6 |
+
import torchvision.transforms as transforms
|
7 |
+
from config import IMAGE_SIZE
|
8 |
+
|
9 |
+
class CoreDataset(Dataset):
|
10 |
+
def __init__(self, image_dir, transform=None):
|
11 |
+
self.image_dir = image_dir
|
12 |
+
self.image_paths = [
|
13 |
+
os.path.join(image_dir, f)
|
14 |
+
for f in os.listdir(image_dir)
|
15 |
+
if f.lower().endswith(('.png', '.jpg', '.jpeg'))
|
16 |
+
]
|
17 |
+
self.transform = transform or self.default_transform()
|
18 |
+
|
19 |
+
def default_transform(self):
|
20 |
+
return transforms.Compose([
|
21 |
+
transforms.Resize(IMAGE_SIZE),
|
22 |
+
transforms.ToTensor(),
|
23 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
24 |
+
std=[0.229, 0.224, 0.225])
|
25 |
+
])
|
26 |
+
|
27 |
+
def __len__(self):
|
28 |
+
return len(self.image_paths)
|
29 |
+
|
30 |
+
def __getitem__(self, idx):
|
31 |
+
img_path = self.image_paths[idx]
|
32 |
+
image = Image.open(img_path).convert("RGB")
|
33 |
+
if self.transform:
|
34 |
+
image = self.transform(image)
|
35 |
+
return image, img_path
|
data_collector.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import os
|
3 |
+
import requests
|
4 |
+
from PIL import Image
|
5 |
+
from io import BytesIO
|
6 |
+
import time
|
7 |
+
from config import IMAGE_DIR, DATASET_SOURCES
|
8 |
+
|
9 |
+
class DataCollector:
|
10 |
+
def __init__(self):
|
11 |
+
self.image_dir = IMAGE_DIR
|
12 |
+
self.sources = DATASET_SOURCES
|
13 |
+
|
14 |
+
def collect_sample_images(self):
|
15 |
+
"""Collect sample images from public sources"""
|
16 |
+
# These are example URLs - in practice you'd scrape or use APIs
|
17 |
+
sample_urls = [
|
18 |
+
"https://c7.alamy.com/comp/3AJ86J0/gold-on-quartz-bradshaw-mountains-arizona-gold-on-quartz-from-the-bradshaw-mountains-arizona-is-a-classic-and-highly-sought-after-mineral-associa-3AJ86J0.jpg",
|
19 |
+
"https://www.nuggetsbygrant.com/cdn/shop/products/243A0948.jpg?v=1670014792&width=1080",
|
20 |
+
"https://news.rice.edu/sites/g/files/bxs2656/files/inline-images/BIF5-0524_540_1.jpeg",
|
21 |
+
"https://c7.alamy.com/comp/2FNKTF3/copper-bearing-rock-against-a-gravel-ground-surface-2FNKTF3.jpg",
|
22 |
+
"https://www.shutterstock.com/shutterstock/photos/2618131965/display_1500/stock-photo-close-up-of-a-rough-weathered-copper-ore-stone-with-natural-crystal-formations-2618131965.jpg",
|
23 |
+
"https://geologyistheway.com/wp-content/uploads/2021/06/118-milky-quartz.jpg",
|
24 |
+
"https://geologyistheway.com/wp-content/uploads/2021/06/201210-4-1024x726.jpg"
|
25 |
+
|
26 |
+
|
27 |
+
]
|
28 |
+
|
29 |
+
|
30 |
+
|
31 |
+
print("Collecting sample drill core images...")
|
32 |
+
for i, url in enumerate(sample_urls):
|
33 |
+
try:
|
34 |
+
response = requests.get(url, timeout=10)
|
35 |
+
response.raise_for_status()
|
36 |
+
|
37 |
+
img = Image.open(BytesIO(response.content))
|
38 |
+
img_path = os.path.join(self.image_dir, f"sample_core_{i+1}.jpg")
|
39 |
+
img.save(img_path)
|
40 |
+
print(f"Downloaded: sample_core_{i+1}.jpg")
|
41 |
+
time.sleep(0.5) # Be respectful to servers
|
42 |
+
except Exception as e:
|
43 |
+
print(f"Failed to download {url}: {e}")
|
44 |
+
|
45 |
+
print(f"Collected {len(os.listdir(self.image_dir))} images")
|
46 |
+
|
47 |
+
def get_dataset_info(self):
|
48 |
+
"""Return information about available datasets"""
|
49 |
+
return self.sources
|
50 |
+
|
51 |
+
if __name__ == "__main__":
|
52 |
+
collector = DataCollector()
|
53 |
+
collector.collect_sample_images()
|
54 |
+
print("\nAvailable geological datasets:")
|
55 |
+
for source in collector.get_dataset_info():
|
56 |
+
print(f"- {source['name']}: {source['description']}")
|
57 |
+
print(f" URL: {source['url']}\n")
|
feature_extractor.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from torchvision.models import resnet18, ResNet18_Weights
|
5 |
+
from torch.utils.data import DataLoader
|
6 |
+
import numpy as np
|
7 |
+
from core_dataset import CoreDataset
|
8 |
+
from config import BATCH_SIZE
|
9 |
+
|
10 |
+
class FeatureExtractor:
|
11 |
+
def __init__(self, device=None):
|
12 |
+
self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
13 |
+
self.model = self._load_model()
|
14 |
+
|
15 |
+
def _load_model(self):
|
16 |
+
"""Load pretrained ResNet18 and remove classification layer"""
|
17 |
+
weights = ResNet18_Weights.DEFAULT
|
18 |
+
model = resnet18(weights=weights)
|
19 |
+
# Remove the final classification layer
|
20 |
+
model = nn.Sequential(*list(model.children())[:-1])
|
21 |
+
model = model.to(self.device)
|
22 |
+
model.eval()
|
23 |
+
return model
|
24 |
+
|
25 |
+
def extract_features(self, image_dir):
|
26 |
+
"""Extract features from all images in directory"""
|
27 |
+
dataset = CoreDataset(image_dir)
|
28 |
+
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False)
|
29 |
+
|
30 |
+
features = []
|
31 |
+
image_paths = []
|
32 |
+
|
33 |
+
print("Extracting features from images...")
|
34 |
+
with torch.no_grad():
|
35 |
+
for batch, paths in dataloader:
|
36 |
+
batch = batch.to(self.device)
|
37 |
+
batch_features = self.model(batch)
|
38 |
+
batch_features = batch_features.view(batch_features.size(0), -1)
|
39 |
+
features.append(batch_features.cpu().numpy())
|
40 |
+
image_paths.extend(paths)
|
41 |
+
|
42 |
+
features = np.vstack(features)
|
43 |
+
print(f"Extracted features shape: {features.shape}")
|
44 |
+
return features, image_paths
|
45 |
+
|
46 |
+
if __name__ == "__main__":
|
47 |
+
from config import IMAGE_DIR
|
48 |
+
extractor = FeatureExtractor()
|
49 |
+
features, paths = extractor.extract_features(IMAGE_DIR)
|
50 |
+
print(f"Extracted features for {len(paths)} images")
|
gen_ai_labeler.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from transformers import pipeline
|
3 |
+
import torch
|
4 |
+
import os
|
5 |
+
from PIL import Image
|
6 |
+
import torchvision.transforms as transforms
|
7 |
+
from config import CANDIDATE_LABELS, IMAGE_SIZE
|
8 |
+
|
9 |
+
class GenAILabeler:
|
10 |
+
def __init__(self):
|
11 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
12 |
+
# Use a vision-language model for better image understanding
|
13 |
+
self.classifier = pipeline(
|
14 |
+
"zero-shot-classification",
|
15 |
+
model="facebook/bart-large-mnli",
|
16 |
+
device=0 if torch.cuda.is_available() else -1
|
17 |
+
)
|
18 |
+
# More specific candidate labels
|
19 |
+
self.candidate_labels = CANDIDATE_LABELS
|
20 |
+
|
21 |
+
def analyze_image_content(self, image_path):
|
22 |
+
"""Extract visual characteristics from image filename"""
|
23 |
+
# In a real implementation, we'd use computer vision
|
24 |
+
# For now, we'll create better prompts based on filenames
|
25 |
+
filename = os.path.basename(image_path).lower()
|
26 |
+
|
27 |
+
characteristics = []
|
28 |
+
if 'gold' in filename:
|
29 |
+
characteristics.append("visible metallic particles, yellow coloration")
|
30 |
+
if 'iron' in filename or 'pyrite' in filename:
|
31 |
+
characteristics.append("dark metallic appearance, magnetic properties")
|
32 |
+
if 'lithium' in filename or 'spodumene' in filename:
|
33 |
+
characteristics.append("light-colored minerals, pegmatite texture")
|
34 |
+
if 'copper' in filename:
|
35 |
+
characteristics.append("green or blue coloration, metallic luster")
|
36 |
+
if 'quartz' in filename:
|
37 |
+
characteristics.append("clear or white crystalline structure")
|
38 |
+
if 'granite' in filename:
|
39 |
+
characteristics.append("mixed mineral composition, coarse-grained")
|
40 |
+
if 'basalt' in filename:
|
41 |
+
characteristics.append("dark fine-grained texture")
|
42 |
+
|
43 |
+
if not characteristics:
|
44 |
+
characteristics = ["visible mineral grains", "distinctive color patterns", "unique textural features"]
|
45 |
+
|
46 |
+
return ", ".join(characteristics)
|
47 |
+
|
48 |
+
def label_cluster(self, sample_image_path):
|
49 |
+
"""Generate label for a cluster based on a sample image"""
|
50 |
+
# Get visual characteristics
|
51 |
+
visual_features = self.analyze_image_content(sample_image_path)
|
52 |
+
|
53 |
+
# Create a more specific prompt
|
54 |
+
prompt = f"A geological drill core sample showing {visual_features}. "
|
55 |
+
prompt += "What economically important mineral is most likely present in this rock sample?"
|
56 |
+
|
57 |
+
# Perform zero-shot classification
|
58 |
+
result = self.classifier(prompt, self.candidate_labels)
|
59 |
+
|
60 |
+
# Return top prediction with all scores
|
61 |
+
return {
|
62 |
+
"label": result['labels'][0],
|
63 |
+
"confidence": result['scores'][0],
|
64 |
+
"all_scores": dict(zip(result['labels'], result['scores'])),
|
65 |
+
"prompt_used": prompt
|
66 |
+
}
|
67 |
+
|
68 |
+
def label_all_clusters(self, cluster_map):
|
69 |
+
"""Label all clusters with improved context"""
|
70 |
+
cluster_labels = {}
|
71 |
+
|
72 |
+
print("Generating detailed labels for clusters using GenAI...")
|
73 |
+
for cluster_id, image_paths in cluster_map.items():
|
74 |
+
# Use first image as sample for the cluster
|
75 |
+
sample_path = image_paths[0]
|
76 |
+
label_info = self.label_cluster(sample_path)
|
77 |
+
cluster_labels[cluster_id] = label_info
|
78 |
+
|
79 |
+
print(f"\nCluster {cluster_id}:")
|
80 |
+
print(f" Primary Label: {label_info['label']}")
|
81 |
+
print(f" Confidence: {label_info['confidence']:.3f}")
|
82 |
+
print(f" Key Features: {self.analyze_image_content(sample_path)}")
|
83 |
+
|
84 |
+
# Show top 3 alternative labels
|
85 |
+
sorted_scores = sorted(label_info['all_scores'].items(), key=lambda x: x[1], reverse=True)
|
86 |
+
print(" Alternative possibilities:")
|
87 |
+
for label, score in sorted_scores[1:4]:
|
88 |
+
print(f" - {label}: {score:.3f}")
|
89 |
+
|
90 |
+
return cluster_labels
|
91 |
+
|
92 |
+
if __name__ == "__main__":
|
93 |
+
# This would be called from the main pipeline
|
94 |
+
pass
|
requirements.txt
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
torch
|
3 |
+
torchvision
|
4 |
+
transformers
|
5 |
+
scikit-learn
|
6 |
+
Pillow
|
7 |
+
gradio
|
8 |
+
requests
|
9 |
+
numpy
|
10 |
+
pandas
|
11 |
+
matplotlib
|
12 |
+
seaborn
|
simple_classifier.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
import os
|
4 |
+
import numpy as np
|
5 |
+
from PIL import Image
|
6 |
+
import torchvision.transforms as transforms
|
7 |
+
from config import CANDIDATE_LABELS
|
8 |
+
import torch
|
9 |
+
from torchvision.models import resnet18, ResNet18_Weights
|
10 |
+
import torch.nn as nn
|
11 |
+
|
12 |
+
class SimpleRockClassifier:
|
13 |
+
def __init__(self):
|
14 |
+
# Load pre-trained model
|
15 |
+
self.transform = transforms.Compose([
|
16 |
+
transforms.Resize((224, 224)),
|
17 |
+
transforms.ToTensor(),
|
18 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
19 |
+
std=[0.229, 0.224, 0.225])
|
20 |
+
])
|
21 |
+
|
22 |
+
# Load ResNet model
|
23 |
+
weights = ResNet18_Weights.DEFAULT
|
24 |
+
self.model = resnet18(weights=weights)
|
25 |
+
self.model = nn.Sequential(*list(self.model.children())[:-1]) # Remove final layer
|
26 |
+
self.model.eval()
|
27 |
+
|
28 |
+
# Simple rule-based classification based on filename
|
29 |
+
self.keyword_mapping = {
|
30 |
+
'gold': 'gold-bearing rock',
|
31 |
+
'iron': 'iron-rich rock',
|
32 |
+
'pyrite': 'iron-rich rock',
|
33 |
+
'lithium': 'lithium-rich rock',
|
34 |
+
'spodumene': 'lithium-rich rock',
|
35 |
+
'copper': 'copper-bearing rock',
|
36 |
+
'quartz': 'quartz-rich rock',
|
37 |
+
'silica': 'quartz-rich rock',
|
38 |
+
'crystal': 'quartz-rich rock',
|
39 |
+
'waste': 'waste rock',
|
40 |
+
'granite': 'waste rock',
|
41 |
+
'basalt': 'waste rock'
|
42 |
+
}
|
43 |
+
|
44 |
+
def extract_features(self, image_path):
|
45 |
+
"""Extract features from image"""
|
46 |
+
try:
|
47 |
+
image = Image.open(image_path).convert("RGB")
|
48 |
+
image_tensor = self.transform(image).unsqueeze(0)
|
49 |
+
|
50 |
+
with torch.no_grad():
|
51 |
+
features = self.model(image_tensor)
|
52 |
+
features = features.view(features.size(0), -1)
|
53 |
+
|
54 |
+
return features.numpy()
|
55 |
+
except Exception as e:
|
56 |
+
print(f"Error extracting features: {e}")
|
57 |
+
return np.random.rand(1, 512) # Fallback
|
58 |
+
|
59 |
+
def classify_by_filename(self, image_path):
|
60 |
+
"""Classify based on filename keywords"""
|
61 |
+
filename = os.path.basename(image_path).lower()
|
62 |
+
|
63 |
+
for keyword, rock_type in self.keyword_mapping.items():
|
64 |
+
if keyword in filename:
|
65 |
+
return rock_type, 0.8
|
66 |
+
|
67 |
+
# Default classification based on color analysis
|
68 |
+
return self.analyze_colors(image_path)
|
69 |
+
|
70 |
+
def analyze_colors(self, image_path):
|
71 |
+
"""Simple color analysis"""
|
72 |
+
try:
|
73 |
+
image = Image.open(image_path).convert("RGB")
|
74 |
+
# Resize for faster processing
|
75 |
+
image_small = image.resize((50, 50))
|
76 |
+
pixels = np.array(image_small)
|
77 |
+
|
78 |
+
# Calculate average color
|
79 |
+
mean_color = np.mean(pixels, axis=(0, 1))
|
80 |
+
|
81 |
+
# Simple color-based classification
|
82 |
+
r, g, b = mean_color
|
83 |
+
|
84 |
+
# Gold detection (yellow)
|
85 |
+
if r > 180 and g > 150 and b < 100 and r > g > b:
|
86 |
+
return "gold-bearing rock", 0.7
|
87 |
+
|
88 |
+
# Iron detection (dark)
|
89 |
+
if (r + g + b) / 3 < 100:
|
90 |
+
return "iron-rich rock", 0.65
|
91 |
+
|
92 |
+
# Copper detection (green/blue)
|
93 |
+
if g > r and g > b and (r + g + b) / 3 > 80:
|
94 |
+
return "copper-bearing rock", 0.6
|
95 |
+
|
96 |
+
# Light minerals (lithium/quartz)
|
97 |
+
if (r + g + b) / 3 > 200:
|
98 |
+
# Check for purple tint (lithium)
|
99 |
+
if abs(r - b) < 30 and (r + g + b) / 3 > 220:
|
100 |
+
return "lithium-rich rock", 0.55
|
101 |
+
else:
|
102 |
+
return "quartz-rich rock", 0.7
|
103 |
+
|
104 |
+
return "waste rock", 0.5
|
105 |
+
|
106 |
+
except Exception as e:
|
107 |
+
print(f"Error in color analysis: {e}")
|
108 |
+
return "waste rock", 0.3
|
109 |
+
|
110 |
+
def predict(self, image_path):
|
111 |
+
"""Main prediction function"""
|
112 |
+
# First try filename-based classification
|
113 |
+
rock_type, confidence = self.classify_by_filename(image_path)
|
114 |
+
|
115 |
+
# Extract features for potential future use
|
116 |
+
features = self.extract_features(image_path)
|
117 |
+
|
118 |
+
return {
|
119 |
+
"rock_type": rock_type,
|
120 |
+
"confidence": confidence,
|
121 |
+
"features": features,
|
122 |
+
"explanation": f"Classified as {rock_type} based on visual characteristics"
|
123 |
+
}
|