File size: 10,359 Bytes
a22e84b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
from llm_engineering.domain.queries import Query, EmbeddedQuery 
from sentence_transformers import SentenceTransformer
import torch
from PIL import Image
import numpy as np
import logging
import re

# Make transformers optional
try:
    from transformers import CLIPProcessor, CLIPModel
    TRANSFORMERS_AVAILABLE = True
except ImportError:
    TRANSFORMERS_AVAILABLE = False
    print("Transformers library not available, using fallback text-only embeddings")


class TextEmbedder:
    def __init__(self, model_name="all-MiniLM-L6-v2"):
        # Force CPU usage for text embedding
        self.device = "cpu"
        self.model = SentenceTransformer(model_name, device="cpu")
    
    # def to(self, device: str):
    #     """Move the model to a specific device"""
    #     self.device = device
    #     self.model = self.model.to(device)
    #     return self  # Allow method chaining
    
    def encode(self, text: str) -> list[float]:
        with torch.no_grad():
            return self.model.encode(text, device="cpu", convert_to_tensor=False).tolist()


class MultimodalEmbeddedQuery:
    def __init__(self, text_embed: list[float], image_embed: list[float]):
        self.embedding = torch.cat([
            torch.tensor(text_embed), 
            torch.tensor(image_embed)
        ]).tolist()


class MultimodalEmbeddingDispatcher:
    @staticmethod
    def dispatch(query: Query) -> EmbeddedQuery:
        if TRANSFORMERS_AVAILABLE:
            embedder = ImageEmbedder()
            embedding = embedder.encode_text(query.content)
        else:
            # Fallback to text-only embedder
            embedder = TextEmbedder()
            embedding = embedder.encode(query.content)
            
        return EmbeddedQuery(
            embedding=embedding,
            content=query.content,
            metadata=query.metadata
        )


class ImageEmbedder:
    def __init__(self, model_name="openai/clip-vit-base-patch32"):
        # Always initialize fallback embedder first to ensure it exists
        print("Initializing fallback TextEmbedder")
        self.fallback_embedder = TextEmbedder()
        
        if not TRANSFORMERS_AVAILABLE:
            # Create a simple fallback embedder
            print("Transformers not available - using fallback text embedder")
            self.model = None
            self.processor = None
            return
            
        self.device = "cpu"
        try:
            print("Loading CLIP model: {}".format(model_name))
            self.model = CLIPModel.from_pretrained(model_name).to(self.device)
            self.processor = CLIPProcessor.from_pretrained(model_name)
            print("CLIP model loaded successfully")
        except Exception as e:
            logging.warning("Failed to load CLIP model: {}".format(e))
            self.model = None
            self.processor = None
            print("Creating fallback text embedder due to CLIP load failure: {}".format(e))

    def encode(self, image_path: str) -> list[float]:
        """Image embedding (512-dim)"""
        if not TRANSFORMERS_AVAILABLE or self.model is None:
            print("Using placeholder embedding (512-dim) due to missing CLIP model")
            # Return a placeholder embedding of the right size (512)
            return [0.0] * 512
            
        try:
            print("Loading image from: {}".format(image_path))
            image = Image.open(image_path).convert("RGB")
            inputs = self.processor(images=image, return_tensors="pt").to(self.device)
            with torch.no_grad():
                output = self.model.get_image_features(**inputs)[0].cpu().numpy().tolist()
                if len(output) != 512:
                    print("Warning: CLIP model output has {} dimensions, normalizing to 512".format(len(output)))
                    if len(output) < 512:
                        output = output + [0.0] * (512 - len(output))
                    else:
                        output = output[:512]
                return output
        except Exception as e:
            logging.warning("Failed to encode image: {}".format(e))
            print("Returning zero embedding (512-dim) due to encoding error: {}".format(e))
            return [0.0] * 512

    def encode_text(self, text: str) -> list[float]:
        """Text embedding using CLIP's text encoder (512-dim)"""
        if not TRANSFORMERS_AVAILABLE or self.model is None:
            print("CLIP not available, using fallback text embedder")
            return self._get_normalized_text_embedding(text)
            
        try:
            # Clean and preprocess the text for CLIP
            try:
                # Clean the text - remove special characters that might cause problems
                # Remove excessive whitespace, newlines, etc.
                text = re.sub(r'\s+', ' ', text).strip()
                # Remove or replace problematic characters
                text = re.sub(r'[^\w\s.,!?\'"-]', '', text)
                
                # Limit text length aggressively to avoid tokenization issues
                if len(text) > 300:  # CLIP has limited context window
                    print("Text too long for CLIP ({}), truncating to 300 chars".format(len(text)))
                    text = text[:300]  # Truncate to avoid tensor size issues
                
                print("Cleaned text for CLIP: {}...".format(text[:50] if len(text) > 50 else text))
            except Exception as text_clean_error:
                print("Error cleaning text: {}. Using fallback.".format(text_clean_error))
                # Just truncate if cleaning fails
                if len(text) > 300:
                    text = text[:300]
            
            # Try to encode with CLIP with explicit max length
            try:
                # Use explicit max_length to avoid tensor size mismatches
                inputs = self.processor(
                    text=text,
                    return_tensors="pt",
                    padding="max_length",
                    max_length=77,  # CLIP's standard context length
                    truncation=True
                ).to(self.device)
                
                with torch.no_grad():
                    output = self.model.get_text_features(**inputs)[0].cpu().numpy().tolist()
                    if len(output) != 512:
                        print("Normalizing CLIP output from {} to 512 dimensions".format(len(output)))
                        if len(output) < 512:
                            output = output + [0.0] * (512 - len(output))
                        else:
                            output = output[:512]
                    return output
            except RuntimeError as e:
                print("CLIP encoding error: {}".format(e))
                if "size mismatch" in str(e) or "dimension" in str(e).lower():
                    print("Tensor size mismatch in CLIP, using fallback")
                    return self._get_normalized_text_embedding(text)
                raise
        except Exception as e:
            logging.warning("Failed to encode text with CLIP: {}".format(e))
            print("Using fallback text embedder due to error: {}".format(e))
            return self._get_normalized_text_embedding(text)
    
    def _get_normalized_text_embedding(self, text: str) -> list[float]:
        """Helper to get normalized text embeddings from the fallback embedder"""
        try:
            if self.fallback_embedder is None:
                print("Fallback embedder is None, initializing...")
                self.fallback_embedder = TextEmbedder()
                
            embed = self.fallback_embedder.encode(text)
            # Ensure 512 dimensions for compatibility
            if len(embed) < 512:
                print("Padding fallback embedding from {} to 512 dimensions".format(len(embed)))
                embed = embed + [0.0] * (512 - len(embed))
            elif len(embed) > 512:
                print("Truncating fallback embedding from {} to 512 dimensions".format(len(embed)))
                embed = embed[:512]
            return embed
        except Exception as e:
            print("Error in fallback embedding: {}".format(e))
            # Last resort: return zeros
            return [0.0] * 512
    
    def encode_batch(self, image_paths: list) -> list:
        if not TRANSFORMERS_AVAILABLE or self.model is None:
            print("CLIP not available for batch encoding, returning placeholders")
            # Return placeholder embeddings
            return [[0.0] * 512 for _ in range(len(image_paths))]
            
        try:
            print("Batch encoding {} images with CLIP".format(len(image_paths)))
            with torch.inference_mode():
                images = []
                for path in image_paths:
                    try:
                        img = Image.open(path).convert("RGB")
                        images.append(img)
                    except Exception as e:
                        print("Error opening image {}: {}".format(path, e))
                        # Add a black image as placeholder
                        images.append(Image.new('RGB', (224, 224), color='black'))
                
                if not images:
                    print("No valid images to process")
                    return [[0.0] * 512]
                    
                inputs = self.processor(images=images, return_tensors="pt").to(self.device)
                outputs = self.model.get_image_features(**inputs).cpu().numpy().tolist()
                
                # Ensure each output has 512 dimensions
                normalized_outputs = []
                for output in outputs:
                    if len(output) != 512:
                        if len(output) < 512:
                            output = output + [0.0] * (512 - len(output))
                        else:
                            output = output[:512]
                    normalized_outputs.append(output)
                
                return normalized_outputs
        except Exception as e:
            logging.warning("Failed to batch encode images: {}".format(e))
            print("Returning placeholder embeddings due to batch encoding error: {}".format(e))
            return [[0.0] * 512 for _ in range(len(image_paths))]