Spaces:
Running
Running
File size: 4,579 Bytes
a0e37e2 bea5044 a0e37e2 bea5044 a0e37e2 f86d7f2 a0e37e2 bea5044 a0e37e2 bea5044 a0e37e2 f86d7f2 a0e37e2 f86d7f2 bea5044 a0e37e2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import logging
import json
import tiktoken
from elasticsearch import Elasticsearch
# from pydantic.v1 import BaseModel, Field # <-- Uses v1 namespace
from pydantic import BaseModel, Field
from langchain.tools import StructuredTool
from ask_candid.base.config.connections import SEMANTIC_ELASTIC_QA
logging.basicConfig(level="INFO")
logger = logging.getLogger("elasticsearch_playground")
es = Elasticsearch(
cloud_id=SEMANTIC_ELASTIC_QA.cloud_id,
api_key=SEMANTIC_ELASTIC_QA.api_key,
verify_certs=True,
request_timeout=60 * 3,
)
class SearchToolInput(BaseModel):
"""Input for the index show data tool."""
index_name: str = Field(
..., description="The name of the index for which the data is to be retrieved"
)
query: str = Field(
...,
description="The ElasticSearch JSON query used to filter all hits. Should use the _source field if possible to specify required fields.",
)
from_: int = Field(
..., description="The record index from which the query will start"
)
size: int = Field(
...,
description="How many records will be retrieved from the ElasticSearch query",
)
def elastic_search(
pcs_codes: dict,
index_name: str,
query: str,
from_: int = 0,
size: int = 20,
):
"""Executes a specific query on an ElasticSearch index and returns all hits or aggregation results"""
size = min(50, size)
encoding = tiktoken.encoding_for_model("gpt-4")
try:
full_dict: dict = json.loads(query)
query_dict = None
aggs_dict = None
sort_dict = None
if "query" in full_dict:
query_dict = full_dict["query"]
if "aggs" in full_dict:
aggs_dict = full_dict["aggs"]
if "sort" in full_dict:
sort_dict = full_dict["sort"]
if query_dict is None and aggs_dict is None and sort_dict is None:
# Assume that there is a query but that the query part was ommitted.
query_dict = full_dict
if query_dict is None and aggs_dict is not None:
# This is an aggregation query, therefore we suppress the hits here
size = 200
logger.info(query)
# Print the query
# print(f"Executing Elasticsearch Query: {query}")
final_res = ""
retries = 0
while retries < 100:
res = es.search(
index=index_name,
from_=from_,
size=size,
query=query_dict,
aggs=aggs_dict,
sort=sort_dict,
)
if query_dict is None and aggs_dict is not None:
# When a result has aggregations, just return that and ignore the rest
final_res = str(res["aggregations"])
elif query_dict is not None and aggs_dict is not None:
# Return both hits and aggregations
final_res = str(
{
"hits": res.get("hits", {}),
"aggregations": res.get("aggregations", {}),
}
)
else:
final_res = str(res["hits"])
tokens = encoding.encode(final_res)
retries += 1
if len(tokens) > 6000:
size -= 1
else:
return final_res
except Exception as e:
logger.exception("Could not execute query %s", query)
msg = str(e)
return msg
def create_search_tool(pcs_codes):
return StructuredTool.from_function(
func=lambda index_name, query, from_, size: elastic_search(
pcs_codes=pcs_codes,
index_name=index_name,
query=query,
from_=from_,
size=size,
),
name="elastic_index_search_tool",
description=(
"""This tool allows executing queries on an Elasticsearch index efficiently. Provide:
1. index_name (string): The target Elasticsearch index.
2. query (dictionary): Defines the query structure, supporting:
a. Filters: For precise data retrieval (e.g., match, term, range).
b. Aggregations: For statistical summaries and grouping (e.g., sum, average, histogram).
c. Full-text search: For analyzing and ranking text-based results (e.g., match, multi-match, query_string).
"""
),
args_schema=SearchToolInput,
)
|