Spaces:
Sleeping
Sleeping
File size: 3,125 Bytes
9ecc32c 8336d2a 9ecc32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
from langchain import PromptTemplate
from langchain.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_community.llms import CTransformers
from langchain.chains import RetrievalQA
import gradio as gr
from huggingface_hub import hf_hub_download
DB_FAISS_PATH = "vectorstores/db_faiss"
def load_llm():
"""
Load the LLaMA model for the language model.
"""
model_name = 'TheBloke/Llama-2-7B-Chat-GGML'
model_path = hf_hub_download(repo_id=model_name, filename='llama-2-7b-chat.ggmlv3.q8_0.bin', cache_dir='./models')
llm = CTransformers(
model=model_path,
model_type="llama",
max_new_tokens=512,
temperature=0.5
)
return llm
def set_custom_prompt():
"""
Define a custom prompt template for the QA model.
"""
custom_prompt_template = """Use the following pieces of information to answer the user's question.
If you don't know the answer, just say that you don't know, don't try to make up an answer.
Context: {context}
Question: {question}
only return the helpful answer below and nothing else.
Helpful answer:
"""
prompt = PromptTemplate(template=custom_prompt_template, input_variables=['context', 'question'])
return prompt
def retrieval_QA_chain(llm, prompt, db):
"""
Create a RetrievalQA chain with the specified LLM, prompt, and vector store.
"""
qachain = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=db.as_retriever(search_kwargs={'k': 2}),
return_source_documents=True,
chain_type_kwargs={'prompt': prompt}
)
return qachain
def qa_bot():
"""
Initialize the QA bot with embeddings, vector store, LLM, and prompt.
"""
embeddings = HuggingFaceEmbeddings(model_name='sentence-transformers/all-miniLM-L6-V2', model_kwargs={'device': 'cpu'})
db = FAISS.load_local(DB_FAISS_PATH, embeddings, allow_dangerous_deserialization=True)
llm = load_llm()
qa_prompt = set_custom_prompt()
qa = retrieval_QA_chain(llm, qa_prompt, db)
return qa
bot = qa_bot()
def chatbot_response(message, history):
"""
Generate a response from the chatbot based on the user input and conversation history.
"""
try:
response = bot({'query': message})
answer = response["result"]
sources = response["source_documents"]
if sources:
answer += f"\nSources: {sources}"
else:
answer += "\nNo sources found"
history.append((message, answer))
except Exception as e:
history.append((message, f"An error occurred: {str(e)}"))
return history, history
# Set up the Gradio interface
demo = gr.Interface(
fn=chatbot_response,
inputs=[
gr.Textbox(label="User Input"),
gr.State(value=[], label="Conversation History")
],
outputs=[
gr.Chatbot(label="Chatbot Response"),
gr.State()
],
title="AdvocateAI",
description="Ask questions about AI rights and get informed, passionate answers."
)
if __name__ == "__main__":
demo.launch() |