Spaces:
Runtime error
Runtime error
File size: 10,359 Bytes
a22e84b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 |
from llm_engineering.domain.queries import Query, EmbeddedQuery
from sentence_transformers import SentenceTransformer
import torch
from PIL import Image
import numpy as np
import logging
import re
# Make transformers optional
try:
from transformers import CLIPProcessor, CLIPModel
TRANSFORMERS_AVAILABLE = True
except ImportError:
TRANSFORMERS_AVAILABLE = False
print("Transformers library not available, using fallback text-only embeddings")
class TextEmbedder:
def __init__(self, model_name="all-MiniLM-L6-v2"):
# Force CPU usage for text embedding
self.device = "cpu"
self.model = SentenceTransformer(model_name, device="cpu")
# def to(self, device: str):
# """Move the model to a specific device"""
# self.device = device
# self.model = self.model.to(device)
# return self # Allow method chaining
def encode(self, text: str) -> list[float]:
with torch.no_grad():
return self.model.encode(text, device="cpu", convert_to_tensor=False).tolist()
class MultimodalEmbeddedQuery:
def __init__(self, text_embed: list[float], image_embed: list[float]):
self.embedding = torch.cat([
torch.tensor(text_embed),
torch.tensor(image_embed)
]).tolist()
class MultimodalEmbeddingDispatcher:
@staticmethod
def dispatch(query: Query) -> EmbeddedQuery:
if TRANSFORMERS_AVAILABLE:
embedder = ImageEmbedder()
embedding = embedder.encode_text(query.content)
else:
# Fallback to text-only embedder
embedder = TextEmbedder()
embedding = embedder.encode(query.content)
return EmbeddedQuery(
embedding=embedding,
content=query.content,
metadata=query.metadata
)
class ImageEmbedder:
def __init__(self, model_name="openai/clip-vit-base-patch32"):
# Always initialize fallback embedder first to ensure it exists
print("Initializing fallback TextEmbedder")
self.fallback_embedder = TextEmbedder()
if not TRANSFORMERS_AVAILABLE:
# Create a simple fallback embedder
print("Transformers not available - using fallback text embedder")
self.model = None
self.processor = None
return
self.device = "cpu"
try:
print("Loading CLIP model: {}".format(model_name))
self.model = CLIPModel.from_pretrained(model_name).to(self.device)
self.processor = CLIPProcessor.from_pretrained(model_name)
print("CLIP model loaded successfully")
except Exception as e:
logging.warning("Failed to load CLIP model: {}".format(e))
self.model = None
self.processor = None
print("Creating fallback text embedder due to CLIP load failure: {}".format(e))
def encode(self, image_path: str) -> list[float]:
"""Image embedding (512-dim)"""
if not TRANSFORMERS_AVAILABLE or self.model is None:
print("Using placeholder embedding (512-dim) due to missing CLIP model")
# Return a placeholder embedding of the right size (512)
return [0.0] * 512
try:
print("Loading image from: {}".format(image_path))
image = Image.open(image_path).convert("RGB")
inputs = self.processor(images=image, return_tensors="pt").to(self.device)
with torch.no_grad():
output = self.model.get_image_features(**inputs)[0].cpu().numpy().tolist()
if len(output) != 512:
print("Warning: CLIP model output has {} dimensions, normalizing to 512".format(len(output)))
if len(output) < 512:
output = output + [0.0] * (512 - len(output))
else:
output = output[:512]
return output
except Exception as e:
logging.warning("Failed to encode image: {}".format(e))
print("Returning zero embedding (512-dim) due to encoding error: {}".format(e))
return [0.0] * 512
def encode_text(self, text: str) -> list[float]:
"""Text embedding using CLIP's text encoder (512-dim)"""
if not TRANSFORMERS_AVAILABLE or self.model is None:
print("CLIP not available, using fallback text embedder")
return self._get_normalized_text_embedding(text)
try:
# Clean and preprocess the text for CLIP
try:
# Clean the text - remove special characters that might cause problems
# Remove excessive whitespace, newlines, etc.
text = re.sub(r'\s+', ' ', text).strip()
# Remove or replace problematic characters
text = re.sub(r'[^\w\s.,!?\'"-]', '', text)
# Limit text length aggressively to avoid tokenization issues
if len(text) > 300: # CLIP has limited context window
print("Text too long for CLIP ({}), truncating to 300 chars".format(len(text)))
text = text[:300] # Truncate to avoid tensor size issues
print("Cleaned text for CLIP: {}...".format(text[:50] if len(text) > 50 else text))
except Exception as text_clean_error:
print("Error cleaning text: {}. Using fallback.".format(text_clean_error))
# Just truncate if cleaning fails
if len(text) > 300:
text = text[:300]
# Try to encode with CLIP with explicit max length
try:
# Use explicit max_length to avoid tensor size mismatches
inputs = self.processor(
text=text,
return_tensors="pt",
padding="max_length",
max_length=77, # CLIP's standard context length
truncation=True
).to(self.device)
with torch.no_grad():
output = self.model.get_text_features(**inputs)[0].cpu().numpy().tolist()
if len(output) != 512:
print("Normalizing CLIP output from {} to 512 dimensions".format(len(output)))
if len(output) < 512:
output = output + [0.0] * (512 - len(output))
else:
output = output[:512]
return output
except RuntimeError as e:
print("CLIP encoding error: {}".format(e))
if "size mismatch" in str(e) or "dimension" in str(e).lower():
print("Tensor size mismatch in CLIP, using fallback")
return self._get_normalized_text_embedding(text)
raise
except Exception as e:
logging.warning("Failed to encode text with CLIP: {}".format(e))
print("Using fallback text embedder due to error: {}".format(e))
return self._get_normalized_text_embedding(text)
def _get_normalized_text_embedding(self, text: str) -> list[float]:
"""Helper to get normalized text embeddings from the fallback embedder"""
try:
if self.fallback_embedder is None:
print("Fallback embedder is None, initializing...")
self.fallback_embedder = TextEmbedder()
embed = self.fallback_embedder.encode(text)
# Ensure 512 dimensions for compatibility
if len(embed) < 512:
print("Padding fallback embedding from {} to 512 dimensions".format(len(embed)))
embed = embed + [0.0] * (512 - len(embed))
elif len(embed) > 512:
print("Truncating fallback embedding from {} to 512 dimensions".format(len(embed)))
embed = embed[:512]
return embed
except Exception as e:
print("Error in fallback embedding: {}".format(e))
# Last resort: return zeros
return [0.0] * 512
def encode_batch(self, image_paths: list) -> list:
if not TRANSFORMERS_AVAILABLE or self.model is None:
print("CLIP not available for batch encoding, returning placeholders")
# Return placeholder embeddings
return [[0.0] * 512 for _ in range(len(image_paths))]
try:
print("Batch encoding {} images with CLIP".format(len(image_paths)))
with torch.inference_mode():
images = []
for path in image_paths:
try:
img = Image.open(path).convert("RGB")
images.append(img)
except Exception as e:
print("Error opening image {}: {}".format(path, e))
# Add a black image as placeholder
images.append(Image.new('RGB', (224, 224), color='black'))
if not images:
print("No valid images to process")
return [[0.0] * 512]
inputs = self.processor(images=images, return_tensors="pt").to(self.device)
outputs = self.model.get_image_features(**inputs).cpu().numpy().tolist()
# Ensure each output has 512 dimensions
normalized_outputs = []
for output in outputs:
if len(output) != 512:
if len(output) < 512:
output = output + [0.0] * (512 - len(output))
else:
output = output[:512]
normalized_outputs.append(output)
return normalized_outputs
except Exception as e:
logging.warning("Failed to batch encode images: {}".format(e))
print("Returning placeholder embeddings due to batch encoding error: {}".format(e))
return [[0.0] * 512 for _ in range(len(image_paths))]
|