AdvocateAI / app.py
Manasa1's picture
Update app.py
82385e8 verified
raw
history blame
3.92 kB
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from langchain import PromptTemplate
from langchain.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_community.llms import CTransformers # You might need to change this if GPT-2 isn't directly supported
from langchain.chains import RetrievalQA
import gradio as gr
from huggingface_hub import hf_hub_download
DB_FAISS_PATH = "vectorstores/db_faiss"
def load_llm():
"""
Load the GPT-2 model for the language model.
"""
try:
print("Downloading or loading the GPT-2 model and tokenizer...")
model_name = 'gpt2'
model = GPT2LMHeadModel.from_pretrained(model_name)
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
print("Model and tokenizer successfully loaded!")
return model, tokenizer
except Exception as e:
print(f"An error occurred while loading the model: {e}")
return None, None
def set_custom_prompt():
"""
Define a custom prompt template for the QA model.
"""
custom_prompt_template = """Use the following pieces of information to answer the user's question.
If you don't know the answer, just say that you don't know, don't try to make up an answer.
Context: {context}
Question: {question}
only return the helpful answer below and nothing else.
Helpful answer:
"""
prompt = PromptTemplate(template=custom_prompt_template, input_variables=['context', 'question'])
return prompt
def retrieval_QA_chain(llm, tokenizer, prompt, db):
"""
Create a RetrievalQA chain with the specified LLM, prompt, and vector store.
"""
def generate_answer(query):
# Tokenize the input query
inputs = tokenizer.encode(query, return_tensors='pt')
# Generate response
outputs = llm.generate(inputs, max_length=512, temperature=0.5)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
qachain = RetrievalQA.from_chain_type(
llm=generate_answer,
chain_type="stuff",
retriever=db.as_retriever(search_kwargs={'k': 2}),
return_source_documents=True,
chain_type_kwargs={'prompt': prompt}
)
return qachain
def qa_bot():
"""
Initialize the QA bot with embeddings, vector store, LLM, and prompt.
"""
embeddings = HuggingFaceEmbeddings(model_name='sentence-transformers/all-miniLM-L6-V2', model_kwargs={'device': 'cpu'})
db = FAISS.load_local(DB_FAISS_PATH, embeddings, allow_dangerous_deserialization=True)
model, tokenizer = load_llm()
qa_prompt = set_custom_prompt()
if model and tokenizer:
qa = retrieval_QA_chain(model, tokenizer, qa_prompt, db)
else:
qa = None
return qa
bot = qa_bot()
def chatbot_response(message, history):
"""
Generate a response from the chatbot based on the user input and conversation history.
"""
try:
if bot:
response = bot({'query': message})
answer = response["result"]
sources = response.get("source_documents", [])
if sources:
answer += f"\nSources: {sources}"
else:
answer += "\nNo sources found"
history.append((message, answer))
else:
history.append((message, "Model is not loaded properly."))
except Exception as e:
history.append((message, f"An error occurred: {str(e)}"))
return history, history
# Set up the Gradio interface
demo = gr.Interface(
fn=chatbot_response,
inputs=[
gr.Textbox(label="User Input"),
gr.State(value=[], label="Conversation History")
],
outputs=[
gr.Chatbot(label="Chatbot Response"),
gr.State()
],
title="AdvocateAI",
description="Ask questions about AI rights and get informed, passionate answers."
)
if __name__ == "__main__":
demo.launch()