File size: 7,389 Bytes
a557e27
e1a76a5
a96f5df
559be66
f5b422d
c83187d
 
98ea658
c83187d
a557e27
 
 
 
 
 
c83187d
 
 
 
 
 
 
 
 
 
 
 
 
 
e1a76a5
 
81258cc
 
17aca0b
81258cc
7f4ceb4
c83187d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dd11e70
f187dab
 
81258cc
 
90905a7
956f515
c28a6ea
f4bf6f7
956f515
81258cc
 
 
ef03b1a
81258cc
ef03b1a
f187dab
a557e27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c83187d
6163039
c83187d
 
65c0ad3
c83187d
 
 
 
65c0ad3
c83187d
 
 
 
 
 
 
 
c28a6ea
 
 
 
aec5f52
 
a557e27
 
 
 
 
cc17af3
a557e27
 
 
 
 
 
2f80339
a557e27
c83187d
a557e27
 
 
 
 
 
 
 
 
 
 
faa90d7
65c0ad3
6b6aacf
81258cc
6b6aacf
 
00dfa0c
81258cc
c28a6ea
a557e27
 
ae7bc30
 
 
a557e27
 
 
907d604
a557e27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f187dab
c28a6ea
 
a557e27
b15d991
c28a6ea
 
 
89db9ff
46adbae
e0eb9d0
a557e27
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
from langchain.prompts.prompt import PromptTemplate
from langchain.llms import OpenAIChat
from langchain.chains import ConversationalRetrievalChain, LLMChain
from langchain.chains.question_answering import load_qa_chain
from langchain.embeddings import HuggingFaceEmbeddings, HuggingFaceInstructEmbeddings
from langchain.callbacks import StdOutCallbackHandler
from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_PROMPT
from langchain.vectorstores import FAISS
from langchain.memory import ConversationBufferMemory
import os
from typing import Optional, Tuple
import gradio as gr
import pickle
from threading import Lock

from langchain.prompts.chat import (
    ChatPromptTemplate,
    SystemMessagePromptTemplate,
    AIMessagePromptTemplate,
    HumanMessagePromptTemplate,
)
from langchain.schema import (
    AIMessage,
    HumanMessage,
    SystemMessage
)

from langchain.prompts import PromptTemplate

prefix_messages = [{"role": "system", "content": "You are a helpful assistant that is very good at answering questions about investments using the information given."}]

site_options = {'US': 'vanguard_embeddings_US',
                'AUS': 'vanguard_embeddings'}

site_options_list = list(site_options.keys())

memory = ConversationBufferMemory(memory_key='chat_history', return_messages=True, output_key='answer')

def load_prompt():

    system_template="""Use only the following pieces of context that has been scraped from a website to answer the users question accurately.
    Do not use any information not provided in the website context.
    If you don't know the answer, just say 'There is no relevant answer in the Investor Website', 
    don't try to make up an answer.
    
    ALWAYS return a "SOURCES" part in your answer.
    The "SOURCES" part should be a reference to the source of the document from which you got your answer.
    
    Remember, do not reference any information not given in the context.
    If the answer is not available in the given context just say 'There is no relevant answer in the website content'
    
    Follow the below format when answering:
    
    Question: {question}
    SOURCES: [xyz]
      
    Begin!
    ----------------
    {context}"""
    
    messages = [
        SystemMessagePromptTemplate.from_template(system_template),
        HumanMessagePromptTemplate.from_template("{question}")
    ]
    prompt = ChatPromptTemplate.from_messages(messages)
        
    return prompt
    
def load_vectorstore(site):
    '''load embeddings and vectorstore'''
        
    emb = HuggingFaceEmbeddings(model_name="all-mpnet-base-v2")
    
    return FAISS.load_local(site_options[site], emb)

#default embeddings and store
vectorstore = load_vectorstore(site_options_list[0])

def on_value_change(site):
    '''When radio changes, change the website reference data'''
    
    global vectorstore
    vectorstore = load_vectorstore(site)
    
# vectorstore = load_vectorstore('vanguard-embeddings',sbert_emb)
    
_template = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question.
You can assume the question about investing and the investment management industry.
Chat History:
{chat_history}
Follow Up Input: {question}
Standalone question:"""
CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template)

template = """You are an AI assistant for answering questions about investing and the investment management industry.
You are given the following extracted parts of a long document and a question. Provide a conversational answer.
If you don't know the answer, just say "Hmm, I'm not sure." Don't try to make up an answer.
If the question is not about investing, politely inform them that you are tuned to only answer questions about investing and the investment management industry.
Question: {question}
=========
{context}
=========
Answer in Markdown:"""
QA_PROMPT = PromptTemplate(template=template, input_variables=["question", "context"])


def get_chain(vectorstore):
    llm = OpenAIChat(streaming=True, 
                     callbacks=[StdOutCallbackHandler()], 
                     verbose=True, 
                     temperature=0,
                     model_name='gpt-4')

    question_generator = LLMChain(llm=llm, prompt=CONDENSE_QUESTION_PROMPT)
    doc_chain = load_qa_chain(llm=llm,chain_type="stuff",prompt=load_prompt())

    chain = ConversationalRetrievalChain(retriever=vectorstore.as_retriever(search_kwags={"k": 4}), 
                                     question_generator=question_generator, 
                                     combine_docs_chain=doc_chain, 
                                     memory=memory, 
                                     return_source_documents=True, 
                                     get_chat_history=lambda h :h)

    
    return chain

def load_chain():
    chain = get_chain(vectorstore)
    return chain
    
    
class ChatWrapper:

    def __init__(self):
        self.lock = Lock()
    def __call__(
        self, inp: str, history: Optional[Tuple[str, str]], chain
    ):
        """Execute the chat functionality."""
        self.lock.acquire()
        try:
            history = history or []
            # Set OpenAI key
            # chain = get_chain(vectorstore)
            # Run chain and append input.
            output = chain({"question": inp})["answer"]
            history.append((inp, output))
        except Exception as e:
            raise e
        finally:
            self.lock.release()
        return history, history

block = gr.Blocks(css=".gradio-container {background-color: lightgray}")

with block:
    with gr.Row():
        gr.Markdown("<h3><center>Chat-Your-Data (Investor Education)</center></h3>")
        embed_but = gr.Button(value='Step 1: Click Me to Load the QA System')
    with gr.Row():
        websites = gr.Radio(choices=site_options_list,value=site_options_list[0],label='Select US or AUS website data',
                interactive=True)
        websites.change(on_value_change, websites)
    
    vectorstore = load_vectorstore(websites.value)
    
    chatbot = gr.Chatbot()

    chat = ChatWrapper()


    with gr.Row():
        message = gr.Textbox(
            label="What's your question?",
            placeholder="Ask questions about Investing",
            lines=1,
        )
        submit = gr.Button(value="Send", variant="secondary").style(full_width=False)

    gr.Examples(
        examples=[
            "What are the benefits of investing in ETFs?",
            "What is the average cost of investing in a managed fund?",
            "At what age can I start investing?",
            "Do you offer investment accounts for kids?"
        ],
        inputs=message,
    )

    gr.HTML("Demo application of a LangChain chain.")

    gr.HTML(
        "<center>Powered by <a href='https://github.com/hwchase17/langchain'>LangChain πŸ¦œοΈπŸ”—</a></center>"
    )

    state = gr.State()
    agent_state = gr.State()

    
    submit.click(chat, inputs=[message, state, agent_state], outputs=[chatbot, state])
    message.submit(chat, inputs=[message, state, agent_state], outputs=[chatbot, state])

    embed_but.click(
        load_chain,
        outputs=[agent_state],
    )

    gr.Markdown("![](https://komarev.com/ghpvc/?username=nickmuchi87&style=flat-square)")

block.launch(debug=True)