Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import torch.nn as nn | |
from transformers import CLIPProcessor, CLIPModel | |
from PIL import Image | |
from typing import Tuple, Union | |
import os | |
class ImageEncoder(nn.Module): | |
def __init__(self, clip_model_name: str = "openai/clip-vit-large-patch14-336"): | |
"""Initialize the image encoder using CLIP. | |
Args: | |
clip_model_name: HuggingFace model name for CLIP | |
""" | |
super().__init__() | |
# Store model name for lazy loading | |
self.clip_model_name = clip_model_name | |
self.clip_model = None | |
self.processor = None | |
self.valence_head = None | |
self.arousal_head = None | |
self.device = None | |
self._initialized = False | |
def _ensure_initialized(self): | |
"""Lazy initialization of the model components.""" | |
if self._initialized: | |
return | |
print(f"Initializing ImageEncoder with {self.clip_model_name}...") | |
print("Loading CLIP model from local cache (network disabled)...") | |
# Prefer loading strictly from the local Hugging Face cache that `app.py` populates. | |
# If the files are genuinely missing (e.g. first run without network), we fall back | |
# to an online download so the user still gets a working application. | |
# Determine the cache directory from env β this is set in `app.py`. | |
hf_cache_dir = os.environ.get("HF_HUB_CACHE", None) | |
try: | |
self.clip_model = CLIPModel.from_pretrained( | |
self.clip_model_name, | |
cache_dir=hf_cache_dir, | |
local_files_only=True, # use cache only on the first attempt | |
) | |
self.processor = CLIPProcessor.from_pretrained( | |
self.clip_model_name, | |
cache_dir=hf_cache_dir, | |
local_files_only=True, | |
) | |
print("CLIP model loaded successfully from local cache") | |
except (OSError, EnvironmentError) as cache_err: | |
print( | |
"Local cache for CLIP model not found β attempting a one-time online download..." | |
) | |
# Note: this will still respect HF_HUB_CACHE so the files are cached for future runs. | |
self.clip_model = CLIPModel.from_pretrained( | |
self.clip_model_name, | |
cache_dir=hf_cache_dir, | |
) | |
self.processor = CLIPProcessor.from_pretrained( | |
self.clip_model_name, | |
cache_dir=hf_cache_dir, | |
) | |
print("CLIP model downloaded and cached successfully") | |
print("CLIP model loaded successfully") | |
# Freeze CLIP parameters | |
for param in self.clip_model.parameters(): | |
param.requires_grad = False | |
# Add projection layers for valence and arousal | |
hidden_dim = self.clip_model.config.projection_dim | |
projection_dim = hidden_dim // 2 | |
self.valence_head = nn.Sequential( | |
nn.Linear(hidden_dim, projection_dim), | |
nn.ReLU(), | |
nn.Dropout(0.1), | |
nn.Linear(projection_dim, projection_dim // 2), | |
nn.ReLU(), | |
nn.Dropout(0.1), | |
nn.Linear(projection_dim // 2, 1), | |
nn.Tanh() # Output between -1 and 1 | |
) | |
self.arousal_head = nn.Sequential( | |
nn.Linear(hidden_dim, projection_dim), | |
nn.ReLU(), | |
nn.Dropout(0.1), | |
nn.Linear(projection_dim, projection_dim // 2), | |
nn.ReLU(), | |
nn.Dropout(0.1), | |
nn.Linear(projection_dim // 2, 1), | |
nn.Tanh() # Output between -1 and 1 | |
) | |
# Move model to GPU if available | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
self.to(self.device) | |
print(f"Model moved to device: {self.device}") | |
self._initialized = True | |
def forward(self, images: Union[Image.Image, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: | |
"""Forward pass to get valence and arousal predictions. | |
Args: | |
images: Either PIL images or tensors in CLIP format | |
Returns: | |
Tuple of predicted valence and arousal scores | |
""" | |
# Ensure model is initialized | |
self._ensure_initialized() | |
# Process images if they're PIL images | |
if isinstance(images, Image.Image): | |
inputs = self.processor(images=images, return_tensors="pt") | |
pixel_values = inputs.pixel_values.to(self.device) | |
else: | |
pixel_values = images.to(self.device) | |
# Get CLIP image features | |
image_features = self.clip_model.get_image_features(pixel_values) | |
# Project to valence and arousal scores | |
valence = self.valence_head(image_features) | |
arousal = self.arousal_head(image_features) | |
return valence, arousal | |
def encode_image(self, image: Image.Image) -> torch.Tensor: | |
"""Get the raw CLIP image embeddings. | |
Args: | |
image: PIL image to encode | |
Returns: | |
Image embedding tensor | |
""" | |
# Ensure model is initialized | |
self._ensure_initialized() | |
inputs = self.processor(images=image, return_tensors="pt") | |
with torch.no_grad(): | |
image_features = self.clip_model.get_image_features(inputs.pixel_values.to(self.device)) | |
return image_features |