Merlintxu's picture
new file: configs/frame_templates.yaml
206183d
from transformers import pipeline
import torch
from PIL import Image
import numpy as np
import logging
class ImageAnalyzer:
def __init__(self, device="cuda" if torch.cuda.is_available() else "cpu"):
self.device = device
self.logger = logging.getLogger(__name__)
self.models = self._load_models()
def _load_models(self):
try:
return {
'captioning': pipeline(
"image-to-text",
model="Salesforce/blip2-opt-2.7b",
device=self.device,
torch_dtype=torch.float16 if 'cuda' in self.device else torch.float32
),
'art_analysis': pipeline(
"text-generation",
model="ArtGAN/art-critique-generator",
device=self.device
),
'color_detector': pipeline(
"image-classification",
model="google/color-detector",
device=self.device
),
'style_classifier': pipeline(
"image-classification",
model="dima806/art_painting_style_detection",
device=self.device
)
}
except Exception as e:
self.logger.error(f"Error loading models: {str(e)}")
raise
def analyze_image(self, image):
try:
if isinstance(image, (str, bytes)):
image = Image.open(image)
results = {}
# Captioning
caption = self.models['captioning'](
image,
max_new_tokens=100,
generate_kwargs={"do_sample": False}
)
results.update(self._parse_caption(caption))
# Color detection
results['colors'] = self._get_colors(image)
# Style classification
style = self.models['style_classifier'](image)[0]
results['style'] = style['label']
results['style_confidence'] = style['score']
# Art analysis
art_prompt = f"Analyze this {results['style']} artwork: {results['description']}"
results['art_commentary'] = self.models['art_analysis'](
art_prompt,
max_new_tokens=200
)[0]['generated_text']
return results
except Exception as e:
self.logger.error(f"Analysis failed: {str(e)}")
return None
def _parse_caption(self, caption_output):
full_text = caption_output[0]['generated_text']
parts = full_text.split('.', 1)
return {
'title': parts[0].strip(),
'description': parts[1].strip() if len(parts) > 1 else full_text
}
def _get_colors(self, image):
colors = self.models['color_detector'](
image.resize((256, 256)),
top_k=5
)
return [{
'hex': c['label'],
'score': round(float(c['score']), 3)
} for c in colors]