File size: 4,506 Bytes
b515f84
bd5e335
ff8cb83
 
 
e887c2a
b515f84
103bc92
 
2b65fe3
 
 
e887c2a
 
b515f84
 
080bbc9
b515f84
 
 
 
 
 
 
080bbc9
84485f7
 
946ff7c
 
 
bd5e335
ff8cb83
bd5e335
080bbc9
bf2279b
 
946ff7c
bf2279b
080bbc9
bf2279b
080bbc9
 
 
2b65fe3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b515f84
 
 
 
 
 
 
 
 
 
 
 
b0cff56
 
b515f84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
080bbc9
12fb877
b515f84
 
 
 
 
 
2b65fe3
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import gradio as gr
import boto3
from botocore import UNSIGNED
from botocore.client import Config

from langchain.document_loaders import WebBaseLoader

from huggingface_hub import AsyncInferenceClient

from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer


from langchain.text_splitter import RecursiveCharacterTextSplitter
text_splitter = RecursiveCharacterTextSplitter(chunk_size=350, chunk_overlap=10)

from langchain.llms import HuggingFaceHub
model_id = HuggingFaceHub(repo_id="HuggingFaceH4/zephyr-7b-beta", model_kwargs={"temperature":0.1, "max_new_tokens":300})

from langchain.embeddings import HuggingFaceHubEmbeddings
embeddings = HuggingFaceHubEmbeddings()

from langchain.vectorstores import Chroma

from langchain.chains import RetrievalQA

from langchain.prompts import ChatPromptTemplate

#web_links = ["https://www.databricks.com/","https://help.databricks.com","https://docs.databricks.com","https://kb.databricks.com/","http://docs.databricks.com/getting-started/index.html","http://docs.databricks.com/introduction/index.html","http://docs.databricks.com/getting-started/tutorials/index.html","http://docs.databricks.com/machine-learning/index.html","http://docs.databricks.com/sql/index.html"]
#loader = WebBaseLoader(web_links)
#documents = loader.load()

s3 = boto3.client('s3', config=Config(signature_version=UNSIGNED))
s3.download_file('rad-rag-demos', 'vectorstores/chroma.sqlite3', './chroma_db/chroma.sqlite3')
     
db = Chroma(persist_directory="./chroma_db", embedding_function=embeddings)
db.get()
#texts = text_splitter.split_documents(documents)
#db = Chroma.from_documents(texts, embedding_function=embeddings)
retriever = db.as_retriever()

global qa 
qa = RetrievalQA.from_chain_type(llm=model_id, chain_type="stuff", retriever=retriever, return_source_documents=True)

def generate(
    message: str,
    chat_history: list[tuple[str, str]],
    system_prompt: str,
    max_new_tokens: int = 1024,
    temperature: float = 0.6,
    top_p: float = 0.9,
    top_k: int = 50,
    repetition_penalty: float = 1.2,
) -> Iterator[str]:
    conversation = []
    if system_prompt:
        conversation.append({"role": "system", "content": system_prompt})
    for user, assistant in chat_history:
        conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
    conversation.append({"role": "user", "content": message})

    input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
    if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
        input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
        gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
    input_ids = input_ids.to(model.device)

    streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
    generate_kwargs = dict(
        {"input_ids": input_ids},
        streamer=streamer,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        top_p=top_p,
        top_k=top_k,
        temperature=temperature,
        num_beams=1,
        repetition_penalty=repetition_penalty,
    )
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    outputs = []
    for text in streamer:
        outputs.append(text)
        yield "".join(outputs)


def add_text(history, text):
    history = history + [(text, None)]
    return history, ""

def bot(history):
    response = infer(history[-1][0])
    history[-1][1] = response['result']
    return history

def infer(question):
    
    query = question
    result = qa({"query": query})
    return result

css="""
#col-container {max-width: 700px; margin-left: auto; margin-right: auto;}
"""

title = """
<div style="text-align: center;max-width: 700px;">
    <h1>Chat with PDF</h1>
    <p style="text-align: center;">Upload a .PDF from your computer, click the "Load PDF to LangChain" button, <br />
    when everything is ready, you can start asking questions about the pdf ;)</p>
</div>
"""


with gr.Blocks(css=css) as demo:
    with gr.Column(elem_id="col-container"):
        gr.HTML(title)      
        chatbot = gr.Chatbot([], elem_id="chatbot")
        with gr.Row():
            question = gr.Textbox(label="Question", placeholder="Type your question and hit Enter ")
    question.submit(add_text, [chatbot, question], [chatbot, question]).then(
        bot, chatbot, chatbot
    )

if __name__ == "__main__":

    demo.launch()