ARIA / aria /image_encoder.py
vincentamato's picture
precache clip
8549414
import torch
import torch.nn as nn
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
from typing import Tuple, Union
import os
class ImageEncoder(nn.Module):
def __init__(self, clip_model_name: str = "openai/clip-vit-large-patch14-336"):
"""Initialize the image encoder using CLIP.
Args:
clip_model_name: HuggingFace model name for CLIP
"""
super().__init__()
# Store model name for lazy loading
self.clip_model_name = clip_model_name
self.clip_model = None
self.processor = None
self.valence_head = None
self.arousal_head = None
self.device = None
self._initialized = False
def _ensure_initialized(self):
"""Lazy initialization of the model components."""
if self._initialized:
return
print(f"Initializing ImageEncoder with {self.clip_model_name}...")
print("Loading CLIP model from local cache (network disabled)...")
# Prefer loading strictly from the local Hugging Face cache that `app.py` populates.
# If the files are genuinely missing (e.g. first run without network), we fall back
# to an online download so the user still gets a working application.
# Determine the cache directory from env – this is set in `app.py`.
hf_cache_dir = os.environ.get("HF_HUB_CACHE", None)
try:
self.clip_model = CLIPModel.from_pretrained(
self.clip_model_name,
cache_dir=hf_cache_dir,
local_files_only=True, # use cache only on the first attempt
)
self.processor = CLIPProcessor.from_pretrained(
self.clip_model_name,
cache_dir=hf_cache_dir,
local_files_only=True,
)
print("CLIP model loaded successfully from local cache")
except (OSError, EnvironmentError) as cache_err:
print(
"Local cache for CLIP model not found – attempting a one-time online download..."
)
# Note: this will still respect HF_HUB_CACHE so the files are cached for future runs.
self.clip_model = CLIPModel.from_pretrained(
self.clip_model_name,
cache_dir=hf_cache_dir,
)
self.processor = CLIPProcessor.from_pretrained(
self.clip_model_name,
cache_dir=hf_cache_dir,
)
print("CLIP model downloaded and cached successfully")
print("CLIP model loaded successfully")
# Freeze CLIP parameters
for param in self.clip_model.parameters():
param.requires_grad = False
# Add projection layers for valence and arousal
hidden_dim = self.clip_model.config.projection_dim
projection_dim = hidden_dim // 2
self.valence_head = nn.Sequential(
nn.Linear(hidden_dim, projection_dim),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(projection_dim, projection_dim // 2),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(projection_dim // 2, 1),
nn.Tanh() # Output between -1 and 1
)
self.arousal_head = nn.Sequential(
nn.Linear(hidden_dim, projection_dim),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(projection_dim, projection_dim // 2),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(projection_dim // 2, 1),
nn.Tanh() # Output between -1 and 1
)
# Move model to GPU if available
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.to(self.device)
print(f"Model moved to device: {self.device}")
self._initialized = True
def forward(self, images: Union[Image.Image, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward pass to get valence and arousal predictions.
Args:
images: Either PIL images or tensors in CLIP format
Returns:
Tuple of predicted valence and arousal scores
"""
# Ensure model is initialized
self._ensure_initialized()
# Process images if they're PIL images
if isinstance(images, Image.Image):
inputs = self.processor(images=images, return_tensors="pt")
pixel_values = inputs.pixel_values.to(self.device)
else:
pixel_values = images.to(self.device)
# Get CLIP image features
image_features = self.clip_model.get_image_features(pixel_values)
# Project to valence and arousal scores
valence = self.valence_head(image_features)
arousal = self.arousal_head(image_features)
return valence, arousal
def encode_image(self, image: Image.Image) -> torch.Tensor:
"""Get the raw CLIP image embeddings.
Args:
image: PIL image to encode
Returns:
Image embedding tensor
"""
# Ensure model is initialized
self._ensure_initialized()
inputs = self.processor(images=image, return_tensors="pt")
with torch.no_grad():
image_features = self.clip_model.get_image_features(inputs.pixel_values.to(self.device))
return image_features