Spaces:
Running
Running
from typing import Type, Optional, List | |
import logging | |
from pydantic import BaseModel, Field | |
from elasticsearch import Elasticsearch | |
from langchain.callbacks.manager import AsyncCallbackManagerForToolRun | |
from langchain.tools.base import BaseTool | |
from ask_candid.base.config.connections import SEMANTIC_ELASTIC_QA | |
logging.basicConfig(level="INFO") | |
logger = logging.getLogger("elasticsearch_playground") | |
es = Elasticsearch( | |
cloud_id=SEMANTIC_ELASTIC_QA.cloud_id, | |
api_key=SEMANTIC_ELASTIC_QA.api_key, | |
verify_certs=True, | |
request_timeout=60 * 3 | |
) | |
class ListIndicesInput(BaseModel): | |
"""Input for the list indices tool.""" | |
separator: str = Field(..., description="Separator for the list of indices") | |
class ListIndicesTool(BaseTool): | |
"""Tool for getting all ElasticSearch indices.""" | |
name: str = "elastic_list_indices" # Added type annotation | |
description: str = ( | |
"Input is a delimiter like comma or new line. Output is a separated list of indices in the database. " | |
"Always use this tool to get to know the indices in the ElasticSearch cluster." | |
) | |
args_schema: Optional[Type[BaseModel]] = ( | |
ListIndicesInput # Define this before methods | |
) | |
def _run(self, separator: str) -> str: | |
"""Get all indices in the Elasticsearch server, usually separated by a line break.""" | |
try: | |
# Ensure that `es` is correctly initialized before calling this method | |
indices: List[str] = es.cat.indices(h="index", s="index").split() | |
# Filter out hidden indices starting with a dot | |
return separator.join( | |
[index for index in indices if not index.startswith(".")] | |
) | |
except Exception as e: | |
logger.exception("Could not list indices: %s", e) | |
return "" | |
async def _arun( | |
self, | |
separator: str = "", | |
run_manager: Optional[AsyncCallbackManagerForToolRun] = None, | |
) -> str: | |
raise NotImplementedError("ListIndicesTool does not support async operations") | |