Spaces:
Running
Running
import logging | |
import json | |
import tiktoken | |
from elasticsearch import Elasticsearch | |
# from pydantic.v1 import BaseModel, Field # <-- Uses v1 namespace | |
from pydantic import BaseModel, Field | |
from langchain.tools import StructuredTool | |
from ask_candid.base.config.connections import SEMANTIC_ELASTIC_QA | |
logging.basicConfig(level="INFO") | |
logger = logging.getLogger("elasticsearch_playground") | |
es = Elasticsearch( | |
cloud_id=SEMANTIC_ELASTIC_QA.cloud_id, | |
api_key=SEMANTIC_ELASTIC_QA.api_key, | |
verify_certs=True, | |
request_timeout=60 * 3, | |
) | |
class SearchToolInput(BaseModel): | |
"""Input for the index show data tool.""" | |
index_name: str = Field( | |
..., description="The name of the index for which the data is to be retrieved" | |
) | |
query: str = Field( | |
..., | |
description="The ElasticSearch JSON query used to filter all hits. Should use the _source field if possible to specify required fields.", | |
) | |
from_: int = Field( | |
..., description="The record index from which the query will start" | |
) | |
size: int = Field( | |
..., | |
description="How many records will be retrieved from the ElasticSearch query", | |
) | |
def elastic_search( | |
pcs_codes: dict, | |
index_name: str, | |
query: str, | |
from_: int = 0, | |
size: int = 20, | |
): | |
"""Executes a specific query on an ElasticSearch index and returns all hits or aggregation results""" | |
size = min(50, size) | |
encoding = tiktoken.encoding_for_model("gpt-4") | |
try: | |
full_dict: dict = json.loads(query) | |
query_dict = None | |
aggs_dict = None | |
sort_dict = None | |
if "query" in full_dict: | |
query_dict = full_dict["query"] | |
if "aggs" in full_dict: | |
aggs_dict = full_dict["aggs"] | |
if "sort" in full_dict: | |
sort_dict = full_dict["sort"] | |
if query_dict is None and aggs_dict is None and sort_dict is None: | |
# Assume that there is a query but that the query part was ommitted. | |
query_dict = full_dict | |
if query_dict is None and aggs_dict is not None: | |
# This is an aggregation query, therefore we suppress the hits here | |
size = 200 | |
logger.info(query) | |
# Print the query | |
# print(f"Executing Elasticsearch Query: {query}") | |
final_res = "" | |
retries = 0 | |
while retries < 100: | |
res = es.search( | |
index=index_name, | |
from_=from_, | |
size=size, | |
query=query_dict, | |
aggs=aggs_dict, | |
sort=sort_dict, | |
) | |
if query_dict is None and aggs_dict is not None: | |
# When a result has aggregations, just return that and ignore the rest | |
final_res = str(res["aggregations"]) | |
elif query_dict is not None and aggs_dict is not None: | |
# Return both hits and aggregations | |
final_res = str( | |
{ | |
"hits": res.get("hits", {}), | |
"aggregations": res.get("aggregations", {}), | |
} | |
) | |
else: | |
final_res = str(res["hits"]) | |
tokens = encoding.encode(final_res) | |
retries += 1 | |
if len(tokens) > 6000: | |
size -= 1 | |
else: | |
return final_res | |
except Exception as e: | |
logger.exception("Could not execute query %s", query) | |
msg = str(e) | |
return msg | |
def create_search_tool(pcs_codes): | |
return StructuredTool.from_function( | |
func=lambda index_name, query, from_, size: elastic_search( | |
pcs_codes=pcs_codes, | |
index_name=index_name, | |
query=query, | |
from_=from_, | |
size=size, | |
), | |
name="elastic_index_search_tool", | |
description=( | |
"""This tool allows executing queries on an Elasticsearch index efficiently. Provide: | |
1. index_name (string): The target Elasticsearch index. | |
2. query (dictionary): Defines the query structure, supporting: | |
a. Filters: For precise data retrieval (e.g., match, term, range). | |
b. Aggregations: For statistical summaries and grouping (e.g., sum, average, histogram). | |
c. Full-text search: For analyzing and ranking text-based results (e.g., match, multi-match, query_string). | |
""" | |
), | |
args_schema=SearchToolInput, | |
) | |