Manasa1 commited on
Commit
0120d98
·
verified ·
1 Parent(s): faa1e33

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -59
app.py CHANGED
@@ -3,7 +3,7 @@ from langchain.embeddings import HuggingFaceEmbeddings
3
  from langchain.vectorstores import FAISS
4
  from langchain.llms import CTransformers
5
  from langchain.chains import RetrievalQA
6
- import chainlit as cl
7
 
8
  DB_FAISS_PATH = "vectorstores/db_faiss"
9
 
@@ -18,74 +18,58 @@ Helpful answer:
18
  """
19
 
20
  def set_custom_prompt():
21
- """
22
- Prompt template for QA retrieval for each vectorstore
23
- """
24
- prompt = PromptTemplate(template=custom_prompt_template,
25
- input_variables=['context', 'question'])
26
  return prompt
27
 
28
-
29
  def load_llm():
30
- """
31
- Load the language model
32
- """
33
  llm = CTransformers(
34
  model="C:/Users/sanath/Downloads/llama-2-7b-chat.ggmlv3.q8_0.bin",
35
- model_type = "llama",
36
- max_new_tokens = 512,
37
- temperature = 0.5
38
  )
39
  return llm
40
 
41
- def retrieval_QA_chain(llm,prompt,db):
42
- qachain = RetrievalQA.from_chain_type(
43
- llm=llm,
44
- chain_type="stuff",
45
- retriever=db.as_retriever(search_kwargs={'k':2}),
46
- return_source_documents=True,
47
- chain_type_kwargs={'prompt':prompt}
48
-
49
- )
50
- return qachain
51
-
52
 
53
  def qa_bot():
54
- embeddings = HuggingFaceEmbeddings(model_name='sentence-transformers/all-miniLM-L6-V2', model_kwargs={'device': 'cpu'})
55
- db = FAISS.load_local(DB_FAISS_PATH,embeddings)
56
- llm = load_llm()
57
- qa_prompt = set_custom_prompt()
58
- qa = retrieval_QA_chain(llm,qa_prompt,db)
59
- return qa
60
-
61
- def final_result(query):
62
- qa_result = qa_bot()
63
- response = qa_result({'query':query})
64
- return response
65
 
66
- @cl.on_chat_start
67
- async def start():
68
- chain = qa_bot()
69
- msg = cl.Message(content = "Starting the bot...")
70
- await msg.send()
71
- msg.content = "Hi, Welcome to the Medical bot. What is your query?"
72
- await msg.update()
 
 
 
 
 
 
73
 
74
- cl.user_session.set("chain",chain)
 
 
 
 
 
 
 
75
 
76
- @cl.on_message
77
- async def main(message):
78
- chain = cl.user_session.get("chain")
79
- cb = cl.AsyncLangchainCallbackHandler(
80
- stream_final_answer = True, answer_prefix_tokens = ["FINAL","ANSWER"]
81
- )
82
- cb.answer_reached = True
83
- res = await chain.acall(message, callbacks=[cb])
84
- answer = res["result"]
85
- sources = res["source_documents"]
86
- if sources:
87
- answer += f"\nSources:" + str(sources)
88
- else:
89
- answer += "\nNo sources found"
90
-
91
- await cl.Message(content=answer).send()
 
3
  from langchain.vectorstores import FAISS
4
  from langchain.llms import CTransformers
5
  from langchain.chains import RetrievalQA
6
+ import gradio as gr
7
 
8
  DB_FAISS_PATH = "vectorstores/db_faiss"
9
 
 
18
  """
19
 
20
  def set_custom_prompt():
21
+ prompt = PromptTemplate(template=custom_prompt_template, input_variables=['context', 'question'])
 
 
 
 
22
  return prompt
23
 
 
24
  def load_llm():
 
 
 
25
  llm = CTransformers(
26
  model="C:/Users/sanath/Downloads/llama-2-7b-chat.ggmlv3.q8_0.bin",
27
+ model_type="llama",
28
+ max_new_tokens=512,
29
+ temperature=0.5
30
  )
31
  return llm
32
 
33
+ def retrieval_QA_chain(llm, prompt, db):
34
+ qachain = RetrievalQA.from_chain_type(
35
+ llm=llm,
36
+ chain_type="stuff",
37
+ retriever=db.as_retriever(search_kwargs={'k': 2}),
38
+ return_source_documents=True,
39
+ chain_type_kwargs={'prompt': prompt}
40
+ )
41
+ return qachain
 
 
42
 
43
  def qa_bot():
44
+ embeddings = HuggingFaceEmbeddings(model_name='sentence-transformers/all-miniLM-L6-V2', model_kwargs={'device': 'cpu'})
45
+ db = FAISS.load_local(DB_FAISS_PATH, embeddings, allow_dangerous_deserialization=True)
46
+ llm = load_llm()
47
+ qa_prompt = set_custom_prompt()
48
+ qa = retrieval_QA_chain(llm, qa_prompt, db)
49
+ return qa
 
 
 
 
 
50
 
51
+ def chatbot_response(query):
52
+ try:
53
+ qa = qa_bot()
54
+ response = qa({'query': query})
55
+ answer = response["result"]
56
+ sources = response["source_documents"]
57
+ if sources:
58
+ answer += f"\nSources:" + str(sources)
59
+ else:
60
+ answer += "\nNo sources found"
61
+ return answer
62
+ except Exception as e:
63
+ return f"An error occurred: {str(e)}"
64
 
65
+ # Create a Gradio interface with updated API
66
+ iface = gr.Interface(
67
+ fn=chatbot_response,
68
+ inputs=gr.Textbox(lines=2, placeholder="Enter your question..."),
69
+ outputs="text",
70
+ title="Medical Chatbot",
71
+ description="Ask a medical question and get answers based on the provided context."
72
+ )
73
 
74
+ # Launch the Gradio app
75
+ iface.launch()