from transformers import pipeline import torch from PIL import Image import numpy as np import logging class ImageAnalyzer: def __init__(self, device="cuda" if torch.cuda.is_available() else "cpu"): self.device = device self.logger = logging.getLogger(__name__) self.models = self._load_models() def _load_models(self): try: return { 'captioning': pipeline( "image-to-text", model="Salesforce/blip2-opt-2.7b", device=self.device, torch_dtype=torch.float16 if 'cuda' in self.device else torch.float32 ), 'art_analysis': pipeline( "text-generation", model="ArtGAN/art-critique-generator", device=self.device ), 'color_detector': pipeline( "image-classification", model="google/color-detector", device=self.device ), 'style_classifier': pipeline( "image-classification", model="dima806/art_painting_style_detection", device=self.device ) } except Exception as e: self.logger.error(f"Error loading models: {str(e)}") raise def analyze_image(self, image): try: if isinstance(image, (str, bytes)): image = Image.open(image) results = {} # Captioning caption = self.models['captioning']( image, max_new_tokens=100, generate_kwargs={"do_sample": False} ) results.update(self._parse_caption(caption)) # Color detection results['colors'] = self._get_colors(image) # Style classification style = self.models['style_classifier'](image)[0] results['style'] = style['label'] results['style_confidence'] = style['score'] # Art analysis art_prompt = f"Analyze this {results['style']} artwork: {results['description']}" results['art_commentary'] = self.models['art_analysis']( art_prompt, max_new_tokens=200 )[0]['generated_text'] return results except Exception as e: self.logger.error(f"Analysis failed: {str(e)}") return None def _parse_caption(self, caption_output): full_text = caption_output[0]['generated_text'] parts = full_text.split('.', 1) return { 'title': parts[0].strip(), 'description': parts[1].strip() if len(parts) > 1 else full_text } def _get_colors(self, image): colors = self.models['color_detector']( image.resize((256, 256)), top_k=5 ) return [{ 'hex': c['label'], 'score': round(float(c['score']), 3) } for c in colors]