Spaces:
Sleeping
Sleeping
from transformers import GPT2LMHeadModel, GPT2Tokenizer | |
from langchain import PromptTemplate | |
from langchain_huggingface import HuggingFaceEmbeddings | |
from langchain_community.vectorstores import FAISS | |
from langchain.chains import RetrievalQA | |
from langchain_core import RunnableSequence | |
import gradio as gr | |
DB_FAISS_PATH = "vectorstores/db_faiss" | |
class GPT2LLM: | |
""" | |
A custom class to wrap the GPT-2 model and tokenizer to be used with LangChain. | |
""" | |
def __init__(self, model, tokenizer): | |
self.model = model | |
self.tokenizer = tokenizer | |
def __call__(self, prompt_text, max_length=512): | |
inputs = self.tokenizer.encode(prompt_text, return_tensors='pt') | |
outputs = self.model.generate(inputs, max_length=max_length, temperature=0.5) | |
return self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
def load_llm(): | |
""" | |
Load the GPT-2 model for the language model. | |
""" | |
try: | |
print("Downloading or loading the GPT-2 model and tokenizer...") | |
model_name = 'gpt2' | |
model = GPT2LMHeadModel.from_pretrained(model_name) | |
tokenizer = GPT2Tokenizer.from_pretrained(model_name) | |
print("Model and tokenizer successfully loaded!") | |
return GPT2LLM(model, tokenizer) | |
except Exception as e: | |
print(f"An error occurred while loading the model: {e}") | |
return None | |
def set_custom_prompt(): | |
""" | |
Define a custom prompt template for the QA model. | |
""" | |
custom_prompt_template = """Use the following pieces of information to answer the user's question. | |
If you don't know the answer, just say that you don't know, don't try to make up an answer. | |
Context: {context} | |
Question: {question} | |
only return the helpful answer below and nothing else. | |
Helpful answer: | |
""" | |
prompt = PromptTemplate(template=custom_prompt_template, input_variables=['context', 'question']) | |
return prompt | |
def retrieval_QA_chain(llm, prompt, db): | |
""" | |
Create a RetrievalQA chain with the specified LLM, prompt, and vector store using the updated RunnableSequence. | |
""" | |
llm_chain = RunnableSequence([prompt, llm]) | |
qachain = RetrievalQA.from_chain_type( | |
llm_chain=llm_chain, | |
chain_type="stuff", | |
retriever=db.as_retriever(search_kwargs={'k': 2}), | |
return_source_documents=True | |
) | |
return qachain | |
def qa_bot(): | |
""" | |
Initialize the QA bot with embeddings, vector store, LLM, and prompt. | |
""" | |
embeddings = HuggingFaceEmbeddings(model_name='sentence-transformers/all-miniLM-L6-V2', model_kwargs={'device': 'cpu'}) | |
db = FAISS.load_local(DB_FAISS_PATH, embeddings, allow_dangerous_deserialization=True) | |
llm = load_llm() | |
qa_prompt = set_custom_prompt() | |
if llm: | |
qa = retrieval_QA_chain(llm, qa_prompt, db) | |
else: | |
qa = None | |
return qa | |
bot = qa_bot() | |
def chatbot_response(message, history): | |
""" | |
Generate a response from the chatbot based on the user input and conversation history. | |
""" | |
try: | |
if bot: | |
response = bot({'query': message}) | |
answer = response["result"] | |
sources = response.get("source_documents", []) | |
if sources: | |
answer += f"\nSources: {sources}" | |
else: | |
answer += "\nNo sources found" | |
history.append((message, answer)) | |
else: | |
history.append((message, "Model is not loaded properly.")) | |
except Exception as e: | |
history.append((message, f"An error occurred: {str(e)}")) | |
return history, history | |
# Set up the Gradio interface | |
demo = gr.Interface( | |
fn=chatbot_response, | |
inputs=[ | |
gr.Textbox(label="User Input"), | |
gr.State(value=[], label="Conversation History") | |
], | |
outputs=[ | |
gr.Chatbot(label="Chatbot Response"), | |
gr.State() | |
], | |
title="AdvocateAI", | |
description="Ask questions about AI rights and get informed, passionate answers." | |
) | |
if __name__ == "__main__": | |
demo.launch() | |