Spaces:
Sleeping
Sleeping
from transformers import pipeline | |
import torch | |
import os | |
from PIL import Image | |
import torchvision.transforms as transforms | |
from config import CANDIDATE_LABELS, IMAGE_SIZE | |
class GenAILabeler: | |
def __init__(self): | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# Use a vision-language model for better image understanding | |
self.classifier = pipeline( | |
"zero-shot-classification", | |
model="facebook/bart-large-mnli", | |
device=0 if torch.cuda.is_available() else -1 | |
) | |
# More specific candidate labels | |
self.candidate_labels = CANDIDATE_LABELS | |
def analyze_image_content(self, image_path): | |
"""Extract visual characteristics from image filename""" | |
# In a real implementation, we'd use computer vision | |
# For now, we'll create better prompts based on filenames | |
filename = os.path.basename(image_path).lower() | |
characteristics = [] | |
if 'gold' in filename: | |
characteristics.append("visible metallic particles, yellow coloration") | |
if 'iron' in filename or 'pyrite' in filename: | |
characteristics.append("dark metallic appearance, magnetic properties") | |
if 'lithium' in filename or 'spodumene' in filename: | |
characteristics.append("light-colored minerals, pegmatite texture") | |
if 'copper' in filename: | |
characteristics.append("green or blue coloration, metallic luster") | |
if 'quartz' in filename: | |
characteristics.append("clear or white crystalline structure") | |
if 'granite' in filename: | |
characteristics.append("mixed mineral composition, coarse-grained") | |
if 'basalt' in filename: | |
characteristics.append("dark fine-grained texture") | |
if not characteristics: | |
characteristics = ["visible mineral grains", "distinctive color patterns", "unique textural features"] | |
return ", ".join(characteristics) | |
def label_cluster(self, sample_image_path): | |
"""Generate label for a cluster based on a sample image""" | |
# Get visual characteristics | |
visual_features = self.analyze_image_content(sample_image_path) | |
# Create a more specific prompt | |
prompt = f"A geological drill core sample showing {visual_features}. " | |
prompt += "What economically important mineral is most likely present in this rock sample?" | |
# Perform zero-shot classification | |
result = self.classifier(prompt, self.candidate_labels) | |
# Return top prediction with all scores | |
return { | |
"label": result['labels'][0], | |
"confidence": result['scores'][0], | |
"all_scores": dict(zip(result['labels'], result['scores'])), | |
"prompt_used": prompt | |
} | |
def label_all_clusters(self, cluster_map): | |
"""Label all clusters with improved context""" | |
cluster_labels = {} | |
print("Generating detailed labels for clusters using GenAI...") | |
for cluster_id, image_paths in cluster_map.items(): | |
# Use first image as sample for the cluster | |
sample_path = image_paths[0] | |
label_info = self.label_cluster(sample_path) | |
cluster_labels[cluster_id] = label_info | |
print(f"\nCluster {cluster_id}:") | |
print(f" Primary Label: {label_info['label']}") | |
print(f" Confidence: {label_info['confidence']:.3f}") | |
print(f" Key Features: {self.analyze_image_content(sample_path)}") | |
# Show top 3 alternative labels | |
sorted_scores = sorted(label_info['all_scores'].items(), key=lambda x: x[1], reverse=True) | |
print(" Alternative possibilities:") | |
for label, score in sorted_scores[1:4]: | |
print(f" - {label}: {score:.3f}") | |
return cluster_labels | |
if __name__ == "__main__": | |
# This would be called from the main pipeline | |
pass | |