from transformers import pipeline import torch import os from PIL import Image import torchvision.transforms as transforms from config import CANDIDATE_LABELS, IMAGE_SIZE class GenAILabeler: def __init__(self): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Use a vision-language model for better image understanding self.classifier = pipeline( "zero-shot-classification", model="facebook/bart-large-mnli", device=0 if torch.cuda.is_available() else -1 ) # More specific candidate labels self.candidate_labels = CANDIDATE_LABELS def analyze_image_content(self, image_path): """Extract visual characteristics from image filename""" # In a real implementation, we'd use computer vision # For now, we'll create better prompts based on filenames filename = os.path.basename(image_path).lower() characteristics = [] if 'gold' in filename: characteristics.append("visible metallic particles, yellow coloration") if 'iron' in filename or 'pyrite' in filename: characteristics.append("dark metallic appearance, magnetic properties") if 'lithium' in filename or 'spodumene' in filename: characteristics.append("light-colored minerals, pegmatite texture") if 'copper' in filename: characteristics.append("green or blue coloration, metallic luster") if 'quartz' in filename: characteristics.append("clear or white crystalline structure") if 'granite' in filename: characteristics.append("mixed mineral composition, coarse-grained") if 'basalt' in filename: characteristics.append("dark fine-grained texture") if not characteristics: characteristics = ["visible mineral grains", "distinctive color patterns", "unique textural features"] return ", ".join(characteristics) def label_cluster(self, sample_image_path): """Generate label for a cluster based on a sample image""" # Get visual characteristics visual_features = self.analyze_image_content(sample_image_path) # Create a more specific prompt prompt = f"A geological drill core sample showing {visual_features}. " prompt += "What economically important mineral is most likely present in this rock sample?" # Perform zero-shot classification result = self.classifier(prompt, self.candidate_labels) # Return top prediction with all scores return { "label": result['labels'][0], "confidence": result['scores'][0], "all_scores": dict(zip(result['labels'], result['scores'])), "prompt_used": prompt } def label_all_clusters(self, cluster_map): """Label all clusters with improved context""" cluster_labels = {} print("Generating detailed labels for clusters using GenAI...") for cluster_id, image_paths in cluster_map.items(): # Use first image as sample for the cluster sample_path = image_paths[0] label_info = self.label_cluster(sample_path) cluster_labels[cluster_id] = label_info print(f"\nCluster {cluster_id}:") print(f" Primary Label: {label_info['label']}") print(f" Confidence: {label_info['confidence']:.3f}") print(f" Key Features: {self.analyze_image_content(sample_path)}") # Show top 3 alternative labels sorted_scores = sorted(label_info['all_scores'].items(), key=lambda x: x[1], reverse=True) print(" Alternative possibilities:") for label, score in sorted_scores[1:4]: print(f" - {label}: {score:.3f}") return cluster_labels if __name__ == "__main__": # This would be called from the main pipeline pass