File size: 3,211 Bytes
206183d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
from transformers import pipeline
import torch
from PIL import Image
import numpy as np
import logging

class ImageAnalyzer:
    def __init__(self, device="cuda" if torch.cuda.is_available() else "cpu"):
        self.device = device
        self.logger = logging.getLogger(__name__)
        self.models = self._load_models()
        
    def _load_models(self):
        try:
            return {
                'captioning': pipeline(
                    "image-to-text",
                    model="Salesforce/blip2-opt-2.7b",
                    device=self.device,
                    torch_dtype=torch.float16 if 'cuda' in self.device else torch.float32
                ),
                'art_analysis': pipeline(
                    "text-generation",
                    model="ArtGAN/art-critique-generator",
                    device=self.device
                ),
                'color_detector': pipeline(
                    "image-classification",
                    model="google/color-detector",
                    device=self.device
                ),
                'style_classifier': pipeline(
                    "image-classification",
                    model="dima806/art_painting_style_detection",
                    device=self.device
                )
            }
        except Exception as e:
            self.logger.error(f"Error loading models: {str(e)}")
            raise

    def analyze_image(self, image):
        try:
            if isinstance(image, (str, bytes)):
                image = Image.open(image)
            
            results = {}
            
            # Captioning
            caption = self.models['captioning'](
                image,
                max_new_tokens=100,
                generate_kwargs={"do_sample": False}
            )
            results.update(self._parse_caption(caption))
            
            # Color detection
            results['colors'] = self._get_colors(image)
            
            # Style classification
            style = self.models['style_classifier'](image)[0]
            results['style'] = style['label']
            results['style_confidence'] = style['score']
            
            # Art analysis
            art_prompt = f"Analyze this {results['style']} artwork: {results['description']}"
            results['art_commentary'] = self.models['art_analysis'](
                art_prompt,
                max_new_tokens=200
            )[0]['generated_text']
            
            return results
            
        except Exception as e:
            self.logger.error(f"Analysis failed: {str(e)}")
            return None

    def _parse_caption(self, caption_output):
        full_text = caption_output[0]['generated_text']
        parts = full_text.split('.', 1)
        return {
            'title': parts[0].strip(),
            'description': parts[1].strip() if len(parts) > 1 else full_text
        }
        
    def _get_colors(self, image):
        colors = self.models['color_detector'](
            image.resize((256, 256)),
            top_k=5
        )
        return [{
            'hex': c['label'],
            'score': round(float(c['score']), 3)
        } for c in colors]