|
"""Retrieval component for the RAG system.""" |
|
|
|
import faiss |
|
import numpy as np |
|
from typing import List, Dict, Tuple |
|
from elasticsearch import Elasticsearch |
|
from transformers import AutoTokenizer, AutoModel |
|
import torch |
|
|
|
class FinancialDataRetriever: |
|
def __init__(self, config: Dict): |
|
"""Initialize the retriever with configuration.""" |
|
self.retriever_type = config['rag']['retriever'] |
|
self.max_documents = config['rag']['max_documents'] |
|
self.similarity_threshold = config['rag']['similarity_threshold'] |
|
|
|
|
|
self.dimension = 768 |
|
self.index = faiss.IndexFlatL2(self.dimension) |
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased') |
|
self.model = AutoModel.from_pretrained('bert-base-uncased') |
|
|
|
|
|
if self.retriever_type == "elasticsearch": |
|
self.es = Elasticsearch() |
|
|
|
def encode_text(self, texts: List[str]) -> np.ndarray: |
|
"""Encode text using BERT.""" |
|
tokens = self.tokenizer(texts, padding=True, truncation=True, |
|
return_tensors="pt", max_length=512) |
|
with torch.no_grad(): |
|
outputs = self.model(**tokens) |
|
embeddings = outputs.last_hidden_state[:, 0, :].numpy() |
|
return embeddings |
|
|
|
def index_documents(self, documents: List[Dict]): |
|
"""Index documents for retrieval.""" |
|
if self.retriever_type == "faiss": |
|
texts = [doc['text'] for doc in documents] |
|
embeddings = self.encode_text(texts) |
|
self.index.add(embeddings) |
|
self.documents = documents |
|
else: |
|
for doc in documents: |
|
self.es.index(index="financial_data", document=doc) |
|
|
|
def retrieve(self, query: str, k: int = None) -> List[Dict]: |
|
"""Retrieve relevant documents.""" |
|
k = k or self.max_documents |
|
query_embedding = self.encode_text([query]) |
|
|
|
if self.retriever_type == "faiss": |
|
distances, indices = self.index.search(query_embedding, k) |
|
results = [ |
|
{ |
|
'document': self.documents[idx], |
|
'score': float(1 / (1 + dist)) |
|
} |
|
for dist, idx in zip(distances[0], indices[0]) |
|
if 1 / (1 + dist) >= self.similarity_threshold |
|
] |
|
else: |
|
response = self.es.search( |
|
index="financial_data", |
|
query={ |
|
"match": { |
|
"text": query |
|
} |
|
}, |
|
size=k |
|
) |
|
results = [ |
|
{ |
|
'document': hit['_source'], |
|
'score': hit['_score'] |
|
} |
|
for hit in response['hits']['hits'] |
|
if hit['_score'] >= self.similarity_threshold |
|
] |
|
|
|
return results |
|
|
|
|