Manasa1 commited on
Commit
9dafbdb
·
verified ·
1 Parent(s): fc719be

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +91 -0
app.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain import PromptTemplate
2
+ from langchain.embeddings import HuggingFaceEmbeddings
3
+ from langchain.vectorstores import FAISS
4
+ from langchain.llms import CTransformers
5
+ from langchain.chains import RetrievalQA
6
+ import chainlit as cl
7
+
8
+ DB_FAISS_PATH = "vectorstores/db_faiss"
9
+
10
+ custom_prompt_template = """Use the following pieces of information to answer the user's question.
11
+ If you don't know the answer, just say that you don't know, don't try to make up an answer.
12
+
13
+ Context: {context}
14
+ Question: {question}
15
+
16
+ only return the helpful answer below and nothing else.
17
+ Helpful answer:
18
+ """
19
+
20
+ def set_custom_prompt():
21
+ """
22
+ Prompt template for QA retrieval for each vectorstore
23
+ """
24
+ prompt = PromptTemplate(template=custom_prompt_template,
25
+ input_variables=['context', 'question'])
26
+ return prompt
27
+
28
+
29
+ def load_llm():
30
+ """
31
+ Load the language model
32
+ """
33
+ llm = CTransformers(
34
+ model="C:/Users/sanath/Downloads/llama-2-7b-chat.ggmlv3.q8_0.bin",
35
+ model_type = "llama",
36
+ max_new_tokens = 512,
37
+ temperature = 0.5
38
+ )
39
+ return llm
40
+
41
+ def retrieval_QA_chain(llm,prompt,db):
42
+ qachain = RetrievalQA.from_chain_type(
43
+ llm=llm,
44
+ chain_type="stuff",
45
+ retriever=db.as_retriever(search_kwargs={'k':2}),
46
+ return_source_documents=True,
47
+ chain_type_kwargs={'prompt':prompt}
48
+
49
+ )
50
+ return qachain
51
+
52
+
53
+ def qa_bot():
54
+ embeddings = HuggingFaceEmbeddings(model_name='sentence-transformers/all-miniLM-L6-V2', model_kwargs={'device': 'cpu'})
55
+ db = FAISS.load_local(DB_FAISS_PATH,embeddings)
56
+ llm = load_llm()
57
+ qa_prompt = set_custom_prompt()
58
+ qa = retrieval_QA_chain(llm,qa_prompt,db)
59
+ return qa
60
+
61
+ def final_result(query):
62
+ qa_result = qa_bot()
63
+ response = qa_result({'query':query})
64
+ return response
65
+
66
+ @cl.on_chat_start
67
+ async def start():
68
+ chain = qa_bot()
69
+ msg = cl.Message(content = "Starting the bot...")
70
+ await msg.send()
71
+ msg.content = "Hi, Welcome to the Medical bot. What is your query?"
72
+ await msg.update()
73
+
74
+ cl.user_session.set("chain",chain)
75
+
76
+ @cl.on_message
77
+ async def main(message):
78
+ chain = cl.user_session.get("chain")
79
+ cb = cl.AsyncLangchainCallbackHandler(
80
+ stream_final_answer = True, answer_prefix_tokens = ["FINAL","ANSWER"]
81
+ )
82
+ cb.answer_reached = True
83
+ res = await chain.acall(message, callbacks=[cb])
84
+ answer = res["result"]
85
+ sources = res["source_documents"]
86
+ if sources:
87
+ answer += f"\nSources:" + str(sources)
88
+ else:
89
+ answer += "\nNo sources found"
90
+
91
+ await cl.Message(content=answer).send()