Geologist_AI / cluster_analyzer.py
solfedge's picture
Upload 9 files
71c32d5 verified
import numpy as np
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
import seaborn as sns
from config import NUM_CLUSTERS, OUTPUT_DIR
import os
class ClusterAnalyzer:
def __init__(self, n_clusters=NUM_CLUSTERS):
self.n_clusters = n_clusters
self.scaler = StandardScaler()
self.kmeans = None
self.pca = None
def fit_predict(self, features):
"""Fit KMeans and return cluster labels"""
# Standardize features
features_scaled = self.scaler.fit_transform(features)
# Adaptive PCA - use min(n_samples, n_features, 50) components
n_components = min(features_scaled.shape[0] - 1, features_scaled.shape[1], 50)
if n_components < 1:
n_components = 1
print(f"Using {n_components} PCA components (adapted to data size)")
self.pca = PCA(n_components=n_components)
features_reduced = self.pca.fit_transform(features_scaled)
# Adjust number of clusters if needed
n_clusters = min(self.n_clusters, len(features_reduced))
if n_clusters < 1:
n_clusters = 1
print(f"Using {n_clusters} clusters")
self.kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10)
labels = self.kmeans.fit_predict(features_reduced)
return labels, features_reduced
def get_cluster_centers(self):
"""Return cluster centers"""
if self.kmeans is not None:
return self.kmeans.cluster_centers_
return None
def visualize_clusters(self, features, labels, image_paths, save_path=None):
"""Visualize clusters using PCA"""
# Further reduce to 2D for visualization (if possible)
if features.shape[0] > 2 and features.shape[1] > 2:
pca_2d = PCA(n_components=min(2, features.shape[0] - 1, features.shape[1]))
features_2d = pca_2d.fit_transform(features)
else:
# If we can't do PCA, use first 2 features
features_2d = features[:, :2] if features.shape[1] >= 2 else np.hstack([features, np.zeros((features.shape[0], 2 - features.shape[1]))])
# Create plot
plt.figure(figsize=(12, 8))
# Handle case where we have only one cluster
unique_labels = np.unique(labels)
if len(unique_labels) > 1:
scatter = plt.scatter(features_2d[:, 0], features_2d[:, 1], c=labels, cmap='tab10', alpha=0.7, s=100)
plt.colorbar(scatter)
else:
plt.scatter(features_2d[:, 0], features_2d[:, 1], c='blue', alpha=0.7, s=100)
plt.title(f'All samples in single cluster (Cluster {labels[0]})')
plt.title('Drill Core Sample Clusters (PCA Visualization)', fontsize=16)
plt.xlabel('Feature Dimension 1')
plt.ylabel('Feature Dimension 2')
# Annotate some points
for i in range(min(15, len(features_2d))):
if i < len(image_paths):
filename = os.path.basename(image_paths[i])[:15] + "..."
plt.annotate(filename, (features_2d[i, 0], features_2d[i, 1]),
xytext=(5, 5), textcoords='offset points', fontsize=8, alpha=0.7)
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f"Cluster visualization saved to {save_path}")
plt.show()
def create_cluster_map(self, image_paths, labels):
"""Create mapping from cluster ID to image paths"""
cluster_map = {}
for path, label in zip(image_paths, labels):
if label not in cluster_map:
cluster_map[label] = []
cluster_map[label].append(path)
return cluster_map
def analyze_cluster_characteristics(self, features, labels, image_paths):
"""Analyze characteristics of each cluster"""
cluster_stats = {}
# Get features for each cluster
for cluster_id in np.unique(labels):
mask = labels == cluster_id
cluster_features = features[mask]
# Calculate statistics
mean_features = np.mean(cluster_features, axis=0)
std_features = np.std(cluster_features, axis=0)
# Get image paths for this cluster
cluster_images = [path for i, path in enumerate(image_paths) if labels[i] == cluster_id]
cluster_stats[cluster_id] = {
'count': len(cluster_images),
'mean_features': mean_features,
'std_features': std_features,
'sample_images': cluster_images[:5] # First 5 samples
}
return cluster_stats
def analyze_clusters(self, features, image_paths):
"""Complete clustering analysis"""
print(f"Performing clustering analysis on {len(image_paths)} samples...")
print(f"Feature dimensions: {features.shape}")
# Perform clustering
labels, features_reduced = self.fit_predict(features)
# Create cluster map
cluster_map = self.create_cluster_map(image_paths, labels)
# Analyze cluster characteristics
cluster_stats = self.analyze_cluster_characteristics(features, labels, image_paths)
# Visualize if we have enough samples
if len(image_paths) > 2:
viz_path = os.path.join(OUTPUT_DIR, "clusters.png")
self.visualize_clusters(features, labels, image_paths, viz_path)
# Print cluster information
print("\n" + "="*60)
print("CLUSTER ANALYSIS RESULTS")
print("="*60)
for cluster_id, stats in cluster_stats.items():
print(f"\nCluster {cluster_id}:")
print(f" Samples: {stats['count']} images")
print(f" Sample files:")
for path in stats['sample_images']:
print(f" - {os.path.basename(path)}")
return labels, cluster_map, cluster_stats
if __name__ == "__main__":
pass