amharic-srh-chatbot / chatbot_utils.py
Walelign's picture
Upload 5 files
9fcadef verified
raw
history blame
1 kB
import pandas as pd
import faiss
from sentence_transformers import SentenceTransformer
import numpy as np
class AmharicChatbot:
def __init__(self, csv_path):
self.df = pd.read_csv(csv_path)
self.model = SentenceTransformer("intfloat/multilingual-e5-small")
self.build_index()
def build_index(self):
self.embeddings = self.model.encode(
["passage: " + q for q in self.df["question"].tolist()],
show_progress_bar=True
)
self.index = faiss.IndexFlatL2(self.embeddings.shape[1])
self.index.add(np.array(self.embeddings))
def get_answer(self, query, top_k=3):
query_embedding = self.model.encode([f"query: {query}"])
D, I = self.index.search(np.array(query_embedding), top_k)
results = []
for idx in I[0]:
question = self.df.iloc[idx]["question"]
answer = self.df.iloc[idx]["answer"]
results.append((question, answer))
return results