from typing import Type, Optional, List import logging from pydantic import BaseModel, Field from elasticsearch import Elasticsearch from langchain.callbacks.manager import AsyncCallbackManagerForToolRun from langchain.tools.base import BaseTool from ask_candid.base.config.connections import SEMANTIC_ELASTIC_QA logging.basicConfig(level="INFO") logger = logging.getLogger("elasticsearch_playground") es = Elasticsearch( cloud_id=SEMANTIC_ELASTIC_QA.cloud_id, api_key=SEMANTIC_ELASTIC_QA.api_key, verify_certs=True, request_timeout=60 * 3 ) class ListIndicesInput(BaseModel): """Input for the list indices tool.""" separator: str = Field(..., description="Separator for the list of indices") class ListIndicesTool(BaseTool): """Tool for getting all ElasticSearch indices.""" name: str = "elastic_list_indices" # Added type annotation description: str = ( "Input is a delimiter like comma or new line. Output is a separated list of indices in the database. " "Always use this tool to get to know the indices in the ElasticSearch cluster." ) args_schema: Optional[Type[BaseModel]] = ( ListIndicesInput # Define this before methods ) def _run(self, separator: str) -> str: """Get all indices in the Elasticsearch server, usually separated by a line break.""" try: # Ensure that `es` is correctly initialized before calling this method indices: List[str] = es.cat.indices(h="index", s="index").split() # Filter out hidden indices starting with a dot return separator.join( [index for index in indices if not index.startswith(".")] ) except Exception as e: logger.exception("Could not list indices: %s", e) return "" async def _arun( self, separator: str = "", run_manager: Optional[AsyncCallbackManagerForToolRun] = None, ) -> str: raise NotImplementedError("ListIndicesTool does not support async operations")