File size: 3,966 Bytes
71c32d5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95

from transformers import pipeline
import torch
import os
from PIL import Image
import torchvision.transforms as transforms
from config import CANDIDATE_LABELS, IMAGE_SIZE

class GenAILabeler:
    def __init__(self):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        # Use a vision-language model for better image understanding
        self.classifier = pipeline(
            "zero-shot-classification",
            model="facebook/bart-large-mnli",
            device=0 if torch.cuda.is_available() else -1
        )
        # More specific candidate labels
        self.candidate_labels = CANDIDATE_LABELS

    def analyze_image_content(self, image_path):
        """Extract visual characteristics from image filename"""
        # In a real implementation, we'd use computer vision
        # For now, we'll create better prompts based on filenames
        filename = os.path.basename(image_path).lower()

        characteristics = []
        if 'gold' in filename:
            characteristics.append("visible metallic particles, yellow coloration")
        if 'iron' in filename or 'pyrite' in filename:
            characteristics.append("dark metallic appearance, magnetic properties")
        if 'lithium' in filename or 'spodumene' in filename:
            characteristics.append("light-colored minerals, pegmatite texture")
        if 'copper' in filename:
            characteristics.append("green or blue coloration, metallic luster")
        if 'quartz' in filename:
            characteristics.append("clear or white crystalline structure")
        if 'granite' in filename:
            characteristics.append("mixed mineral composition, coarse-grained")
        if 'basalt' in filename:
            characteristics.append("dark fine-grained texture")

        if not characteristics:
            characteristics = ["visible mineral grains", "distinctive color patterns", "unique textural features"]

        return ", ".join(characteristics)

    def label_cluster(self, sample_image_path):
        """Generate label for a cluster based on a sample image"""
        # Get visual characteristics
        visual_features = self.analyze_image_content(sample_image_path)

        # Create a more specific prompt
        prompt = f"A geological drill core sample showing {visual_features}. "
        prompt += "What economically important mineral is most likely present in this rock sample?"

        # Perform zero-shot classification
        result = self.classifier(prompt, self.candidate_labels)

        # Return top prediction with all scores
        return {
            "label": result['labels'][0],
            "confidence": result['scores'][0],
            "all_scores": dict(zip(result['labels'], result['scores'])),
            "prompt_used": prompt
        }

    def label_all_clusters(self, cluster_map):
        """Label all clusters with improved context"""
        cluster_labels = {}

        print("Generating detailed labels for clusters using GenAI...")
        for cluster_id, image_paths in cluster_map.items():
            # Use first image as sample for the cluster
            sample_path = image_paths[0]
            label_info = self.label_cluster(sample_path)
            cluster_labels[cluster_id] = label_info

            print(f"\nCluster {cluster_id}:")
            print(f"  Primary Label: {label_info['label']}")
            print(f"  Confidence: {label_info['confidence']:.3f}")
            print(f"  Key Features: {self.analyze_image_content(sample_path)}")

            # Show top 3 alternative labels
            sorted_scores = sorted(label_info['all_scores'].items(), key=lambda x: x[1], reverse=True)
            print("  Alternative possibilities:")
            for label, score in sorted_scores[1:4]:
                print(f"    - {label}: {score:.3f}")

        return cluster_labels

if __name__ == "__main__":
    # This would be called from the main pipeline
    pass