import logging import json import tiktoken from elasticsearch import Elasticsearch # from pydantic.v1 import BaseModel, Field # <-- Uses v1 namespace from pydantic import BaseModel, Field from langchain.tools import StructuredTool from ask_candid.base.config.connections import SEMANTIC_ELASTIC_QA logging.basicConfig(level="INFO") logger = logging.getLogger("elasticsearch_playground") es = Elasticsearch( cloud_id=SEMANTIC_ELASTIC_QA.cloud_id, api_key=SEMANTIC_ELASTIC_QA.api_key, verify_certs=True, request_timeout=60 * 3, ) class SearchToolInput(BaseModel): """Input for the index show data tool.""" index_name: str = Field( ..., description="The name of the index for which the data is to be retrieved" ) query: str = Field( ..., description="The ElasticSearch JSON query used to filter all hits. Should use the _source field if possible to specify required fields.", ) from_: int = Field( ..., description="The record index from which the query will start" ) size: int = Field( ..., description="How many records will be retrieved from the ElasticSearch query", ) def elastic_search( pcs_codes: dict, index_name: str, query: str, from_: int = 0, size: int = 20, ): """Executes a specific query on an ElasticSearch index and returns all hits or aggregation results""" size = min(50, size) encoding = tiktoken.encoding_for_model("gpt-4") try: full_dict: dict = json.loads(query) query_dict = None aggs_dict = None sort_dict = None if "query" in full_dict: query_dict = full_dict["query"] if "aggs" in full_dict: aggs_dict = full_dict["aggs"] if "sort" in full_dict: sort_dict = full_dict["sort"] if query_dict is None and aggs_dict is None and sort_dict is None: # Assume that there is a query but that the query part was ommitted. query_dict = full_dict if query_dict is None and aggs_dict is not None: # This is an aggregation query, therefore we suppress the hits here size = 200 logger.info(query) # Print the query # print(f"Executing Elasticsearch Query: {query}") final_res = "" retries = 0 while retries < 100: res = es.search( index=index_name, from_=from_, size=size, query=query_dict, aggs=aggs_dict, sort=sort_dict, ) if query_dict is None and aggs_dict is not None: # When a result has aggregations, just return that and ignore the rest final_res = str(res["aggregations"]) elif query_dict is not None and aggs_dict is not None: # Return both hits and aggregations final_res = str( { "hits": res.get("hits", {}), "aggregations": res.get("aggregations", {}), } ) else: final_res = str(res["hits"]) tokens = encoding.encode(final_res) retries += 1 if len(tokens) > 6000: size -= 1 else: return final_res except Exception as e: logger.exception("Could not execute query %s", query) msg = str(e) return msg def create_search_tool(pcs_codes): return StructuredTool.from_function( func=lambda index_name, query, from_, size: elastic_search( pcs_codes=pcs_codes, index_name=index_name, query=query, from_=from_, size=size, ), name="elastic_index_search_tool", description=( """This tool allows executing queries on an Elasticsearch index efficiently. Provide: 1. index_name (string): The target Elasticsearch index. 2. query (dictionary): Defines the query structure, supporting: a. Filters: For precise data retrieval (e.g., match, term, range). b. Aggregations: For statistical summaries and grouping (e.g., sum, average, histogram). c. Full-text search: For analyzing and ranking text-based results (e.g., match, multi-match, query_string). """ ), args_schema=SearchToolInput, )