File size: 3,998 Bytes
82385e8
9ecc32c
52fbbdb
9ecc32c
 
52fbbdb
9ecc32c
 
 
 
8503786
 
 
 
 
 
 
 
 
 
 
 
 
9ecc32c
 
82385e8
9ecc32c
82385e8
 
 
 
 
 
8503786
82385e8
 
8503786
9ecc32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8503786
4639b02
52fbbdb
4639b02
52fbbdb
9ecc32c
8503786
9ecc32c
 
8503786
9ecc32c
 
 
 
 
 
 
 
 
8503786
9ecc32c
8503786
 
82385e8
 
9ecc32c
 
 
 
 
 
 
 
 
82385e8
 
 
 
 
 
 
 
 
9ecc32c
82385e8
9ecc32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82385e8
4639b02
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from langchain import PromptTemplate
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain.chains import RetrievalQA
from langchain_core import RunnableSequence
import gradio as gr

DB_FAISS_PATH = "vectorstores/db_faiss"

class GPT2LLM:
    """
    A custom class to wrap the GPT-2 model and tokenizer to be used with LangChain.
    """
    def __init__(self, model, tokenizer):
        self.model = model
        self.tokenizer = tokenizer

    def __call__(self, prompt_text, max_length=512):
        inputs = self.tokenizer.encode(prompt_text, return_tensors='pt')
        outputs = self.model.generate(inputs, max_length=max_length, temperature=0.5)
        return self.tokenizer.decode(outputs[0], skip_special_tokens=True)

def load_llm():
    """
    Load the GPT-2 model for the language model.
    """
    try:
        print("Downloading or loading the GPT-2 model and tokenizer...")
        model_name = 'gpt2'
        model = GPT2LMHeadModel.from_pretrained(model_name)
        tokenizer = GPT2Tokenizer.from_pretrained(model_name)
        print("Model and tokenizer successfully loaded!")
        return GPT2LLM(model, tokenizer)
    except Exception as e:
        print(f"An error occurred while loading the model: {e}")
        return None

def set_custom_prompt():
    """
    Define a custom prompt template for the QA model.
    """
    custom_prompt_template = """Use the following pieces of information to answer the user's question.
If you don't know the answer, just say that you don't know, don't try to make up an answer.

Context: {context}
Question: {question}

only return the helpful answer below and nothing else.
Helpful answer:
"""
    prompt = PromptTemplate(template=custom_prompt_template, input_variables=['context', 'question'])
    return prompt

def retrieval_QA_chain(llm, prompt, db):
    """
    Create a RetrievalQA chain with the specified LLM, prompt, and vector store using the updated RunnableSequence.
    """
    llm_chain = RunnableSequence([prompt, llm])
    qachain = RetrievalQA.from_chain_type(
        llm_chain=llm_chain,
        chain_type="stuff",
        retriever=db.as_retriever(search_kwargs={'k': 2}),
        return_source_documents=True
    )
    return qachain

def qa_bot():
    """
    Initialize the QA bot with embeddings, vector store, LLM, and prompt.
    """
    embeddings = HuggingFaceEmbeddings(model_name='sentence-transformers/all-miniLM-L6-V2', model_kwargs={'device': 'cpu'})
    db = FAISS.load_local(DB_FAISS_PATH, embeddings, allow_dangerous_deserialization=True)
    llm = load_llm()
    qa_prompt = set_custom_prompt()
    if llm:
        qa = retrieval_QA_chain(llm, qa_prompt, db)
    else:
        qa = None
    return qa

bot = qa_bot()

def chatbot_response(message, history):
    """
    Generate a response from the chatbot based on the user input and conversation history.
    """
    try:
        if bot:
            response = bot({'query': message})
            answer = response["result"]
            sources = response.get("source_documents", [])
            if sources:
                answer += f"\nSources: {sources}"
            else:
                answer += "\nNo sources found"
            history.append((message, answer))
        else:
            history.append((message, "Model is not loaded properly."))
    except Exception as e:
        history.append((message, f"An error occurred: {str(e)}"))
    return history, history

# Set up the Gradio interface
demo = gr.Interface(
    fn=chatbot_response,
    inputs=[
        gr.Textbox(label="User Input"),
        gr.State(value=[], label="Conversation History")
    ],
    outputs=[
        gr.Chatbot(label="Chatbot Response"),
        gr.State()
    ],
    title="AdvocateAI",
    description="Ask questions about AI rights and get informed, passionate answers."
)

if __name__ == "__main__":
    demo.launch()