File size: 2,636 Bytes
dbd2a4a
03864fe
 
 
9dafbdb
0120d98
e6778e2
f321fb1
9dafbdb
 
03864fe
fefab64
 
03864fe
 
 
 
 
 
 
 
9dafbdb
 
 
 
 
 
 
 
 
 
 
0120d98
9dafbdb
 
0120d98
 
 
 
 
 
 
 
 
9dafbdb
 
0120d98
 
 
 
 
 
9dafbdb
0a5f820
 
 
0120d98
0a5f820
0120d98
 
 
 
 
 
0a5f820
0120d98
0a5f820
 
9dafbdb
1587b05
 
 
 
 
 
 
9dafbdb
1587b05
03864fe
fefab64
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
from langchain import PromptTemplate
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_community.llms import CTransformers
from langchain.chains import RetrievalQA
import gradio as gr
from huggingface_hub import hf_hub_download

DB_FAISS_PATH = "vectorstores/db_faiss"

def load_llm():
    model_name = 'TheBloke/Llama-2-7B-Chat-GGML'  # Correct model repository
    model_path = hf_hub_download(repo_id=model_name, filename='llama-2-7b-chat.ggmlv3.q8_0.bin', cache_dir='./models')
    llm = CTransformers(
        model=model_path,
        model_type="llama",
        max_new_tokens=512,
        temperature=0.5
    )
    return llm

custom_prompt_template = """Use the following pieces of information to answer the user's question.
If you don't know the answer, just say that you don't know, don't try to make up an answer.

Context: {context}
Question: {question}

only return the helpful answer below and nothing else.
Helpful answer:
"""

def set_custom_prompt():
    prompt = PromptTemplate(template=custom_prompt_template, input_variables=['context', 'question'])
    return prompt

def retrieval_QA_chain(llm, prompt, db):
    qachain = RetrievalQA.from_chain_type(
        llm=llm,
        chain_type="stuff",
        retriever=db.as_retriever(search_kwargs={'k': 2}),
        return_source_documents=True,
        chain_type_kwargs={'prompt': prompt}
    )
    return qachain

def qa_bot():
    embeddings = HuggingFaceEmbeddings(model_name='sentence-transformers/all-miniLM-L6-V2', model_kwargs={'device': 'cpu'})
    db = FAISS.load_local(DB_FAISS_PATH, embeddings, allow_dangerous_deserialization=True)
    llm = load_llm()
    qa_prompt = set_custom_prompt()
    qa = retrieval_QA_chain(llm, qa_prompt, db)
    return qa

bot = qa_bot()

def chatbot_response(message, history):
    try:
        response = bot({'query': message})
        answer = response["result"]
        sources = response["source_documents"]
        if sources:
            answer += f"\nSources:" + str(sources)
        else:
            answer += "\nNo sources found"
        history.append((message, answer))
    except Exception as e:
        history.append((message, f"An error occurred: {str(e)}"))
    return history, history

with gr.Blocks() as demo:
    chatbot = gr.Chatbot()
    with gr.Row():
        msg = gr.Textbox(show_label=False, placeholder="Enter your question...")
        submit = gr.Button("Send")
    submit.click(chatbot_response, [msg, chatbot], [chatbot, chatbot])
    msg.submit(chatbot_response, [msg, chatbot], [chatbot, chatbot])

demo.launch()