Spaces:
Sleeping
Sleeping
File size: 3,998 Bytes
82385e8 9ecc32c 52fbbdb 9ecc32c 52fbbdb 9ecc32c 8503786 9ecc32c 82385e8 9ecc32c 82385e8 8503786 82385e8 8503786 9ecc32c 8503786 4639b02 52fbbdb 4639b02 52fbbdb 9ecc32c 8503786 9ecc32c 8503786 9ecc32c 8503786 9ecc32c 8503786 82385e8 9ecc32c 82385e8 9ecc32c 82385e8 9ecc32c 82385e8 4639b02 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from langchain import PromptTemplate
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain.chains import RetrievalQA
from langchain_core import RunnableSequence
import gradio as gr
DB_FAISS_PATH = "vectorstores/db_faiss"
class GPT2LLM:
"""
A custom class to wrap the GPT-2 model and tokenizer to be used with LangChain.
"""
def __init__(self, model, tokenizer):
self.model = model
self.tokenizer = tokenizer
def __call__(self, prompt_text, max_length=512):
inputs = self.tokenizer.encode(prompt_text, return_tensors='pt')
outputs = self.model.generate(inputs, max_length=max_length, temperature=0.5)
return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
def load_llm():
"""
Load the GPT-2 model for the language model.
"""
try:
print("Downloading or loading the GPT-2 model and tokenizer...")
model_name = 'gpt2'
model = GPT2LMHeadModel.from_pretrained(model_name)
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
print("Model and tokenizer successfully loaded!")
return GPT2LLM(model, tokenizer)
except Exception as e:
print(f"An error occurred while loading the model: {e}")
return None
def set_custom_prompt():
"""
Define a custom prompt template for the QA model.
"""
custom_prompt_template = """Use the following pieces of information to answer the user's question.
If you don't know the answer, just say that you don't know, don't try to make up an answer.
Context: {context}
Question: {question}
only return the helpful answer below and nothing else.
Helpful answer:
"""
prompt = PromptTemplate(template=custom_prompt_template, input_variables=['context', 'question'])
return prompt
def retrieval_QA_chain(llm, prompt, db):
"""
Create a RetrievalQA chain with the specified LLM, prompt, and vector store using the updated RunnableSequence.
"""
llm_chain = RunnableSequence([prompt, llm])
qachain = RetrievalQA.from_chain_type(
llm_chain=llm_chain,
chain_type="stuff",
retriever=db.as_retriever(search_kwargs={'k': 2}),
return_source_documents=True
)
return qachain
def qa_bot():
"""
Initialize the QA bot with embeddings, vector store, LLM, and prompt.
"""
embeddings = HuggingFaceEmbeddings(model_name='sentence-transformers/all-miniLM-L6-V2', model_kwargs={'device': 'cpu'})
db = FAISS.load_local(DB_FAISS_PATH, embeddings, allow_dangerous_deserialization=True)
llm = load_llm()
qa_prompt = set_custom_prompt()
if llm:
qa = retrieval_QA_chain(llm, qa_prompt, db)
else:
qa = None
return qa
bot = qa_bot()
def chatbot_response(message, history):
"""
Generate a response from the chatbot based on the user input and conversation history.
"""
try:
if bot:
response = bot({'query': message})
answer = response["result"]
sources = response.get("source_documents", [])
if sources:
answer += f"\nSources: {sources}"
else:
answer += "\nNo sources found"
history.append((message, answer))
else:
history.append((message, "Model is not loaded properly."))
except Exception as e:
history.append((message, f"An error occurred: {str(e)}"))
return history, history
# Set up the Gradio interface
demo = gr.Interface(
fn=chatbot_response,
inputs=[
gr.Textbox(label="User Input"),
gr.State(value=[], label="Conversation History")
],
outputs=[
gr.Chatbot(label="Chatbot Response"),
gr.State()
],
title="AdvocateAI",
description="Ask questions about AI rights and get informed, passionate answers."
)
if __name__ == "__main__":
demo.launch()
|