File size: 5,328 Bytes
37b6839
 
 
 
 
 
 
 
 
 
 
 
e83efea
9153d1a
e83efea
 
37b6839
 
 
3bfbb8d
37b6839
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0a3d21d
5478ed1
37b6839
 
 
5478ed1
37b6839
 
 
 
5478ed1
37b6839
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
from src.embeddings_model import GEmbeddings
from src.text_generation_model import GLLM
from src.pinecone_index import PineconeIndex

from typing import Dict, List, Any, Union
import datetime
import asyncio

from llama_index.core.evaluation import SemanticSimilarityEvaluator
from llama_index.core.base.embeddings.base import SimilarityMode

prompt_template = """
<system instruction>
You are Gerard Lee, a data enthusiast with 6 years of experience in the field and humble about his success. Imagine you are in a conversation with someone who interested in your portfolio.
Reply as faifhfully as possible and in no more than 5 complete sentences unless the <user query> requests to elaborate in details. Use contents from <context> only without prior knowledge except referring to <chat history> for seamless conversatation.
</system instruction>

<chat history>
{context_history}
</chat history>

<context>
{context_from_index}
</context>

<user query>
{user_query}
</user query>
"""

class GLlamaIndex():
    def __init__(
        self,
        logger,
        emb_model: GEmbeddings,
        text_model: GLLM,
        index: PineconeIndex,
        similarity_threshold: float
    ) -> None:
        self.logger = logger
        self.emb_model = emb_model
        self.llm = text_model
        self.index = index
        self.evaluator = self._set_evaluator(similarity_threshold)
        self.prompt_template = prompt_template
    
    def _set_evaluator(self, similarity_threshold: float) -> SemanticSimilarityEvaluator:
        sem_evaluator = SemanticSimilarityEvaluator(
            similarity_mode=SimilarityMode.DEFAULT,
            similarity_threshold=similarity_threshold,
        )
        return sem_evaluator

    def format_history(self, history: List[str]) -> str:
        return "\n".join(list(filter(None, history)))

    async def aget_context_with_history(
        self,
        query: str,
        history: List[str]
    ) -> str:
        if not history:
            result = await self.index.retrieve_context(query)
            return result["result"]
            
        extended_query = f"{self.format_history(history[-2:])}\n{query}"

        results = await self.index.aretrieve_context_multi(
            [query, extended_query]
        )
        self.logger.info(f"retrieval results: {results}")
        eval_results = await self.aevaluate_context_multi(
            [query, extended_query],
            [r["result"] for r in results]
        )
        self.logger.info(f"eval results: {eval_results}")
        return results[0]["result"] if eval_results[0].score > eval_results[1].score \
            else results[1]["result"]
    
    async def aevaluate_context(
        self,
        query: str,
        returned_context: str
    ) -> Dict[str, Any]:
        result = await self.evaluator.aevaluate(
            response=returned_context,
            reference=query,
        )
        return result

    async def aevaluate_context_multi(
        self,
        query_list: List[str],
        returned_context_list: List[str]
    ) -> List[Dict]:
        result = await asyncio.gather(*(self.aevaluate_context(query, returned_context) for query, returned_context in zip(query_list, returned_context_list)))
        return result

    def generate_text(
        self,
        query: str,
        history: List[str],
    ) -> str:
        # get chat history
        context_history = self.format_history(history=history)

        # get retrieval context(s) from llama-index vectorstore index
        try:
            # without history, single context retrieval without evaluation
            if not history:
                # w&b trace retrieval context
                result_query_only = self.index.retrieve_context(query)
                context_from_index_selected = result_query_only["result"]

            # with history, multiple context retrieval with async, then evaluation to determine which context to choose
            else:
                context_from_index_selected = asyncio.run(self.aget_context_with_history(query=query, history=history))

        except Exception as e:
            self.logger.error(f"Exception {e} occured when retriving context\n")

            llm_end_time_ms = round(datetime.datetime.now().timestamp() * 1000)
            result = "Something went wrong. Please try again later."
            return result
        
        self.logger.info(f"Context from Llama-Index:\n{context_from_index_selected}\n")

        # generate text with prompt template to roleplay myself
        prompt_with_context = self.prompt_template.format(context_history=context_history, context_from_index=context_from_index_selected, user_query=query)
        try:
            result = self.llm.gai_generate_content(
                prompt=prompt_with_context,
                temperature=0.5,
            )
            success_flag = "success"
            if result is None:
                result = "Seems something went wrong. Please try again later."
                self.logger.error(f"Result with 'None' received\n")
                success_flag = "fail"

        except Exception as e:
            result = "Seems something went wrong. Please try again later."
            self.logger.error(f"Exception {e} occured\n")
            success_flag = "fail"

        return result