purpleriann's picture
Upload folder using huggingface_hub
a22e84b verified
from llm_engineering.domain.queries import Query, EmbeddedQuery
from sentence_transformers import SentenceTransformer
import torch
from PIL import Image
import numpy as np
import logging
import re
# Make transformers optional
try:
from transformers import CLIPProcessor, CLIPModel
TRANSFORMERS_AVAILABLE = True
except ImportError:
TRANSFORMERS_AVAILABLE = False
print("Transformers library not available, using fallback text-only embeddings")
class TextEmbedder:
def __init__(self, model_name="all-MiniLM-L6-v2"):
# Force CPU usage for text embedding
self.device = "cpu"
self.model = SentenceTransformer(model_name, device="cpu")
# def to(self, device: str):
# """Move the model to a specific device"""
# self.device = device
# self.model = self.model.to(device)
# return self # Allow method chaining
def encode(self, text: str) -> list[float]:
with torch.no_grad():
return self.model.encode(text, device="cpu", convert_to_tensor=False).tolist()
class MultimodalEmbeddedQuery:
def __init__(self, text_embed: list[float], image_embed: list[float]):
self.embedding = torch.cat([
torch.tensor(text_embed),
torch.tensor(image_embed)
]).tolist()
class MultimodalEmbeddingDispatcher:
@staticmethod
def dispatch(query: Query) -> EmbeddedQuery:
if TRANSFORMERS_AVAILABLE:
embedder = ImageEmbedder()
embedding = embedder.encode_text(query.content)
else:
# Fallback to text-only embedder
embedder = TextEmbedder()
embedding = embedder.encode(query.content)
return EmbeddedQuery(
embedding=embedding,
content=query.content,
metadata=query.metadata
)
class ImageEmbedder:
def __init__(self, model_name="openai/clip-vit-base-patch32"):
# Always initialize fallback embedder first to ensure it exists
print("Initializing fallback TextEmbedder")
self.fallback_embedder = TextEmbedder()
if not TRANSFORMERS_AVAILABLE:
# Create a simple fallback embedder
print("Transformers not available - using fallback text embedder")
self.model = None
self.processor = None
return
self.device = "cpu"
try:
print("Loading CLIP model: {}".format(model_name))
self.model = CLIPModel.from_pretrained(model_name).to(self.device)
self.processor = CLIPProcessor.from_pretrained(model_name)
print("CLIP model loaded successfully")
except Exception as e:
logging.warning("Failed to load CLIP model: {}".format(e))
self.model = None
self.processor = None
print("Creating fallback text embedder due to CLIP load failure: {}".format(e))
def encode(self, image_path: str) -> list[float]:
"""Image embedding (512-dim)"""
if not TRANSFORMERS_AVAILABLE or self.model is None:
print("Using placeholder embedding (512-dim) due to missing CLIP model")
# Return a placeholder embedding of the right size (512)
return [0.0] * 512
try:
print("Loading image from: {}".format(image_path))
image = Image.open(image_path).convert("RGB")
inputs = self.processor(images=image, return_tensors="pt").to(self.device)
with torch.no_grad():
output = self.model.get_image_features(**inputs)[0].cpu().numpy().tolist()
if len(output) != 512:
print("Warning: CLIP model output has {} dimensions, normalizing to 512".format(len(output)))
if len(output) < 512:
output = output + [0.0] * (512 - len(output))
else:
output = output[:512]
return output
except Exception as e:
logging.warning("Failed to encode image: {}".format(e))
print("Returning zero embedding (512-dim) due to encoding error: {}".format(e))
return [0.0] * 512
def encode_text(self, text: str) -> list[float]:
"""Text embedding using CLIP's text encoder (512-dim)"""
if not TRANSFORMERS_AVAILABLE or self.model is None:
print("CLIP not available, using fallback text embedder")
return self._get_normalized_text_embedding(text)
try:
# Clean and preprocess the text for CLIP
try:
# Clean the text - remove special characters that might cause problems
# Remove excessive whitespace, newlines, etc.
text = re.sub(r'\s+', ' ', text).strip()
# Remove or replace problematic characters
text = re.sub(r'[^\w\s.,!?\'"-]', '', text)
# Limit text length aggressively to avoid tokenization issues
if len(text) > 300: # CLIP has limited context window
print("Text too long for CLIP ({}), truncating to 300 chars".format(len(text)))
text = text[:300] # Truncate to avoid tensor size issues
print("Cleaned text for CLIP: {}...".format(text[:50] if len(text) > 50 else text))
except Exception as text_clean_error:
print("Error cleaning text: {}. Using fallback.".format(text_clean_error))
# Just truncate if cleaning fails
if len(text) > 300:
text = text[:300]
# Try to encode with CLIP with explicit max length
try:
# Use explicit max_length to avoid tensor size mismatches
inputs = self.processor(
text=text,
return_tensors="pt",
padding="max_length",
max_length=77, # CLIP's standard context length
truncation=True
).to(self.device)
with torch.no_grad():
output = self.model.get_text_features(**inputs)[0].cpu().numpy().tolist()
if len(output) != 512:
print("Normalizing CLIP output from {} to 512 dimensions".format(len(output)))
if len(output) < 512:
output = output + [0.0] * (512 - len(output))
else:
output = output[:512]
return output
except RuntimeError as e:
print("CLIP encoding error: {}".format(e))
if "size mismatch" in str(e) or "dimension" in str(e).lower():
print("Tensor size mismatch in CLIP, using fallback")
return self._get_normalized_text_embedding(text)
raise
except Exception as e:
logging.warning("Failed to encode text with CLIP: {}".format(e))
print("Using fallback text embedder due to error: {}".format(e))
return self._get_normalized_text_embedding(text)
def _get_normalized_text_embedding(self, text: str) -> list[float]:
"""Helper to get normalized text embeddings from the fallback embedder"""
try:
if self.fallback_embedder is None:
print("Fallback embedder is None, initializing...")
self.fallback_embedder = TextEmbedder()
embed = self.fallback_embedder.encode(text)
# Ensure 512 dimensions for compatibility
if len(embed) < 512:
print("Padding fallback embedding from {} to 512 dimensions".format(len(embed)))
embed = embed + [0.0] * (512 - len(embed))
elif len(embed) > 512:
print("Truncating fallback embedding from {} to 512 dimensions".format(len(embed)))
embed = embed[:512]
return embed
except Exception as e:
print("Error in fallback embedding: {}".format(e))
# Last resort: return zeros
return [0.0] * 512
def encode_batch(self, image_paths: list) -> list:
if not TRANSFORMERS_AVAILABLE or self.model is None:
print("CLIP not available for batch encoding, returning placeholders")
# Return placeholder embeddings
return [[0.0] * 512 for _ in range(len(image_paths))]
try:
print("Batch encoding {} images with CLIP".format(len(image_paths)))
with torch.inference_mode():
images = []
for path in image_paths:
try:
img = Image.open(path).convert("RGB")
images.append(img)
except Exception as e:
print("Error opening image {}: {}".format(path, e))
# Add a black image as placeholder
images.append(Image.new('RGB', (224, 224), color='black'))
if not images:
print("No valid images to process")
return [[0.0] * 512]
inputs = self.processor(images=images, return_tensors="pt").to(self.device)
outputs = self.model.get_image_features(**inputs).cpu().numpy().tolist()
# Ensure each output has 512 dimensions
normalized_outputs = []
for output in outputs:
if len(output) != 512:
if len(output) < 512:
output = output + [0.0] * (512 - len(output))
else:
output = output[:512]
normalized_outputs.append(output)
return normalized_outputs
except Exception as e:
logging.warning("Failed to batch encode images: {}".format(e))
print("Returning placeholder embeddings due to batch encoding error: {}".format(e))
return [[0.0] * 512 for _ in range(len(image_paths))]