rag_hydro / elastic /retrieval.py
Anas Bader
redo
4cbe4e9
from typing import List, Dict, Any, Optional
import logging
from elasticsearch import exceptions
from elastic.es_client import get_es_client
logger = logging.getLogger(__name__)
es_client = get_es_client()
def search_certification_chunks(
index_name: str,
text_query: str,
vector_query: List[float],
certification_name: str,
es_client=es_client,
vector_field: str = "embedding",
text_field: str = "chunk",
size: int = 5,
min_score: float = 0.1, # Lowered threshold
boost_text: float = 1.0,
boost_vector: float = 1.0,
) -> List[Dict[str, Any]]:
# First verify the certification value exists
cert_check = es_client.search(
index=index_name,
body={
"query": {"term": {"certification": certification_name}},
"size": 1,
},
)
if not cert_check["hits"]["hits"]:
logger.error(f"No documents found with certification: {certification_name}")
return []
# Then proceed with hybrid search
query_body = {
"size": size,
"query": {
"bool": {
"should": [
{"match": {"chunk": text_query}},
{
"script_score": {
"query": {"match_all": {}},
"script": {
"source": "cosineSimilarity(params.query_vector, 'embedding') + 1.0",
"params": {"query_vector": vector_query},
},
}
},
]
}
},
}
logger.debug(f"Elasticsearch query body: {query_body}")
logger.info(f"Executing search on index '{index_name}'")
response = es_client.search(index=index_name, body=query_body, routing=cert_check["hits"]["hits"][0]["_id"])
hits = response.get("hits", {}).get("hits", [])
logger.info(f"Found {len(hits)} matching documents")
# Process results with correct field names
results = [
{
"id": hit["_id"],
"score": hit["_score"],
"text": hit["_source"]["chunk"],
"source_file": hit["_source"]["source_file"],
}
for hit in hits
]
if results:
logger.debug(f"Top result score: {results[0]['score']}")
logger.debug(f"Top result source: {results[0]['source_file']}")
else:
logger.warning("No results returned from Elasticsearch")
return results