File size: 4,074 Bytes
7f0844d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
from typing import Optional, List, Tuple
from langchain.docstore.document import Document as LangchainDocument
from rank_bm25 import BM25Okapi
from langchain_community.vectorstores import FAISS
from ragatouille import RAGPretrainedModel
from litellm import completion
import os
import retriver
import config


class RAGAnswerGenerator:
    def __init__(self, docs: List[LangchainDocument], bm25: BM25Okapi, knowledge_index: FAISS, reranker: Optional[RAGPretrainedModel] = None):
        self.bm25 = bm25
        self.knowledge_index = knowledge_index
        self.docs = docs
        self.reranker = reranker
        self.llm_key = os.environ['GROQ_API_KEY']

    def retrieve_documents(
        self,
        question: str,
        num_retrieved_docs: int,
        bm_25_flag: bool,
        semantic_flag: bool
    ) -> List[str]:
        print("=> Retrieving documents...")
        relevant_docs = []

        if bm_25_flag or semantic_flag:
            result = retriver.search(
                self.docs,
                self.bm25,
                self.knowledge_index,
                question,
                use_bm25=bm_25_flag,
                use_semantic_search=semantic_flag,
                top_k=num_retrieved_docs
            )
            if bm_25_flag and semantic_flag:
                relevant_docs = [doc.page_content for doc in result]
                return relevant_docs
            elif bm_25_flag:
                relevant_docs = result
                return relevant_docs
            elif semantic_flag:
                relevant_docs = [doc.page_content for doc in result]
                return relevant_docs
                

    def rerank_documents(self, question: str, documents: List[str], num_docs_final: int) -> List[str]:
        if self.reranker and documents:
            print("=> Reranking documents...")
            reranked_docs = self.reranker.rerank(question, documents, k=num_docs_final)
            return [doc["content"] for doc in reranked_docs]
        return documents[:num_docs_final]

    def format_context(self, documents: List[str]) -> str:
        if not documents:
            return "No retrieved documents available."
        return "\n".join([f"[{i + 1}] {doc}" for i, doc in enumerate(documents)])

    def generate_answer(
        self,
        question: str,
        context: str,
        temperature: float,
    ) -> str:
        print("=> Generating answer...")
        if context.strip() == "No retrieved documents available.":
            response = completion(
                model="groq/llama3-8b-8192",
                messages=[
                    {"role": "system", "content": config.LLM_ONLY_PROMPT},
                    {"role": "user", "content": f"Question: {question}"}
                ],
                api_key=self.llm_key,
                temperature=temperature
            )
        else:
            response = completion(
                model="groq/llama3-8b-8192",
                messages=[
                    {"role": "system", "content": config.RAG_PROMPT},
                    {"role": "user", "content": f""" Context: {context} Question: {question} """}
                ],
                api_key=self.llm_key,
                temperature=temperature
            )
        return response.get("choices", [{}])[0].get("message", {}).get("content", "No response content found")

    def answer(self, question: str, temperature: float, num_retrieved_docs: int = 30, num_docs_final: int = 5, bm_25_flag=True, semantic_flag=True) -> Tuple[str, List[str]]:
        relevant_docs = self.retrieve_documents(question, num_retrieved_docs, bm_25_flag, semantic_flag)
        print(len(relevant_docs))
        relevant_docs = self.rerank_documents(question, relevant_docs, num_docs_final)
        print(len(relevant_docs))
        context = self.format_context(relevant_docs)
        answer = self.generate_answer(question, context, temperature)
        document_list = [f"[{i + 1}] {doc}" for i, doc in enumerate(relevant_docs)] if relevant_docs else []
        return answer, document_list