Geologist_AI / gen_ai_labeler.py
solfedge's picture
Upload 9 files
71c32d5 verified
from transformers import pipeline
import torch
import os
from PIL import Image
import torchvision.transforms as transforms
from config import CANDIDATE_LABELS, IMAGE_SIZE
class GenAILabeler:
def __init__(self):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Use a vision-language model for better image understanding
self.classifier = pipeline(
"zero-shot-classification",
model="facebook/bart-large-mnli",
device=0 if torch.cuda.is_available() else -1
)
# More specific candidate labels
self.candidate_labels = CANDIDATE_LABELS
def analyze_image_content(self, image_path):
"""Extract visual characteristics from image filename"""
# In a real implementation, we'd use computer vision
# For now, we'll create better prompts based on filenames
filename = os.path.basename(image_path).lower()
characteristics = []
if 'gold' in filename:
characteristics.append("visible metallic particles, yellow coloration")
if 'iron' in filename or 'pyrite' in filename:
characteristics.append("dark metallic appearance, magnetic properties")
if 'lithium' in filename or 'spodumene' in filename:
characteristics.append("light-colored minerals, pegmatite texture")
if 'copper' in filename:
characteristics.append("green or blue coloration, metallic luster")
if 'quartz' in filename:
characteristics.append("clear or white crystalline structure")
if 'granite' in filename:
characteristics.append("mixed mineral composition, coarse-grained")
if 'basalt' in filename:
characteristics.append("dark fine-grained texture")
if not characteristics:
characteristics = ["visible mineral grains", "distinctive color patterns", "unique textural features"]
return ", ".join(characteristics)
def label_cluster(self, sample_image_path):
"""Generate label for a cluster based on a sample image"""
# Get visual characteristics
visual_features = self.analyze_image_content(sample_image_path)
# Create a more specific prompt
prompt = f"A geological drill core sample showing {visual_features}. "
prompt += "What economically important mineral is most likely present in this rock sample?"
# Perform zero-shot classification
result = self.classifier(prompt, self.candidate_labels)
# Return top prediction with all scores
return {
"label": result['labels'][0],
"confidence": result['scores'][0],
"all_scores": dict(zip(result['labels'], result['scores'])),
"prompt_used": prompt
}
def label_all_clusters(self, cluster_map):
"""Label all clusters with improved context"""
cluster_labels = {}
print("Generating detailed labels for clusters using GenAI...")
for cluster_id, image_paths in cluster_map.items():
# Use first image as sample for the cluster
sample_path = image_paths[0]
label_info = self.label_cluster(sample_path)
cluster_labels[cluster_id] = label_info
print(f"\nCluster {cluster_id}:")
print(f" Primary Label: {label_info['label']}")
print(f" Confidence: {label_info['confidence']:.3f}")
print(f" Key Features: {self.analyze_image_content(sample_path)}")
# Show top 3 alternative labels
sorted_scores = sorted(label_info['all_scores'].items(), key=lambda x: x[1], reverse=True)
print(" Alternative possibilities:")
for label, score in sorted_scores[1:4]:
print(f" - {label}: {score:.3f}")
return cluster_labels
if __name__ == "__main__":
# This would be called from the main pipeline
pass