|
""" |
|
SQL Retriever for RAG System |
|
Intelligent retrieval of relevant SQL examples based on question similarity and table schema analysis. |
|
""" |
|
|
|
import re |
|
from typing import List, Dict, Any, Optional, Tuple |
|
from collections import defaultdict |
|
import numpy as np |
|
from loguru import logger |
|
|
|
from .vector_store import VectorStore |
|
|
|
class SQLRetriever: |
|
"""Intelligent SQL example retriever with schema-aware filtering.""" |
|
|
|
def __init__(self, vector_store: VectorStore): |
|
""" |
|
Initialize the SQL retriever. |
|
|
|
Args: |
|
vector_store: Initialized vector store instance |
|
""" |
|
self.vector_store = vector_store |
|
self.schema_cache = {} |
|
|
|
def retrieve_examples(self, |
|
question: str, |
|
table_headers: List[str], |
|
top_k: int = 5, |
|
use_schema_filtering: bool = True) -> List[Dict[str, Any]]: |
|
""" |
|
Retrieve relevant SQL examples using multiple retrieval strategies. |
|
|
|
Args: |
|
question: Natural language question |
|
table_headers: List of table column names |
|
top_k: Number of examples to retrieve |
|
use_schema_filtering: Whether to use schema-aware filtering |
|
|
|
Returns: |
|
List of retrieved examples with relevance scores |
|
""" |
|
|
|
vector_results = self.vector_store.search_similar( |
|
query=question, |
|
table_headers=table_headers, |
|
top_k=top_k * 2, |
|
similarity_threshold=0.6 |
|
) |
|
|
|
if not vector_results: |
|
logger.warning("No vector search results found") |
|
return [] |
|
|
|
|
|
if use_schema_filtering: |
|
filtered_results = self._apply_schema_filtering( |
|
vector_results, question, table_headers |
|
) |
|
else: |
|
filtered_results = vector_results |
|
|
|
|
|
enhanced_results = self._enhance_with_question_analysis( |
|
filtered_results, question, table_headers |
|
) |
|
|
|
|
|
final_results = self._final_ranking( |
|
enhanced_results, question, table_headers, top_k |
|
) |
|
|
|
logger.info(f"Retrieved {len(final_results)} relevant examples") |
|
return final_results |
|
|
|
def _apply_schema_filtering(self, |
|
results: List[Dict[str, Any]], |
|
question: str, |
|
table_headers: List[str]) -> List[Dict[str, Any]]: |
|
"""Apply schema-aware filtering to improve relevance.""" |
|
filtered_results = [] |
|
|
|
|
|
current_schema = self._analyze_schema(table_headers) |
|
|
|
for result in results: |
|
|
|
example_headers = result["table_headers"] |
|
if isinstance(example_headers, str): |
|
example_headers = [h.strip() for h in example_headers.split(",")] |
|
|
|
example_schema = self._analyze_schema(example_headers) |
|
|
|
|
|
schema_similarity = self._calculate_schema_similarity( |
|
current_schema, example_schema |
|
) |
|
|
|
|
|
result["schema_similarity"] = schema_similarity |
|
result["enhanced_score"] = ( |
|
result["similarity_score"] * 0.7 + |
|
schema_similarity * 0.3 |
|
) |
|
|
|
|
|
if schema_similarity > 0.3: |
|
filtered_results.append(result) |
|
|
|
return filtered_results |
|
|
|
def _analyze_schema(self, table_headers: List[str]) -> Dict[str, Any]: |
|
"""Analyze table schema for intelligent matching.""" |
|
if not table_headers: |
|
return {} |
|
|
|
schema_info = { |
|
"column_count": len(table_headers), |
|
"column_types": {}, |
|
"has_numeric": False, |
|
"has_text": False, |
|
"has_date": False, |
|
"has_boolean": False, |
|
"primary_key_candidates": [], |
|
"foreign_key_candidates": [] |
|
} |
|
|
|
for header in table_headers: |
|
header_lower = header.lower() |
|
|
|
|
|
if any(word in header_lower for word in ['id', 'key', 'pk', 'fk']): |
|
if 'id' in header_lower: |
|
schema_info["primary_key_candidates"].append(header) |
|
if 'fk' in header_lower or 'foreign' in header_lower: |
|
schema_info["foreign_key_candidates"].append(header) |
|
|
|
|
|
if any(word in header_lower for word in ['age', 'count', 'number', 'price', 'salary', 'amount']): |
|
schema_info["has_numeric"] = True |
|
schema_info["column_types"][header] = "numeric" |
|
|
|
if any(word in header_lower for word in ['name', 'title', 'description', 'text', 'comment']): |
|
schema_info["has_text"] = True |
|
schema_info["column_types"][header] = "text" |
|
|
|
if any(word in header_lower for word in ['date', 'time', 'created', 'updated', 'birth']): |
|
schema_info["has_date"] = True |
|
schema_info["column_types"][header] = "date" |
|
|
|
if any(word in header_lower for word in ['is_', 'has_', 'active', 'enabled', 'status']): |
|
schema_info["has_boolean"] = True |
|
schema_info["column_types"][header] = "boolean" |
|
|
|
return schema_info |
|
|
|
def _calculate_schema_similarity(self, |
|
schema1: Dict[str, Any], |
|
schema2: Dict[str, Any]) -> float: |
|
"""Calculate similarity between two table schemas.""" |
|
if not schema1 or not schema2: |
|
return 0.0 |
|
|
|
|
|
count_diff = abs(schema1.get("column_count", 0) - schema2.get("column_count", 0)) |
|
count_similarity = max(0, 1 - (count_diff / max(schema1.get("column_count", 1), 1))) |
|
|
|
|
|
type_similarity = 0.0 |
|
if schema1.get("has_numeric") == schema2.get("has_numeric"): |
|
type_similarity += 0.25 |
|
if schema1.get("has_text") == schema2.get("has_text"): |
|
type_similarity += 0.25 |
|
if schema1.get("has_date") == schema2.get("has_date"): |
|
type_similarity += 0.25 |
|
if schema1.get("has_boolean") == schema2.get("has_boolean"): |
|
type_similarity += 0.25 |
|
|
|
|
|
pk_similarity = 0.0 |
|
if (schema1.get("primary_key_candidates") and |
|
schema2.get("primary_key_candidates")): |
|
pk_similarity = 0.2 |
|
|
|
|
|
final_similarity = ( |
|
count_similarity * 0.4 + |
|
type_similarity * 0.4 + |
|
pk_similarity * 0.2 |
|
) |
|
|
|
return final_similarity |
|
|
|
def _enhance_with_question_analysis(self, |
|
results: List[Dict[str, Any]], |
|
question: str, |
|
table_headers: List[str]) -> List[Dict[str, Any]]: |
|
"""Enhance results with question type analysis.""" |
|
|
|
question_type = self._classify_question_type(question) |
|
|
|
for result in results: |
|
|
|
if question_type in result.get("category", "").lower(): |
|
result["enhanced_score"] *= 1.2 |
|
|
|
|
|
question_complexity = self._assess_question_complexity(question) |
|
example_complexity = self._assess_question_complexity(result["question"]) |
|
|
|
complexity_match = 1 - abs(question_complexity - example_complexity) / max(question_complexity, 1) |
|
result["enhanced_score"] *= (0.9 + complexity_match * 0.1) |
|
|
|
return results |
|
|
|
def _classify_question_type(self, question: str) -> str: |
|
"""Classify the type of SQL question.""" |
|
question_lower = question.lower() |
|
|
|
if any(word in question_lower for word in ['count', 'how many', 'number of']): |
|
return "aggregation" |
|
elif any(word in question_lower for word in ['average', 'mean', 'sum', 'total']): |
|
return "aggregation" |
|
elif any(word in question_lower for word in ['group by', 'grouped', 'by department', 'by category']): |
|
return "grouping" |
|
elif any(word in question_lower for word in ['join', 'combine', 'merge', 'connect']): |
|
return "join" |
|
elif any(word in question_lower for word in ['order by', 'sort', 'rank', 'top', 'highest', 'lowest']): |
|
return "sorting" |
|
elif any(word in question_lower for word in ['where', 'filter', 'condition']): |
|
return "filtering" |
|
else: |
|
return "general" |
|
|
|
def _assess_question_complexity(self, question: str) -> float: |
|
"""Assess the complexity of a question (0-1 scale).""" |
|
complexity_score = 0.0 |
|
|
|
|
|
if len(question.split()) > 20: |
|
complexity_score += 0.3 |
|
elif len(question.split()) > 10: |
|
complexity_score += 0.2 |
|
|
|
|
|
complex_keywords = ['join', 'group by', 'having', 'subquery', 'union', 'intersect'] |
|
for keyword in complex_keywords: |
|
if keyword in question.lower(): |
|
complexity_score += 0.15 |
|
|
|
|
|
if '?' in question: |
|
complexity_score += 0.1 |
|
|
|
return min(1.0, complexity_score) |
|
|
|
def _final_ranking(self, |
|
results: List[Dict[str, Any]], |
|
question: str, |
|
table_headers: List[str], |
|
top_k: int) -> List[Dict[str, Any]]: |
|
"""Final ranking and selection of examples.""" |
|
if not results: |
|
return [] |
|
|
|
|
|
results.sort(key=lambda x: x.get("enhanced_score", 0), reverse=True) |
|
|
|
|
|
diverse_results = [] |
|
seen_categories = set() |
|
|
|
for result in results: |
|
if len(diverse_results) >= top_k: |
|
break |
|
|
|
category = result.get("category", "general") |
|
if category not in seen_categories or len(diverse_results) < top_k // 2: |
|
diverse_results.append(result) |
|
seen_categories.add(category) |
|
|
|
|
|
remaining_slots = top_k - len(diverse_results) |
|
if remaining_slots > 0: |
|
for result in results: |
|
if result not in diverse_results and len(diverse_results) < top_k: |
|
diverse_results.append(result) |
|
|
|
|
|
for result in diverse_results: |
|
result["final_score"] = result.get("enhanced_score", result.get("similarity_score", 0)) |
|
|
|
result.pop("enhanced_score", None) |
|
result.pop("schema_similarity", None) |
|
|
|
return diverse_results[:top_k] |
|
|
|
def get_retrieval_stats(self) -> Dict[str, Any]: |
|
"""Get statistics about the retrieval system.""" |
|
vector_stats = self.vector_store.get_statistics() |
|
|
|
return { |
|
"vector_store_stats": vector_stats, |
|
"schema_cache_size": len(self.schema_cache), |
|
"retrieval_strategies": [ |
|
"vector_similarity", |
|
"schema_filtering", |
|
"question_analysis", |
|
"diversity_ranking" |
|
] |
|
} |
|
|