Manasa1 commited on
Commit
9ecc32c
·
verified ·
1 Parent(s): d8238d7

Create app

Browse files
Files changed (1) hide show
  1. app +100 -0
app ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain import PromptTemplate
2
+ from langchain_huggingface import HuggingFaceEmbeddings
3
+ from langchain_community.vectorstores import FAISS
4
+ from langchain_community.llms import CTransformers
5
+ from langchain.chains import RetrievalQA
6
+ import gradio as gr
7
+ from huggingface_hub import hf_hub_download
8
+
9
+ DB_FAISS_PATH = "vectorstores/db_faiss"
10
+
11
+ def load_llm():
12
+ """
13
+ Load the LLaMA model for the language model.
14
+ """
15
+ model_name = 'TheBloke/Llama-2-7B-Chat-GGML'
16
+ model_path = hf_hub_download(repo_id=model_name, filename='llama-2-7b-chat.ggmlv3.q8_0.bin', cache_dir='./models')
17
+ llm = CTransformers(
18
+ model=model_path,
19
+ model_type="llama",
20
+ max_new_tokens=512,
21
+ temperature=0.5
22
+ )
23
+ return llm
24
+
25
+ def set_custom_prompt():
26
+ """
27
+ Define a custom prompt template for the QA model.
28
+ """
29
+ custom_prompt_template = """Use the following pieces of information to answer the user's question.
30
+ If you don't know the answer, just say that you don't know, don't try to make up an answer.
31
+
32
+ Context: {context}
33
+ Question: {question}
34
+
35
+ only return the helpful answer below and nothing else.
36
+ Helpful answer:
37
+ """
38
+ prompt = PromptTemplate(template=custom_prompt_template, input_variables=['context', 'question'])
39
+ return prompt
40
+
41
+ def retrieval_QA_chain(llm, prompt, db):
42
+ """
43
+ Create a RetrievalQA chain with the specified LLM, prompt, and vector store.
44
+ """
45
+ qachain = RetrievalQA.from_chain_type(
46
+ llm=llm,
47
+ chain_type="stuff",
48
+ retriever=db.as_retriever(search_kwargs={'k': 2}),
49
+ return_source_documents=True,
50
+ chain_type_kwargs={'prompt': prompt}
51
+ )
52
+ return qachain
53
+
54
+ def qa_bot():
55
+ """
56
+ Initialize the QA bot with embeddings, vector store, LLM, and prompt.
57
+ """
58
+ embeddings = HuggingFaceEmbeddings(model_name='sentence-transformers/all-miniLM-L6-V2', model_kwargs={'device': 'cpu'})
59
+ db = FAISS.load_local(DB_FAISS_PATH, embeddings, allow_dangerous_deserialization=True)
60
+ llm = load_llm()
61
+ qa_prompt = set_custom_prompt()
62
+ qa = retrieval_QA_chain(llm, qa_prompt, db)
63
+ return qa
64
+
65
+ bot = qa_bot()
66
+
67
+ def chatbot_response(message, history):
68
+ """
69
+ Generate a response from the chatbot based on the user input and conversation history.
70
+ """
71
+ try:
72
+ response = bot({'query': message})
73
+ answer = response["result"]
74
+ sources = response["source_documents"]
75
+ if sources:
76
+ answer += f"\nSources: {sources}"
77
+ else:
78
+ answer += "\nNo sources found"
79
+ history.append((message, answer))
80
+ except Exception as e:
81
+ history.append((message, f"An error occurred: {str(e)}"))
82
+ return history, history
83
+
84
+ # Set up the Gradio interface
85
+ demo = gr.Interface(
86
+ fn=chatbot_response,
87
+ inputs=[
88
+ gr.Textbox(label="User Input"),
89
+ gr.State(value=[], label="Conversation History")
90
+ ],
91
+ outputs=[
92
+ gr.Chatbot(label="Chatbot Response"),
93
+ gr.State()
94
+ ],
95
+ title="AdvocateAI",
96
+ description="Ask questions about AI rights and get informed, passionate answers."
97
+ )
98
+
99
+ if __name__ == "__main__":
100
+ demo.launch()