File size: 5,664 Bytes
a557e27
e1a76a5
a557e27
f5b422d
4d35513
 
98ea658
a557e27
 
 
 
 
 
e1a76a5
 
f187dab
32bf62b
17aca0b
7f4ceb4
 
f187dab
 
17aca0b
f187dab
 
 
 
 
 
 
 
 
 
9ad6ced
956f515
9855b70
ef03b1a
956f515
ef03b1a
9855b70
ef03b1a
9855b70
ef03b1a
f187dab
a557e27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e1a76a5
 
a557e27
 
 
 
 
 
 
aec5f52
 
a557e27
0ece682
a557e27
 
 
 
cc17af3
a557e27
 
 
 
 
 
f55cae9
a557e27
 
 
 
 
 
 
 
 
 
 
 
 
faa90d7
a557e27
d456856
7f4ceb4
 
6f4e437
00dfa0c
4f8e4c3
f187dab
a557e27
 
ae7bc30
 
 
a557e27
 
 
907d604
a557e27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f187dab
6781233
 
a557e27
99992c7
89db9ff
e0eb9d0
 
a557e27
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
from langchain.prompts.prompt import PromptTemplate
from langchain.llms import OpenAIChat
from langchain.chains import ChatVectorDBChain
from langchain.embeddings import HuggingFaceEmbeddings, HuggingFaceInstructEmbeddings
from langchain.callbacks.base import CallbackManager
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.vectorstores import FAISS
import os
from typing import Optional, Tuple
import gradio as gr
import pickle
from threading import Lock

prefix_messages = [{"role": "system", "content": "You are a helpful assistant that is very good at answering questions about investments using the information given."}]

model_options = {'all-mpnet-base-v2': "sentence-transformers/all-mpnet-base-v2",
                'instructor-base': "hkunlp/instructor-base"}

model_options_list = list(model_options.keys())

def load_vectorstore(model):
    '''load embeddings and vectorstore'''

    if 'mpnet' in model:
        
        emb = HuggingFaceEmbeddings(model_name=model)
        return FAISS.load_local('vanguard-embeddings', emb)

    elif 'instructor'in model:
        
        emb = HuggingFaceInstructEmbeddings(model_name=model,
                                               query_instruction='Represent the Financial question for retrieving supporting paragraphs: ',
                                               embed_instruction='Represent the Financial paragraph for retrieval: ')
        return FAISS.load_local('vanguard_embeddings_inst', emb)

#default embeddings
vectorstore = load_vectorstore(model_options['all-mpnet-base-v2'])

def on_value_change(change):
    '''When radio changes, change the embeddings'''
    global vectorstore
    vectorstore = load_vectorstore(model_options[change])
    
# vectorstore = load_vectorstore('vanguard-embeddings',sbert_emb)
    
_template = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question.
You can assume the question about investing and the investment management industry.
Chat History:
{chat_history}
Follow Up Input: {question}
Standalone question:"""
CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template)

template = """You are an AI assistant for answering questions about investing and the investment management industry.
You are given the following extracted parts of a long document and a question. Provide a conversational answer.
If you don't know the answer, just say "Hmm, I'm not sure." Don't try to make up an answer.
If the question is not about investing, politely inform them that you are tuned to only answer questions about investing and the investment management industry.
Question: {question}
=========
{context}
=========
Answer in Markdown:"""
QA_PROMPT = PromptTemplate(template=template, input_variables=["question", "context"])


def get_chain(vectorstore):
    llm = OpenAIChat(streaming=True, callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]), verbose=True, temperature=0,\
                    prefix_messages=prefix_messages)
    qa_chain = ChatVectorDBChain.from_llm(
        llm,
        vectorstore,
        qa_prompt=QA_PROMPT,
        condense_question_prompt=CONDENSE_QUESTION_PROMPT,
    )
    return qa_chain
    
    
class ChatWrapper:
    global chain

    def __init__(self):
        self.lock = Lock()
    def __call__(
        self, inp: str, history: Optional[Tuple[str, str]], chain
    ):
        """Execute the chat functionality."""
        self.lock.acquire()
        try:
            history = history or []
            # Set OpenAI key
            chain = get_chain(vectorstore)
            # Run chain and append input.
            output = chain({"question": inp, "chat_history": history})["answer"]
            history.append((inp, output))
        except Exception as e:
            raise e
        finally:
            self.lock.release()
        return history, history

block = gr.Blocks(css=".gradio-container {background-color: lightgray}")

with block:
    with gr.Row():
        gr.Markdown("<h3><center>Chat-Your-Data (Investor Education)</center></h3>")

    with gr.Row():
        embeddings = gr.Radio(choices=model_options_list,value=model_options_list[0], label='Choose your Embedding Model',
                             interactive=True)
        embeddings.change(on_value_change, embeddings)
    
    vectorstore = load_vectorstore(embeddings.value)

    chatbot = gr.Chatbot()

    chat = ChatWrapper()


    with gr.Row():
        message = gr.Textbox(
            label="What's your question?",
            placeholder="Ask questions about Investing",
            lines=1,
        )
        submit = gr.Button(value="Send", variant="secondary").style(full_width=False)

    gr.Examples(
        examples=[
            "What are the benefits of investing in ETFs?",
            "What is the average cost of investing in a managed fund?",
            "At what age can I start investing?",
            "Do you offer investment accounts for kids?"
        ],
        inputs=message,
    )

    gr.HTML("Demo application of a LangChain chain.")

    gr.HTML(
        "<center>Powered by <a href='https://github.com/hwchase17/langchain'>LangChain πŸ¦œοΈπŸ”—</a></center>"
    )

    state = gr.State()
    agent_state = gr.State()

    
    submit.click(chat, inputs=[message, state, agent_state], outputs=[chatbot, state])
    message.submit(chat, inputs=[message, state, agent_state], outputs=[chatbot, state])

    # agent_state = chain

    gr.Markdown("![visitor badge](https://visitor-badge.glitch.me/badge?page_id=nickmuchi-investor-chatchain)")

block.launch(debug=True)