payman / src /rag /optimized_query_processor.py
satyamdev404's picture
Upload 31 files
e0aa230 verified
"""
Optimized Query Processor with Rate Limiting and Better Error Handling
"""
import logging
import time
from typing import Dict, List, Any, Optional
from datetime import datetime, timedelta
class OptimizedQueryProcessor:
"""
Optimized QueryProcessor with rate limiting and better error handling
"""
def __init__(
self, embedding_generator, vector_db, config: Optional[Dict[str, Any]] = None
):
self.embedding_generator = embedding_generator
self.vector_db = vector_db
self.config = config or {}
self.logger = logging.getLogger(__name__)
# Optimized configuration settings
self.top_k = self.config.get("top_k", 10) # Increased from 5
self.similarity_threshold = self.config.get(
"similarity_threshold", 0.4
) # Lowered from 0.7
self.max_context_length = self.config.get(
"max_context_length", 8000
) # Increased
self.enable_caching = self.config.get("enable_caching", True)
self.cache_ttl = self.config.get("cache_ttl", 7200) # 2 hours
# Rate limiting settings
self.last_api_call = 0
self.min_api_interval = 1.0 # Minimum 1 second between API calls
self.max_retries = 3
self.retry_delay = 2.0
# Query cache and history
self.query_cache = {}
self.query_history = []
self.logger.info("OptimizedQueryProcessor initialized")
def process_query(
self, query: str, filter: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""
Process query with optimized rate limiting and error handling
"""
if not query or not query.strip():
return {
"query": query,
"context": [],
"total_results": 0,
"error": "Empty query provided",
}
self.logger.info(f"Processing query: {query[:100]}...")
start_time = time.time()
try:
# Check cache first
cache_key = self._generate_cache_key(query, filter)
if self.enable_caching and cache_key in self.query_cache:
cached_result = self.query_cache[cache_key]
if self._is_cache_valid(cached_result["timestamp"]):
self.logger.info("Returning cached result")
cached_result["from_cache"] = True
return cached_result
# Rate limiting protection
self._enforce_rate_limit()
# Generate query embedding with retry logic
query_embedding = self._generate_embedding_with_retry(query)
if not query_embedding:
return {
"query": query,
"context": [],
"total_results": 0,
"error": "Failed to generate query embedding",
}
# Search for similar vectors with increased top_k
search_results = self.vector_db.search(
query_embedding=query_embedding,
top_k=self.top_k * 2, # Get more results for better filtering
filter=filter,
include_metadata=True,
)
if not search_results:
self.logger.warning("No search results returned from vector database")
return {
"query": query,
"context": [],
"total_results": 0,
"error": "No similar documents found",
}
# Apply optimized filtering
filtered_results = self._apply_smart_filtering(search_results, query)
# Extract and format context with better error handling
context = self._extract_context_safely(filtered_results)
# Prepare result
result = {
"query": query,
"context": context,
"total_results": len(filtered_results),
"processing_time": time.time() - start_time,
"timestamp": datetime.now(),
"from_cache": False,
"similarity_scores": [r.get("score", 0) for r in filtered_results[:5]],
}
# Cache the result
if self.enable_caching:
self.query_cache[cache_key] = result.copy()
self.logger.info(
f"Query processed in {result['processing_time']:.2f}s, {len(context)} context items"
)
return result
except Exception as e:
self.logger.error(f"Error processing query: {str(e)}")
return {
"query": query,
"context": [],
"total_results": 0,
"error": str(e),
"processing_time": time.time() - start_time,
}
def _enforce_rate_limit(self):
"""Enforce rate limiting between API calls"""
current_time = time.time()
time_since_last_call = current_time - self.last_api_call
if time_since_last_call < self.min_api_interval:
sleep_time = self.min_api_interval - time_since_last_call
self.logger.info(f"Rate limiting: sleeping {sleep_time:.1f}s")
time.sleep(sleep_time)
self.last_api_call = time.time()
def _generate_embedding_with_retry(self, query: str) -> List[float]:
"""Generate embedding with retry logic and rate limiting"""
for attempt in range(self.max_retries):
try:
self._enforce_rate_limit()
embedding = self.embedding_generator.generate_query_embedding(query)
if embedding:
return embedding
else:
self.logger.warning(
f"Attempt {attempt + 1}: Empty embedding returned"
)
except Exception as e:
self.logger.warning(f"Attempt {attempt + 1} failed: {str(e)}")
if "429" in str(e) or "quota" in str(e).lower():
# Rate limit hit - wait longer
wait_time = self.retry_delay * (2**attempt)
self.logger.info(f"Rate limit hit, waiting {wait_time}s...")
time.sleep(wait_time)
elif attempt < self.max_retries - 1:
time.sleep(self.retry_delay)
self.logger.error("All embedding generation attempts failed")
return []
def _apply_smart_filtering(
self, search_results: List[Dict[str, Any]], query: str
) -> List[Dict[str, Any]]:
"""Apply smart filtering with adaptive threshold"""
if not search_results:
return []
# Get score statistics
scores = [r.get("score", 0) for r in search_results]
max_score = max(scores)
avg_score = sum(scores) / len(scores)
# Adaptive threshold: use lower threshold if max score is low
adaptive_threshold = min(self.similarity_threshold, max_score * 0.8)
self.logger.info(
f"Score stats - Max: {max_score:.3f}, Avg: {avg_score:.3f}, Threshold: {adaptive_threshold:.3f}"
)
# Filter results
filtered = [
result
for result in search_results[: self.top_k]
if result.get("score", 0) >= adaptive_threshold
]
# If no results pass threshold, return top 3 anyway
if not filtered and search_results:
self.logger.warning(
f"No results above threshold {adaptive_threshold:.3f}, returning top 3"
)
filtered = search_results[:3]
return filtered
def _extract_context_safely(
self, search_results: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
"""Extract context with better error handling"""
context = []
total_length = 0
for i, result in enumerate(search_results):
try:
# Multiple ways to extract text content
text = ""
metadata = result.get("metadata", {})
# Try different text fields
for field in ["text", "content", "content_preview", "description"]:
if field in metadata and metadata[field]:
text = str(metadata[field])
break
if not text:
self.logger.warning(f"No text content found in result {i}")
continue
# Check length limit
if total_length + len(text) > self.max_context_length and context:
break
# Create context item
context_item = {
"text": text,
"score": result.get("score", 0),
"source": metadata.get("source", f"Document {i+1}"),
"chunk_id": result.get("id", ""),
"metadata": metadata,
"relevance_rank": len(context) + 1,
}
context.append(context_item)
total_length += len(text)
except Exception as e:
self.logger.warning(f"Error extracting context from result {i}: {e}")
continue
self.logger.info(
f"Extracted {len(context)} context items (total length: {total_length})"
)
return context
def _generate_cache_key(self, query: str, filter: Optional[Dict[str, Any]]) -> str:
"""Generate cache key for query"""
import hashlib
filter_str = str(sorted(filter.items())) if filter else ""
cache_string = f"{query.lower().strip()}{filter_str}"
return hashlib.md5(cache_string.encode()).hexdigest()
def _is_cache_valid(self, timestamp: datetime) -> bool:
"""Check if cached result is still valid"""
return datetime.now() - timestamp < timedelta(seconds=self.cache_ttl)