File size: 3,919 Bytes
82385e8
9ecc32c
8336d2a
9ecc32c
82385e8
9ecc32c
 
 
 
 
 
 
 
82385e8
9ecc32c
82385e8
 
 
 
 
 
 
 
 
 
9ecc32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82385e8
9ecc32c
 
 
82385e8
 
 
 
 
 
 
9ecc32c
82385e8
9ecc32c
 
 
 
 
 
 
 
 
 
 
 
 
82385e8
9ecc32c
82385e8
 
 
 
9ecc32c
 
 
 
 
 
 
 
 
82385e8
 
 
 
 
 
 
 
 
9ecc32c
82385e8
9ecc32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82385e8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from langchain import PromptTemplate
from langchain.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_community.llms import CTransformers  # You might need to change this if GPT-2 isn't directly supported
from langchain.chains import RetrievalQA
import gradio as gr
from huggingface_hub import hf_hub_download

DB_FAISS_PATH = "vectorstores/db_faiss"

def load_llm():
    """
    Load the GPT-2 model for the language model.
    """
    try:
        print("Downloading or loading the GPT-2 model and tokenizer...")
        model_name = 'gpt2'
        model = GPT2LMHeadModel.from_pretrained(model_name)
        tokenizer = GPT2Tokenizer.from_pretrained(model_name)
        print("Model and tokenizer successfully loaded!")
        return model, tokenizer
    except Exception as e:
        print(f"An error occurred while loading the model: {e}")
        return None, None

def set_custom_prompt():
    """
    Define a custom prompt template for the QA model.
    """
    custom_prompt_template = """Use the following pieces of information to answer the user's question.
If you don't know the answer, just say that you don't know, don't try to make up an answer.

Context: {context}
Question: {question}

only return the helpful answer below and nothing else.
Helpful answer:
"""
    prompt = PromptTemplate(template=custom_prompt_template, input_variables=['context', 'question'])
    return prompt

def retrieval_QA_chain(llm, tokenizer, prompt, db):
    """
    Create a RetrievalQA chain with the specified LLM, prompt, and vector store.
    """
    def generate_answer(query):
        # Tokenize the input query
        inputs = tokenizer.encode(query, return_tensors='pt')
        # Generate response
        outputs = llm.generate(inputs, max_length=512, temperature=0.5)
        return tokenizer.decode(outputs[0], skip_special_tokens=True)

    qachain = RetrievalQA.from_chain_type(
        llm=generate_answer,
        chain_type="stuff",
        retriever=db.as_retriever(search_kwargs={'k': 2}),
        return_source_documents=True,
        chain_type_kwargs={'prompt': prompt}
    )
    return qachain

def qa_bot():
    """
    Initialize the QA bot with embeddings, vector store, LLM, and prompt.
    """
    embeddings = HuggingFaceEmbeddings(model_name='sentence-transformers/all-miniLM-L6-V2', model_kwargs={'device': 'cpu'})
    db = FAISS.load_local(DB_FAISS_PATH, embeddings, allow_dangerous_deserialization=True)
    model, tokenizer = load_llm()
    qa_prompt = set_custom_prompt()
    if model and tokenizer:
        qa = retrieval_QA_chain(model, tokenizer, qa_prompt, db)
    else:
        qa = None
    return qa

bot = qa_bot()

def chatbot_response(message, history):
    """
    Generate a response from the chatbot based on the user input and conversation history.
    """
    try:
        if bot:
            response = bot({'query': message})
            answer = response["result"]
            sources = response.get("source_documents", [])
            if sources:
                answer += f"\nSources: {sources}"
            else:
                answer += "\nNo sources found"
            history.append((message, answer))
        else:
            history.append((message, "Model is not loaded properly."))
    except Exception as e:
        history.append((message, f"An error occurred: {str(e)}"))
    return history, history

# Set up the Gradio interface
demo = gr.Interface(
    fn=chatbot_response,
    inputs=[
        gr.Textbox(label="User Input"),
        gr.State(value=[], label="Conversation History")
    ],
    outputs=[
        gr.Chatbot(label="Chatbot Response"),
        gr.State()
    ],
    title="AdvocateAI",
    description="Ask questions about AI rights and get informed, passionate answers."
)

if __name__ == "__main__":
    demo.launch()