from transformers import GPT2LMHeadModel, GPT2Tokenizer from langchain import PromptTemplate from langchain_huggingface import HuggingFaceEmbeddings from langchain_community.vectorstores import FAISS from langchain.chains import RetrievalQA from langchain_core import RunnableSequence import gradio as gr DB_FAISS_PATH = "vectorstores/db_faiss" class GPT2LLM: """ A custom class to wrap the GPT-2 model and tokenizer to be used with LangChain. """ def __init__(self, model, tokenizer): self.model = model self.tokenizer = tokenizer def __call__(self, prompt_text, max_length=512): inputs = self.tokenizer.encode(prompt_text, return_tensors='pt') outputs = self.model.generate(inputs, max_length=max_length, temperature=0.5) return self.tokenizer.decode(outputs[0], skip_special_tokens=True) def load_llm(): """ Load the GPT-2 model for the language model. """ try: print("Downloading or loading the GPT-2 model and tokenizer...") model_name = 'gpt2' model = GPT2LMHeadModel.from_pretrained(model_name) tokenizer = GPT2Tokenizer.from_pretrained(model_name) print("Model and tokenizer successfully loaded!") return GPT2LLM(model, tokenizer) except Exception as e: print(f"An error occurred while loading the model: {e}") return None def set_custom_prompt(): """ Define a custom prompt template for the QA model. """ custom_prompt_template = """Use the following pieces of information to answer the user's question. If you don't know the answer, just say that you don't know, don't try to make up an answer. Context: {context} Question: {question} only return the helpful answer below and nothing else. Helpful answer: """ prompt = PromptTemplate(template=custom_prompt_template, input_variables=['context', 'question']) return prompt def retrieval_QA_chain(llm, prompt, db): """ Create a RetrievalQA chain with the specified LLM, prompt, and vector store using the updated RunnableSequence. """ llm_chain = RunnableSequence([prompt, llm]) qachain = RetrievalQA.from_chain_type( llm_chain=llm_chain, chain_type="stuff", retriever=db.as_retriever(search_kwargs={'k': 2}), return_source_documents=True ) return qachain def qa_bot(): """ Initialize the QA bot with embeddings, vector store, LLM, and prompt. """ embeddings = HuggingFaceEmbeddings(model_name='sentence-transformers/all-miniLM-L6-V2', model_kwargs={'device': 'cpu'}) db = FAISS.load_local(DB_FAISS_PATH, embeddings, allow_dangerous_deserialization=True) llm = load_llm() qa_prompt = set_custom_prompt() if llm: qa = retrieval_QA_chain(llm, qa_prompt, db) else: qa = None return qa bot = qa_bot() def chatbot_response(message, history): """ Generate a response from the chatbot based on the user input and conversation history. """ try: if bot: response = bot({'query': message}) answer = response["result"] sources = response.get("source_documents", []) if sources: answer += f"\nSources: {sources}" else: answer += "\nNo sources found" history.append((message, answer)) else: history.append((message, "Model is not loaded properly.")) except Exception as e: history.append((message, f"An error occurred: {str(e)}")) return history, history # Set up the Gradio interface demo = gr.Interface( fn=chatbot_response, inputs=[ gr.Textbox(label="User Input"), gr.State(value=[], label="Conversation History") ], outputs=[ gr.Chatbot(label="Chatbot Response"), gr.State() ], title="AdvocateAI", description="Ask questions about AI rights and get informed, passionate answers." ) if __name__ == "__main__": demo.launch()