Spaces:
Sleeping
Sleeping
File size: 3,126 Bytes
754afec |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 |
"""Retrieval component for the RAG system."""
import faiss
import numpy as np
from typing import List, Dict, Tuple
from elasticsearch import Elasticsearch
from transformers import AutoTokenizer, AutoModel
import torch
class FinancialDataRetriever:
def __init__(self, config: Dict):
"""Initialize the retriever with configuration."""
self.retriever_type = config['rag']['retriever']
self.max_documents = config['rag']['max_documents']
self.similarity_threshold = config['rag']['similarity_threshold']
# Initialize FAISS index
self.dimension = 768 # BERT embedding dimension
self.index = faiss.IndexFlatL2(self.dimension)
# Initialize transformer model for embeddings
self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
self.model = AutoModel.from_pretrained('bert-base-uncased')
# Initialize Elasticsearch if needed
if self.retriever_type == "elasticsearch":
self.es = Elasticsearch()
def encode_text(self, texts: List[str]) -> np.ndarray:
"""Encode text using BERT."""
tokens = self.tokenizer(texts, padding=True, truncation=True,
return_tensors="pt", max_length=512)
with torch.no_grad():
outputs = self.model(**tokens)
embeddings = outputs.last_hidden_state[:, 0, :].numpy()
return embeddings
def index_documents(self, documents: List[Dict]):
"""Index documents for retrieval."""
if self.retriever_type == "faiss":
texts = [doc['text'] for doc in documents]
embeddings = self.encode_text(texts)
self.index.add(embeddings)
self.documents = documents
else:
for doc in documents:
self.es.index(index="financial_data", document=doc)
def retrieve(self, query: str, k: int = None) -> List[Dict]:
"""Retrieve relevant documents."""
k = k or self.max_documents
query_embedding = self.encode_text([query])
if self.retriever_type == "faiss":
distances, indices = self.index.search(query_embedding, k)
results = [
{
'document': self.documents[idx],
'score': float(1 / (1 + dist))
}
for dist, idx in zip(distances[0], indices[0])
if 1 / (1 + dist) >= self.similarity_threshold
]
else:
response = self.es.search(
index="financial_data",
query={
"match": {
"text": query
}
},
size=k
)
results = [
{
'document': hit['_source'],
'score': hit['_score']
}
for hit in response['hits']['hits']
if hit['_score'] >= self.similarity_threshold
]
return results
|