Spaces:
Runtime error
Runtime error
from llm_engineering.domain.queries import Query, EmbeddedQuery | |
from sentence_transformers import SentenceTransformer | |
import torch | |
from PIL import Image | |
import numpy as np | |
import logging | |
import re | |
# Make transformers optional | |
try: | |
from transformers import CLIPProcessor, CLIPModel | |
TRANSFORMERS_AVAILABLE = True | |
except ImportError: | |
TRANSFORMERS_AVAILABLE = False | |
print("Transformers library not available, using fallback text-only embeddings") | |
class TextEmbedder: | |
def __init__(self, model_name="all-MiniLM-L6-v2"): | |
# Force CPU usage for text embedding | |
self.device = "cpu" | |
self.model = SentenceTransformer(model_name, device="cpu") | |
# def to(self, device: str): | |
# """Move the model to a specific device""" | |
# self.device = device | |
# self.model = self.model.to(device) | |
# return self # Allow method chaining | |
def encode(self, text: str) -> list[float]: | |
with torch.no_grad(): | |
return self.model.encode(text, device="cpu", convert_to_tensor=False).tolist() | |
class MultimodalEmbeddedQuery: | |
def __init__(self, text_embed: list[float], image_embed: list[float]): | |
self.embedding = torch.cat([ | |
torch.tensor(text_embed), | |
torch.tensor(image_embed) | |
]).tolist() | |
class MultimodalEmbeddingDispatcher: | |
def dispatch(query: Query) -> EmbeddedQuery: | |
if TRANSFORMERS_AVAILABLE: | |
embedder = ImageEmbedder() | |
embedding = embedder.encode_text(query.content) | |
else: | |
# Fallback to text-only embedder | |
embedder = TextEmbedder() | |
embedding = embedder.encode(query.content) | |
return EmbeddedQuery( | |
embedding=embedding, | |
content=query.content, | |
metadata=query.metadata | |
) | |
class ImageEmbedder: | |
def __init__(self, model_name="openai/clip-vit-base-patch32"): | |
# Always initialize fallback embedder first to ensure it exists | |
print("Initializing fallback TextEmbedder") | |
self.fallback_embedder = TextEmbedder() | |
if not TRANSFORMERS_AVAILABLE: | |
# Create a simple fallback embedder | |
print("Transformers not available - using fallback text embedder") | |
self.model = None | |
self.processor = None | |
return | |
self.device = "cpu" | |
try: | |
print("Loading CLIP model: {}".format(model_name)) | |
self.model = CLIPModel.from_pretrained(model_name).to(self.device) | |
self.processor = CLIPProcessor.from_pretrained(model_name) | |
print("CLIP model loaded successfully") | |
except Exception as e: | |
logging.warning("Failed to load CLIP model: {}".format(e)) | |
self.model = None | |
self.processor = None | |
print("Creating fallback text embedder due to CLIP load failure: {}".format(e)) | |
def encode(self, image_path: str) -> list[float]: | |
"""Image embedding (512-dim)""" | |
if not TRANSFORMERS_AVAILABLE or self.model is None: | |
print("Using placeholder embedding (512-dim) due to missing CLIP model") | |
# Return a placeholder embedding of the right size (512) | |
return [0.0] * 512 | |
try: | |
print("Loading image from: {}".format(image_path)) | |
image = Image.open(image_path).convert("RGB") | |
inputs = self.processor(images=image, return_tensors="pt").to(self.device) | |
with torch.no_grad(): | |
output = self.model.get_image_features(**inputs)[0].cpu().numpy().tolist() | |
if len(output) != 512: | |
print("Warning: CLIP model output has {} dimensions, normalizing to 512".format(len(output))) | |
if len(output) < 512: | |
output = output + [0.0] * (512 - len(output)) | |
else: | |
output = output[:512] | |
return output | |
except Exception as e: | |
logging.warning("Failed to encode image: {}".format(e)) | |
print("Returning zero embedding (512-dim) due to encoding error: {}".format(e)) | |
return [0.0] * 512 | |
def encode_text(self, text: str) -> list[float]: | |
"""Text embedding using CLIP's text encoder (512-dim)""" | |
if not TRANSFORMERS_AVAILABLE or self.model is None: | |
print("CLIP not available, using fallback text embedder") | |
return self._get_normalized_text_embedding(text) | |
try: | |
# Clean and preprocess the text for CLIP | |
try: | |
# Clean the text - remove special characters that might cause problems | |
# Remove excessive whitespace, newlines, etc. | |
text = re.sub(r'\s+', ' ', text).strip() | |
# Remove or replace problematic characters | |
text = re.sub(r'[^\w\s.,!?\'"-]', '', text) | |
# Limit text length aggressively to avoid tokenization issues | |
if len(text) > 300: # CLIP has limited context window | |
print("Text too long for CLIP ({}), truncating to 300 chars".format(len(text))) | |
text = text[:300] # Truncate to avoid tensor size issues | |
print("Cleaned text for CLIP: {}...".format(text[:50] if len(text) > 50 else text)) | |
except Exception as text_clean_error: | |
print("Error cleaning text: {}. Using fallback.".format(text_clean_error)) | |
# Just truncate if cleaning fails | |
if len(text) > 300: | |
text = text[:300] | |
# Try to encode with CLIP with explicit max length | |
try: | |
# Use explicit max_length to avoid tensor size mismatches | |
inputs = self.processor( | |
text=text, | |
return_tensors="pt", | |
padding="max_length", | |
max_length=77, # CLIP's standard context length | |
truncation=True | |
).to(self.device) | |
with torch.no_grad(): | |
output = self.model.get_text_features(**inputs)[0].cpu().numpy().tolist() | |
if len(output) != 512: | |
print("Normalizing CLIP output from {} to 512 dimensions".format(len(output))) | |
if len(output) < 512: | |
output = output + [0.0] * (512 - len(output)) | |
else: | |
output = output[:512] | |
return output | |
except RuntimeError as e: | |
print("CLIP encoding error: {}".format(e)) | |
if "size mismatch" in str(e) or "dimension" in str(e).lower(): | |
print("Tensor size mismatch in CLIP, using fallback") | |
return self._get_normalized_text_embedding(text) | |
raise | |
except Exception as e: | |
logging.warning("Failed to encode text with CLIP: {}".format(e)) | |
print("Using fallback text embedder due to error: {}".format(e)) | |
return self._get_normalized_text_embedding(text) | |
def _get_normalized_text_embedding(self, text: str) -> list[float]: | |
"""Helper to get normalized text embeddings from the fallback embedder""" | |
try: | |
if self.fallback_embedder is None: | |
print("Fallback embedder is None, initializing...") | |
self.fallback_embedder = TextEmbedder() | |
embed = self.fallback_embedder.encode(text) | |
# Ensure 512 dimensions for compatibility | |
if len(embed) < 512: | |
print("Padding fallback embedding from {} to 512 dimensions".format(len(embed))) | |
embed = embed + [0.0] * (512 - len(embed)) | |
elif len(embed) > 512: | |
print("Truncating fallback embedding from {} to 512 dimensions".format(len(embed))) | |
embed = embed[:512] | |
return embed | |
except Exception as e: | |
print("Error in fallback embedding: {}".format(e)) | |
# Last resort: return zeros | |
return [0.0] * 512 | |
def encode_batch(self, image_paths: list) -> list: | |
if not TRANSFORMERS_AVAILABLE or self.model is None: | |
print("CLIP not available for batch encoding, returning placeholders") | |
# Return placeholder embeddings | |
return [[0.0] * 512 for _ in range(len(image_paths))] | |
try: | |
print("Batch encoding {} images with CLIP".format(len(image_paths))) | |
with torch.inference_mode(): | |
images = [] | |
for path in image_paths: | |
try: | |
img = Image.open(path).convert("RGB") | |
images.append(img) | |
except Exception as e: | |
print("Error opening image {}: {}".format(path, e)) | |
# Add a black image as placeholder | |
images.append(Image.new('RGB', (224, 224), color='black')) | |
if not images: | |
print("No valid images to process") | |
return [[0.0] * 512] | |
inputs = self.processor(images=images, return_tensors="pt").to(self.device) | |
outputs = self.model.get_image_features(**inputs).cpu().numpy().tolist() | |
# Ensure each output has 512 dimensions | |
normalized_outputs = [] | |
for output in outputs: | |
if len(output) != 512: | |
if len(output) < 512: | |
output = output + [0.0] * (512 - len(output)) | |
else: | |
output = output[:512] | |
normalized_outputs.append(output) | |
return normalized_outputs | |
except Exception as e: | |
logging.warning("Failed to batch encode images: {}".format(e)) | |
print("Returning placeholder embeddings due to batch encoding error: {}".format(e)) | |
return [[0.0] * 512 for _ in range(len(image_paths))] | |