Spaces:
Build error
Build error
from transformers import pipeline | |
import torch | |
from PIL import Image | |
import numpy as np | |
import logging | |
class ImageAnalyzer: | |
def __init__(self, device="cuda" if torch.cuda.is_available() else "cpu"): | |
self.device = device | |
self.logger = logging.getLogger(__name__) | |
self.models = self._load_models() | |
def _load_models(self): | |
try: | |
return { | |
'captioning': pipeline( | |
"image-to-text", | |
model="Salesforce/blip2-opt-2.7b", | |
device=self.device, | |
torch_dtype=torch.float16 if 'cuda' in self.device else torch.float32 | |
), | |
'art_analysis': pipeline( | |
"text-generation", | |
model="ArtGAN/art-critique-generator", | |
device=self.device | |
), | |
'color_detector': pipeline( | |
"image-classification", | |
model="google/color-detector", | |
device=self.device | |
), | |
'style_classifier': pipeline( | |
"image-classification", | |
model="dima806/art_painting_style_detection", | |
device=self.device | |
) | |
} | |
except Exception as e: | |
self.logger.error(f"Error loading models: {str(e)}") | |
raise | |
def analyze_image(self, image): | |
try: | |
if isinstance(image, (str, bytes)): | |
image = Image.open(image) | |
results = {} | |
# Captioning | |
caption = self.models['captioning']( | |
image, | |
max_new_tokens=100, | |
generate_kwargs={"do_sample": False} | |
) | |
results.update(self._parse_caption(caption)) | |
# Color detection | |
results['colors'] = self._get_colors(image) | |
# Style classification | |
style = self.models['style_classifier'](image)[0] | |
results['style'] = style['label'] | |
results['style_confidence'] = style['score'] | |
# Art analysis | |
art_prompt = f"Analyze this {results['style']} artwork: {results['description']}" | |
results['art_commentary'] = self.models['art_analysis']( | |
art_prompt, | |
max_new_tokens=200 | |
)[0]['generated_text'] | |
return results | |
except Exception as e: | |
self.logger.error(f"Analysis failed: {str(e)}") | |
return None | |
def _parse_caption(self, caption_output): | |
full_text = caption_output[0]['generated_text'] | |
parts = full_text.split('.', 1) | |
return { | |
'title': parts[0].strip(), | |
'description': parts[1].strip() if len(parts) > 1 else full_text | |
} | |
def _get_colors(self, image): | |
colors = self.models['color_detector']( | |
image.resize((256, 256)), | |
top_k=5 | |
) | |
return [{ | |
'hex': c['label'], | |
'score': round(float(c['score']), 3) | |
} for c in colors] |