Spaces:
Sleeping
Sleeping
"""Retrieval component for the RAG system.""" | |
import faiss | |
import numpy as np | |
from typing import List, Dict, Tuple | |
from elasticsearch import Elasticsearch | |
from transformers import AutoTokenizer, AutoModel | |
import torch | |
class FinancialDataRetriever: | |
def __init__(self, config: Dict): | |
"""Initialize the retriever with configuration.""" | |
self.retriever_type = config['rag']['retriever'] | |
self.max_documents = config['rag']['max_documents'] | |
self.similarity_threshold = config['rag']['similarity_threshold'] | |
# Initialize FAISS index | |
self.dimension = 768 # BERT embedding dimension | |
self.index = faiss.IndexFlatL2(self.dimension) | |
# Initialize transformer model for embeddings | |
self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased') | |
self.model = AutoModel.from_pretrained('bert-base-uncased') | |
# Initialize Elasticsearch if needed | |
if self.retriever_type == "elasticsearch": | |
self.es = Elasticsearch() | |
def encode_text(self, texts: List[str]) -> np.ndarray: | |
"""Encode text using BERT.""" | |
tokens = self.tokenizer(texts, padding=True, truncation=True, | |
return_tensors="pt", max_length=512) | |
with torch.no_grad(): | |
outputs = self.model(**tokens) | |
embeddings = outputs.last_hidden_state[:, 0, :].numpy() | |
return embeddings | |
def index_documents(self, documents: List[Dict]): | |
"""Index documents for retrieval.""" | |
if self.retriever_type == "faiss": | |
texts = [doc['text'] for doc in documents] | |
embeddings = self.encode_text(texts) | |
self.index.add(embeddings) | |
self.documents = documents | |
else: | |
for doc in documents: | |
self.es.index(index="financial_data", document=doc) | |
def retrieve(self, query: str, k: int = None) -> List[Dict]: | |
"""Retrieve relevant documents.""" | |
k = k or self.max_documents | |
query_embedding = self.encode_text([query]) | |
if self.retriever_type == "faiss": | |
distances, indices = self.index.search(query_embedding, k) | |
results = [ | |
{ | |
'document': self.documents[idx], | |
'score': float(1 / (1 + dist)) | |
} | |
for dist, idx in zip(distances[0], indices[0]) | |
if 1 / (1 + dist) >= self.similarity_threshold | |
] | |
else: | |
response = self.es.search( | |
index="financial_data", | |
query={ | |
"match": { | |
"text": query | |
} | |
}, | |
size=k | |
) | |
results = [ | |
{ | |
'document': hit['_source'], | |
'score': hit['_score'] | |
} | |
for hit in response['hits']['hits'] | |
if hit['_score'] >= self.similarity_threshold | |
] | |
return results | |