Spaces:
Sleeping
Sleeping
File size: 10,322 Bytes
e0aa230 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 |
"""
Optimized Query Processor with Rate Limiting and Better Error Handling
"""
import logging
import time
from typing import Dict, List, Any, Optional
from datetime import datetime, timedelta
class OptimizedQueryProcessor:
"""
Optimized QueryProcessor with rate limiting and better error handling
"""
def __init__(
self, embedding_generator, vector_db, config: Optional[Dict[str, Any]] = None
):
self.embedding_generator = embedding_generator
self.vector_db = vector_db
self.config = config or {}
self.logger = logging.getLogger(__name__)
# Optimized configuration settings
self.top_k = self.config.get("top_k", 10) # Increased from 5
self.similarity_threshold = self.config.get(
"similarity_threshold", 0.4
) # Lowered from 0.7
self.max_context_length = self.config.get(
"max_context_length", 8000
) # Increased
self.enable_caching = self.config.get("enable_caching", True)
self.cache_ttl = self.config.get("cache_ttl", 7200) # 2 hours
# Rate limiting settings
self.last_api_call = 0
self.min_api_interval = 1.0 # Minimum 1 second between API calls
self.max_retries = 3
self.retry_delay = 2.0
# Query cache and history
self.query_cache = {}
self.query_history = []
self.logger.info("OptimizedQueryProcessor initialized")
def process_query(
self, query: str, filter: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""
Process query with optimized rate limiting and error handling
"""
if not query or not query.strip():
return {
"query": query,
"context": [],
"total_results": 0,
"error": "Empty query provided",
}
self.logger.info(f"Processing query: {query[:100]}...")
start_time = time.time()
try:
# Check cache first
cache_key = self._generate_cache_key(query, filter)
if self.enable_caching and cache_key in self.query_cache:
cached_result = self.query_cache[cache_key]
if self._is_cache_valid(cached_result["timestamp"]):
self.logger.info("Returning cached result")
cached_result["from_cache"] = True
return cached_result
# Rate limiting protection
self._enforce_rate_limit()
# Generate query embedding with retry logic
query_embedding = self._generate_embedding_with_retry(query)
if not query_embedding:
return {
"query": query,
"context": [],
"total_results": 0,
"error": "Failed to generate query embedding",
}
# Search for similar vectors with increased top_k
search_results = self.vector_db.search(
query_embedding=query_embedding,
top_k=self.top_k * 2, # Get more results for better filtering
filter=filter,
include_metadata=True,
)
if not search_results:
self.logger.warning("No search results returned from vector database")
return {
"query": query,
"context": [],
"total_results": 0,
"error": "No similar documents found",
}
# Apply optimized filtering
filtered_results = self._apply_smart_filtering(search_results, query)
# Extract and format context with better error handling
context = self._extract_context_safely(filtered_results)
# Prepare result
result = {
"query": query,
"context": context,
"total_results": len(filtered_results),
"processing_time": time.time() - start_time,
"timestamp": datetime.now(),
"from_cache": False,
"similarity_scores": [r.get("score", 0) for r in filtered_results[:5]],
}
# Cache the result
if self.enable_caching:
self.query_cache[cache_key] = result.copy()
self.logger.info(
f"Query processed in {result['processing_time']:.2f}s, {len(context)} context items"
)
return result
except Exception as e:
self.logger.error(f"Error processing query: {str(e)}")
return {
"query": query,
"context": [],
"total_results": 0,
"error": str(e),
"processing_time": time.time() - start_time,
}
def _enforce_rate_limit(self):
"""Enforce rate limiting between API calls"""
current_time = time.time()
time_since_last_call = current_time - self.last_api_call
if time_since_last_call < self.min_api_interval:
sleep_time = self.min_api_interval - time_since_last_call
self.logger.info(f"Rate limiting: sleeping {sleep_time:.1f}s")
time.sleep(sleep_time)
self.last_api_call = time.time()
def _generate_embedding_with_retry(self, query: str) -> List[float]:
"""Generate embedding with retry logic and rate limiting"""
for attempt in range(self.max_retries):
try:
self._enforce_rate_limit()
embedding = self.embedding_generator.generate_query_embedding(query)
if embedding:
return embedding
else:
self.logger.warning(
f"Attempt {attempt + 1}: Empty embedding returned"
)
except Exception as e:
self.logger.warning(f"Attempt {attempt + 1} failed: {str(e)}")
if "429" in str(e) or "quota" in str(e).lower():
# Rate limit hit - wait longer
wait_time = self.retry_delay * (2**attempt)
self.logger.info(f"Rate limit hit, waiting {wait_time}s...")
time.sleep(wait_time)
elif attempt < self.max_retries - 1:
time.sleep(self.retry_delay)
self.logger.error("All embedding generation attempts failed")
return []
def _apply_smart_filtering(
self, search_results: List[Dict[str, Any]], query: str
) -> List[Dict[str, Any]]:
"""Apply smart filtering with adaptive threshold"""
if not search_results:
return []
# Get score statistics
scores = [r.get("score", 0) for r in search_results]
max_score = max(scores)
avg_score = sum(scores) / len(scores)
# Adaptive threshold: use lower threshold if max score is low
adaptive_threshold = min(self.similarity_threshold, max_score * 0.8)
self.logger.info(
f"Score stats - Max: {max_score:.3f}, Avg: {avg_score:.3f}, Threshold: {adaptive_threshold:.3f}"
)
# Filter results
filtered = [
result
for result in search_results[: self.top_k]
if result.get("score", 0) >= adaptive_threshold
]
# If no results pass threshold, return top 3 anyway
if not filtered and search_results:
self.logger.warning(
f"No results above threshold {adaptive_threshold:.3f}, returning top 3"
)
filtered = search_results[:3]
return filtered
def _extract_context_safely(
self, search_results: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
"""Extract context with better error handling"""
context = []
total_length = 0
for i, result in enumerate(search_results):
try:
# Multiple ways to extract text content
text = ""
metadata = result.get("metadata", {})
# Try different text fields
for field in ["text", "content", "content_preview", "description"]:
if field in metadata and metadata[field]:
text = str(metadata[field])
break
if not text:
self.logger.warning(f"No text content found in result {i}")
continue
# Check length limit
if total_length + len(text) > self.max_context_length and context:
break
# Create context item
context_item = {
"text": text,
"score": result.get("score", 0),
"source": metadata.get("source", f"Document {i+1}"),
"chunk_id": result.get("id", ""),
"metadata": metadata,
"relevance_rank": len(context) + 1,
}
context.append(context_item)
total_length += len(text)
except Exception as e:
self.logger.warning(f"Error extracting context from result {i}: {e}")
continue
self.logger.info(
f"Extracted {len(context)} context items (total length: {total_length})"
)
return context
def _generate_cache_key(self, query: str, filter: Optional[Dict[str, Any]]) -> str:
"""Generate cache key for query"""
import hashlib
filter_str = str(sorted(filter.items())) if filter else ""
cache_string = f"{query.lower().strip()}{filter_str}"
return hashlib.md5(cache_string.encode()).hexdigest()
def _is_cache_valid(self, timestamp: datetime) -> bool:
"""Check if cached result is still valid"""
return datetime.now() - timestamp < timedelta(seconds=self.cache_ttl)
|