GGroenendaal's picture
add es retriever
af461f3
raw
history blame
1.33 kB
from datasets import DatasetDict, load_dataset
from src.utils.log import get_logger
from src.retrievers.base_retriever import Retriever
from elasticsearch import Elasticsearch
from dotenv import load_dotenv
import os
load_dotenv()
logger = get_logger()
class ESRetriever(Retriever):
def __init__(self, dataset: DatasetDict) -> None:
self.dataset = dataset["train"]
es_host = os.getenv("ELASTIC_HOST", "localhost")
es_password = os.getenv("ELASTIC_PASSWORD")
es_username = os.getenv("ELASTIC_USERNAME")
self.client = Elasticsearch(
hosts=[es_host], http_auth=(es_username, es_password))
if self.client.indices.exists(index="paragraphs"):
self.dataset.load_elasticsearch_index(
"paragraphs", es_index_name="paragraphs", es_client=self.client)
else:
logger.info(f"Creating index 'paragraphs' on {es_host}")
self.dataset.add_elasticsearch_index(column="text",
index_name="paragraphs",
es_index_name="paragraphs",
es_client=self.client)
def retrieve(self, query: str, k: int = 5):
return self.dataset.get_nearest_examples("paragraphs", query, k)