File size: 3,126 Bytes
754afec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
"""Retrieval component for the RAG system."""

import faiss
import numpy as np
from typing import List, Dict, Tuple
from elasticsearch import Elasticsearch
from transformers import AutoTokenizer, AutoModel
import torch

class FinancialDataRetriever:
    def __init__(self, config: Dict):
        """Initialize the retriever with configuration."""
        self.retriever_type = config['rag']['retriever']
        self.max_documents = config['rag']['max_documents']
        self.similarity_threshold = config['rag']['similarity_threshold']
        
        # Initialize FAISS index
        self.dimension = 768  # BERT embedding dimension
        self.index = faiss.IndexFlatL2(self.dimension)
        
        # Initialize transformer model for embeddings
        self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
        self.model = AutoModel.from_pretrained('bert-base-uncased')
        
        # Initialize Elasticsearch if needed
        if self.retriever_type == "elasticsearch":
            self.es = Elasticsearch()
            
    def encode_text(self, texts: List[str]) -> np.ndarray:
        """Encode text using BERT."""
        tokens = self.tokenizer(texts, padding=True, truncation=True, 
                              return_tensors="pt", max_length=512)
        with torch.no_grad():
            outputs = self.model(**tokens)
            embeddings = outputs.last_hidden_state[:, 0, :].numpy()
        return embeddings
        
    def index_documents(self, documents: List[Dict]):
        """Index documents for retrieval."""
        if self.retriever_type == "faiss":
            texts = [doc['text'] for doc in documents]
            embeddings = self.encode_text(texts)
            self.index.add(embeddings)
            self.documents = documents
        else:
            for doc in documents:
                self.es.index(index="financial_data", document=doc)
                
    def retrieve(self, query: str, k: int = None) -> List[Dict]:
        """Retrieve relevant documents."""
        k = k or self.max_documents
        query_embedding = self.encode_text([query])
        
        if self.retriever_type == "faiss":
            distances, indices = self.index.search(query_embedding, k)
            results = [
                {
                    'document': self.documents[idx],
                    'score': float(1 / (1 + dist))
                }
                for dist, idx in zip(distances[0], indices[0])
                if 1 / (1 + dist) >= self.similarity_threshold
            ]
        else:
            response = self.es.search(
                index="financial_data",
                query={
                    "match": {
                        "text": query
                    }
                },
                size=k
            )
            results = [
                {
                    'document': hit['_source'],
                    'score': hit['_score']
                }
                for hit in response['hits']['hits']
                if hit['_score'] >= self.similarity_threshold
            ]
            
        return results