File size: 4,579 Bytes
a0e37e2
 
 
 
 
bea5044
a0e37e2
 
 
 
 
 
 
 
 
 
 
 
bea5044
a0e37e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f86d7f2
a0e37e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bea5044
 
 
 
 
 
 
 
 
a0e37e2
 
bea5044
a0e37e2
 
 
 
 
 
 
 
 
 
 
 
 
f86d7f2
a0e37e2
f86d7f2
 
 
 
 
 
 
bea5044
 
 
 
 
 
 
 
 
 
 
a0e37e2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import logging
import json

import tiktoken
from elasticsearch import Elasticsearch

# from pydantic.v1 import BaseModel, Field  # <-- Uses v1 namespace
from pydantic import BaseModel, Field
from langchain.tools import StructuredTool

from ask_candid.base.config.connections import SEMANTIC_ELASTIC_QA

logging.basicConfig(level="INFO")
logger = logging.getLogger("elasticsearch_playground")
es = Elasticsearch(
    cloud_id=SEMANTIC_ELASTIC_QA.cloud_id,
    api_key=SEMANTIC_ELASTIC_QA.api_key,
    verify_certs=True,
    request_timeout=60 * 3,
)


class SearchToolInput(BaseModel):
    """Input for the index show data tool."""

    index_name: str = Field(
        ..., description="The name of the index for which the data is to be retrieved"
    )
    query: str = Field(
        ...,
        description="The ElasticSearch JSON query used to filter all hits. Should use the _source field if possible to specify required fields.",
    )
    from_: int = Field(
        ..., description="The record index from which the query will start"
    )
    size: int = Field(
        ...,
        description="How many records will be retrieved from the ElasticSearch query",
    )


def elastic_search(
    pcs_codes: dict,
    index_name: str,
    query: str,
    from_: int = 0,
    size: int = 20,
):
    """Executes a specific query on an ElasticSearch index and returns all hits or aggregation results"""
    size = min(50, size)
    encoding = tiktoken.encoding_for_model("gpt-4")
    try:
        full_dict: dict = json.loads(query)
        query_dict = None
        aggs_dict = None
        sort_dict = None
        if "query" in full_dict:
            query_dict = full_dict["query"]
        if "aggs" in full_dict:
            aggs_dict = full_dict["aggs"]
        if "sort" in full_dict:
            sort_dict = full_dict["sort"]
        if query_dict is None and aggs_dict is None and sort_dict is None:
            # Assume that there is a query but that the query part was ommitted.
            query_dict = full_dict
        if query_dict is None and aggs_dict is not None:
            # This is an aggregation query, therefore we suppress the hits here
            size = 200
        logger.info(query)
        # Print the query
        # print(f"Executing Elasticsearch Query: {query}")
        final_res = ""
        retries = 0
        while retries < 100:
            res = es.search(
                index=index_name,
                from_=from_,
                size=size,
                query=query_dict,
                aggs=aggs_dict,
                sort=sort_dict,
            )
            if query_dict is None and aggs_dict is not None:
                # When a result has aggregations, just return that and ignore the rest
                final_res = str(res["aggregations"])
            elif query_dict is not None and aggs_dict is not None:
                # Return both hits and aggregations
                final_res = str(
                    {
                        "hits": res.get("hits", {}),
                        "aggregations": res.get("aggregations", {}),
                    }
                )

            else:
                final_res = str(res["hits"])

            tokens = encoding.encode(final_res)
            retries += 1
            if len(tokens) > 6000:
                size -= 1
            else:
                return final_res

    except Exception as e:
        logger.exception("Could not execute query %s", query)
        msg = str(e)
        return msg


def create_search_tool(pcs_codes):
    return StructuredTool.from_function(
        func=lambda index_name, query, from_, size: elastic_search(
            pcs_codes=pcs_codes,
            index_name=index_name,
            query=query,
            from_=from_,
            size=size,
        ),
        name="elastic_index_search_tool",
        description=(
            """This tool allows executing queries on an Elasticsearch index efficiently. Provide:
                        1. index_name (string): The target Elasticsearch index.
                        2. query (dictionary): Defines the query structure, supporting:
                            a. Filters: For precise data retrieval (e.g., match, term, range).
                            b. Aggregations: For statistical summaries and grouping (e.g., sum, average, histogram).
                            c. Full-text search: For analyzing and ranking text-based results (e.g., match, multi-match, query_string).
                        """
        ),
        args_schema=SearchToolInput,
    )