Spaces:
Build error
Build error
File size: 3,211 Bytes
206183d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
from transformers import pipeline
import torch
from PIL import Image
import numpy as np
import logging
class ImageAnalyzer:
def __init__(self, device="cuda" if torch.cuda.is_available() else "cpu"):
self.device = device
self.logger = logging.getLogger(__name__)
self.models = self._load_models()
def _load_models(self):
try:
return {
'captioning': pipeline(
"image-to-text",
model="Salesforce/blip2-opt-2.7b",
device=self.device,
torch_dtype=torch.float16 if 'cuda' in self.device else torch.float32
),
'art_analysis': pipeline(
"text-generation",
model="ArtGAN/art-critique-generator",
device=self.device
),
'color_detector': pipeline(
"image-classification",
model="google/color-detector",
device=self.device
),
'style_classifier': pipeline(
"image-classification",
model="dima806/art_painting_style_detection",
device=self.device
)
}
except Exception as e:
self.logger.error(f"Error loading models: {str(e)}")
raise
def analyze_image(self, image):
try:
if isinstance(image, (str, bytes)):
image = Image.open(image)
results = {}
# Captioning
caption = self.models['captioning'](
image,
max_new_tokens=100,
generate_kwargs={"do_sample": False}
)
results.update(self._parse_caption(caption))
# Color detection
results['colors'] = self._get_colors(image)
# Style classification
style = self.models['style_classifier'](image)[0]
results['style'] = style['label']
results['style_confidence'] = style['score']
# Art analysis
art_prompt = f"Analyze this {results['style']} artwork: {results['description']}"
results['art_commentary'] = self.models['art_analysis'](
art_prompt,
max_new_tokens=200
)[0]['generated_text']
return results
except Exception as e:
self.logger.error(f"Analysis failed: {str(e)}")
return None
def _parse_caption(self, caption_output):
full_text = caption_output[0]['generated_text']
parts = full_text.split('.', 1)
return {
'title': parts[0].strip(),
'description': parts[1].strip() if len(parts) > 1 else full_text
}
def _get_colors(self, image):
colors = self.models['color_detector'](
image.resize((256, 256)),
top_k=5
)
return [{
'hex': c['label'],
'score': round(float(c['score']), 3)
} for c in colors] |