File size: 10,322 Bytes
e0aa230
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
"""

Optimized Query Processor with Rate Limiting and Better Error Handling

"""

import logging
import time
from typing import Dict, List, Any, Optional
from datetime import datetime, timedelta


class OptimizedQueryProcessor:
    """

    Optimized QueryProcessor with rate limiting and better error handling

    """

    def __init__(

        self, embedding_generator, vector_db, config: Optional[Dict[str, Any]] = None

    ):
        self.embedding_generator = embedding_generator
        self.vector_db = vector_db
        self.config = config or {}
        self.logger = logging.getLogger(__name__)

        # Optimized configuration settings
        self.top_k = self.config.get("top_k", 10)  # Increased from 5
        self.similarity_threshold = self.config.get(
            "similarity_threshold", 0.4
        )  # Lowered from 0.7
        self.max_context_length = self.config.get(
            "max_context_length", 8000
        )  # Increased
        self.enable_caching = self.config.get("enable_caching", True)
        self.cache_ttl = self.config.get("cache_ttl", 7200)  # 2 hours

        # Rate limiting settings
        self.last_api_call = 0
        self.min_api_interval = 1.0  # Minimum 1 second between API calls
        self.max_retries = 3
        self.retry_delay = 2.0

        # Query cache and history
        self.query_cache = {}
        self.query_history = []

        self.logger.info("OptimizedQueryProcessor initialized")

    def process_query(

        self, query: str, filter: Optional[Dict[str, Any]] = None

    ) -> Dict[str, Any]:
        """

        Process query with optimized rate limiting and error handling

        """
        if not query or not query.strip():
            return {
                "query": query,
                "context": [],
                "total_results": 0,
                "error": "Empty query provided",
            }

        self.logger.info(f"Processing query: {query[:100]}...")
        start_time = time.time()

        try:
            # Check cache first
            cache_key = self._generate_cache_key(query, filter)
            if self.enable_caching and cache_key in self.query_cache:
                cached_result = self.query_cache[cache_key]
                if self._is_cache_valid(cached_result["timestamp"]):
                    self.logger.info("Returning cached result")
                    cached_result["from_cache"] = True
                    return cached_result

            # Rate limiting protection
            self._enforce_rate_limit()

            # Generate query embedding with retry logic
            query_embedding = self._generate_embedding_with_retry(query)

            if not query_embedding:
                return {
                    "query": query,
                    "context": [],
                    "total_results": 0,
                    "error": "Failed to generate query embedding",
                }

            # Search for similar vectors with increased top_k
            search_results = self.vector_db.search(
                query_embedding=query_embedding,
                top_k=self.top_k * 2,  # Get more results for better filtering
                filter=filter,
                include_metadata=True,
            )

            if not search_results:
                self.logger.warning("No search results returned from vector database")
                return {
                    "query": query,
                    "context": [],
                    "total_results": 0,
                    "error": "No similar documents found",
                }

            # Apply optimized filtering
            filtered_results = self._apply_smart_filtering(search_results, query)

            # Extract and format context with better error handling
            context = self._extract_context_safely(filtered_results)

            # Prepare result
            result = {
                "query": query,
                "context": context,
                "total_results": len(filtered_results),
                "processing_time": time.time() - start_time,
                "timestamp": datetime.now(),
                "from_cache": False,
                "similarity_scores": [r.get("score", 0) for r in filtered_results[:5]],
            }

            # Cache the result
            if self.enable_caching:
                self.query_cache[cache_key] = result.copy()

            self.logger.info(
                f"Query processed in {result['processing_time']:.2f}s, {len(context)} context items"
            )
            return result

        except Exception as e:
            self.logger.error(f"Error processing query: {str(e)}")
            return {
                "query": query,
                "context": [],
                "total_results": 0,
                "error": str(e),
                "processing_time": time.time() - start_time,
            }

    def _enforce_rate_limit(self):
        """Enforce rate limiting between API calls"""
        current_time = time.time()
        time_since_last_call = current_time - self.last_api_call

        if time_since_last_call < self.min_api_interval:
            sleep_time = self.min_api_interval - time_since_last_call
            self.logger.info(f"Rate limiting: sleeping {sleep_time:.1f}s")
            time.sleep(sleep_time)

        self.last_api_call = time.time()

    def _generate_embedding_with_retry(self, query: str) -> List[float]:
        """Generate embedding with retry logic and rate limiting"""
        for attempt in range(self.max_retries):
            try:
                self._enforce_rate_limit()
                embedding = self.embedding_generator.generate_query_embedding(query)

                if embedding:
                    return embedding
                else:
                    self.logger.warning(
                        f"Attempt {attempt + 1}: Empty embedding returned"
                    )

            except Exception as e:
                self.logger.warning(f"Attempt {attempt + 1} failed: {str(e)}")

                if "429" in str(e) or "quota" in str(e).lower():
                    # Rate limit hit - wait longer
                    wait_time = self.retry_delay * (2**attempt)
                    self.logger.info(f"Rate limit hit, waiting {wait_time}s...")
                    time.sleep(wait_time)
                elif attempt < self.max_retries - 1:
                    time.sleep(self.retry_delay)

        self.logger.error("All embedding generation attempts failed")
        return []

    def _apply_smart_filtering(

        self, search_results: List[Dict[str, Any]], query: str

    ) -> List[Dict[str, Any]]:
        """Apply smart filtering with adaptive threshold"""
        if not search_results:
            return []

        # Get score statistics
        scores = [r.get("score", 0) for r in search_results]
        max_score = max(scores)
        avg_score = sum(scores) / len(scores)

        # Adaptive threshold: use lower threshold if max score is low
        adaptive_threshold = min(self.similarity_threshold, max_score * 0.8)

        self.logger.info(
            f"Score stats - Max: {max_score:.3f}, Avg: {avg_score:.3f}, Threshold: {adaptive_threshold:.3f}"
        )

        # Filter results
        filtered = [
            result
            for result in search_results[: self.top_k]
            if result.get("score", 0) >= adaptive_threshold
        ]

        # If no results pass threshold, return top 3 anyway
        if not filtered and search_results:
            self.logger.warning(
                f"No results above threshold {adaptive_threshold:.3f}, returning top 3"
            )
            filtered = search_results[:3]

        return filtered

    def _extract_context_safely(

        self, search_results: List[Dict[str, Any]]

    ) -> List[Dict[str, Any]]:
        """Extract context with better error handling"""
        context = []
        total_length = 0

        for i, result in enumerate(search_results):
            try:
                # Multiple ways to extract text content
                text = ""
                metadata = result.get("metadata", {})

                # Try different text fields
                for field in ["text", "content", "content_preview", "description"]:
                    if field in metadata and metadata[field]:
                        text = str(metadata[field])
                        break

                if not text:
                    self.logger.warning(f"No text content found in result {i}")
                    continue

                # Check length limit
                if total_length + len(text) > self.max_context_length and context:
                    break

                # Create context item
                context_item = {
                    "text": text,
                    "score": result.get("score", 0),
                    "source": metadata.get("source", f"Document {i+1}"),
                    "chunk_id": result.get("id", ""),
                    "metadata": metadata,
                    "relevance_rank": len(context) + 1,
                }

                context.append(context_item)
                total_length += len(text)

            except Exception as e:
                self.logger.warning(f"Error extracting context from result {i}: {e}")
                continue

        self.logger.info(
            f"Extracted {len(context)} context items (total length: {total_length})"
        )
        return context

    def _generate_cache_key(self, query: str, filter: Optional[Dict[str, Any]]) -> str:
        """Generate cache key for query"""
        import hashlib

        filter_str = str(sorted(filter.items())) if filter else ""
        cache_string = f"{query.lower().strip()}{filter_str}"
        return hashlib.md5(cache_string.encode()).hexdigest()

    def _is_cache_valid(self, timestamp: datetime) -> bool:
        """Check if cached result is still valid"""
        return datetime.now() - timestamp < timedelta(seconds=self.cache_ttl)