ask-candid / ask_candid /tools /elastic /index_search_tool.py
brainsqueeze's picture
Smarter document context retrieval
f86d7f2 verified
import logging
import json
import tiktoken
from elasticsearch import Elasticsearch
# from pydantic.v1 import BaseModel, Field # <-- Uses v1 namespace
from pydantic import BaseModel, Field
from langchain.tools import StructuredTool
from ask_candid.base.config.connections import SEMANTIC_ELASTIC_QA
logging.basicConfig(level="INFO")
logger = logging.getLogger("elasticsearch_playground")
es = Elasticsearch(
cloud_id=SEMANTIC_ELASTIC_QA.cloud_id,
api_key=SEMANTIC_ELASTIC_QA.api_key,
verify_certs=True,
request_timeout=60 * 3,
)
class SearchToolInput(BaseModel):
"""Input for the index show data tool."""
index_name: str = Field(
..., description="The name of the index for which the data is to be retrieved"
)
query: str = Field(
...,
description="The ElasticSearch JSON query used to filter all hits. Should use the _source field if possible to specify required fields.",
)
from_: int = Field(
..., description="The record index from which the query will start"
)
size: int = Field(
...,
description="How many records will be retrieved from the ElasticSearch query",
)
def elastic_search(
pcs_codes: dict,
index_name: str,
query: str,
from_: int = 0,
size: int = 20,
):
"""Executes a specific query on an ElasticSearch index and returns all hits or aggregation results"""
size = min(50, size)
encoding = tiktoken.encoding_for_model("gpt-4")
try:
full_dict: dict = json.loads(query)
query_dict = None
aggs_dict = None
sort_dict = None
if "query" in full_dict:
query_dict = full_dict["query"]
if "aggs" in full_dict:
aggs_dict = full_dict["aggs"]
if "sort" in full_dict:
sort_dict = full_dict["sort"]
if query_dict is None and aggs_dict is None and sort_dict is None:
# Assume that there is a query but that the query part was ommitted.
query_dict = full_dict
if query_dict is None and aggs_dict is not None:
# This is an aggregation query, therefore we suppress the hits here
size = 200
logger.info(query)
# Print the query
# print(f"Executing Elasticsearch Query: {query}")
final_res = ""
retries = 0
while retries < 100:
res = es.search(
index=index_name,
from_=from_,
size=size,
query=query_dict,
aggs=aggs_dict,
sort=sort_dict,
)
if query_dict is None and aggs_dict is not None:
# When a result has aggregations, just return that and ignore the rest
final_res = str(res["aggregations"])
elif query_dict is not None and aggs_dict is not None:
# Return both hits and aggregations
final_res = str(
{
"hits": res.get("hits", {}),
"aggregations": res.get("aggregations", {}),
}
)
else:
final_res = str(res["hits"])
tokens = encoding.encode(final_res)
retries += 1
if len(tokens) > 6000:
size -= 1
else:
return final_res
except Exception as e:
logger.exception("Could not execute query %s", query)
msg = str(e)
return msg
def create_search_tool(pcs_codes):
return StructuredTool.from_function(
func=lambda index_name, query, from_, size: elastic_search(
pcs_codes=pcs_codes,
index_name=index_name,
query=query,
from_=from_,
size=size,
),
name="elastic_index_search_tool",
description=(
"""This tool allows executing queries on an Elasticsearch index efficiently. Provide:
1. index_name (string): The target Elasticsearch index.
2. query (dictionary): Defines the query structure, supporting:
a. Filters: For precise data retrieval (e.g., match, term, range).
b. Aggregations: For statistical summaries and grouping (e.g., sum, average, histogram).
c. Full-text search: For analyzing and ranking text-based results (e.g., match, multi-match, query_string).
"""
),
args_schema=SearchToolInput,
)