File size: 7,099 Bytes
30a2426
 
b515f84
bd5e335
ff8cb83
 
 
66b1dc7
 
b515f84
103bc92
 
7961699
2b65fe3
e887c2a
b515f84
 
 
 
84485f7
445dc1d
7961699
66b1dc7
7961699
ba58b26
 
7961699
445dc1d
cd7efbd
 
8bddd83
 
 
 
445dc1d
463e62a
 
445dc1d
84485f7
10e2a26
 
2ccbf4d
463e62a
2ccbf4d
 
 
 
 
a76fcae
4af1853
463e62a
2ccbf4d
abaa624
e34519b
2ccbf4d
4af1853
 
2ccbf4d
e1b8370
 
 
 
 
 
 
 
 
 
 
66b1dc7
e1b8370
 
66b1dc7
e1b8370
 
84dc3a4
463e62a
66b1dc7
 
3e8e635
 
4af1853
3e8e635
ad54a96
22d3354
 
bd5e335
ff8cb83
bd5e335
080bbc9
bf2279b
 
c7634f3
 
080bbc9
bf2279b
080bbc9
a84883b
080bbc9
2b65fe3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b515f84
 
 
 
 
 
 
 
 
 
 
 
b0cff56
 
b515f84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1c2a1d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10d5da7
b515f84
2b65fe3
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
import os 

import gradio as gr
import boto3
from botocore import UNSIGNED
from botocore.client import Config

import torch


from huggingface_hub import AsyncInferenceClient

from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, pipeline

from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.llms import HuggingFaceHub
from langchain.embeddings import HuggingFaceHubEmbeddings
from langchain.vectorstores import Chroma
from langchain.chains import RetrievalQA
from langchain.prompts import ChatPromptTemplate
from langchain.document_loaders import WebBaseLoader
from langchain.llms.huggingface_pipeline import HuggingFacePipeline
from langchain.llms import CTransformers

from transformers import AutoModel



from typing import Iterator

MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))


# text_splitter = RecursiveCharacterTextSplitter(chunk_size=350, chunk_overlap=10)

embeddings = HuggingFaceHubEmbeddings()

model_id = "TheBloke/zephyr-7B-beta-GGUF"
# model_id = "HuggingFaceH4/zephyr-7b-beta"
# model_id = "meta-llama/Llama-2-7b-chat-hf"

# model = AutoModelForCausalLM.from_pretrained(
#     model_id, 
#     device_map="auto",
#     low_cpu_mem_usage=True
# )

# print( "initalized model")

# tokenizer = AutoTokenizer.from_pretrained(model_id)
# model = AutoModelForCausalLM.from_pretrained(model_id)
# model = AutoModel.from_pretrained("TheBloke/zephyr-7B-beta-GGUF")

device = "cpu"


# llm_model = CTransformers(
#     model="TheBloke/zephyr-7B-beta-GGUF",
#     model_type="mistral",
#     max_new_tokens=4384,
#     temperature=0.2,
#     repetition_penalty=1.13,
#     device=device  # Set the device explicitly during model initialization
# )

# Load model directly
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta")
model = AutoModelForCausalLM.from_pretrained("HuggingFaceH4/zephyr-7b-beta")


# tokenizer = AutoTokenizer.from_pretrained(model_id)
# model = AutoModelForCausalLM.from_pretrained(model_id)

# pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=10)
# hf = HuggingFacePipeline(pipeline=pipe)


print( "initalized  model")

# tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.use_default_system_prompt = False


s3 = boto3.client('s3', config=Config(signature_version=UNSIGNED))
s3.download_file('rad-rag-demos', 'vectorstores/chroma.sqlite3', './chroma_db/chroma.sqlite3')
     
db = Chroma(persist_directory="./chroma_db", embedding_function=embeddings)
db.get()


retriever = db.as_retriever()

global qa 
qa = RetrievalQA.from_chain_type(llm=llm_model, chain_type="stuff", retriever=retriever, return_source_documents=True)

def generate(
    message: str,
    chat_history: list[tuple[str, str]],
    system_prompt: str,
    max_new_tokens: int = 1024,
    temperature: float = 0.6,
    top_p: float = 0.9,
    top_k: int = 50,
    repetition_penalty: float = 1.2,
) -> Iterator[str]:
    conversation = []
    if system_prompt:
        conversation.append({"role": "system", "content": system_prompt})
    for user, assistant in chat_history:
        conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
    conversation.append({"role": "user", "content": message})

    input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
    if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
        input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
        gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
    input_ids = input_ids.to(model.device)

    streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
    generate_kwargs = dict(
        {"input_ids": input_ids},
        streamer=streamer,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        top_p=top_p,
        top_k=top_k,
        temperature=temperature,
        num_beams=1,
        repetition_penalty=repetition_penalty,
    )
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    outputs = []
    for text in streamer:
        outputs.append(text)
        yield "".join(outputs)


def add_text(history, text):
    history = history + [(text, None)]
    return history, ""

def bot(history):
    response = infer(history[-1][0])
    history[-1][1] = response['result']
    return history

def infer(question):
    
    query = question
    result = qa({"query": query})
    return result

css="""
#col-container {max-width: 700px; margin-left: auto; margin-right: auto;}
"""

title = """
<div style="text-align: center;max-width: 700px;">
    <h1>Chat with PDF</h1>
    <p style="text-align: center;">Upload a .PDF from your computer, click the "Load PDF to LangChain" button, <br />
    when everything is ready, you can start asking questions about the pdf ;)</p>
</div>
"""


# with gr.Blocks(css=css) as demo:
#     with gr.Column(elem_id="col-container"):
#         gr.HTML(title)      
#         chatbot = gr.Chatbot([], elem_id="chatbot")
#         with gr.Row():
#             question = gr.Textbox(label="Question", placeholder="Type your question and hit Enter ")
#     question.submit(add_text, [chatbot, question], [chatbot, question]).then(
#         bot, chatbot, chatbot
#     )

chat_interface = gr.ChatInterface(
    fn=generate,
    additional_inputs=[
        gr.Textbox(label="System prompt", lines=6),
        gr.Slider(
            label="Max new tokens",
            minimum=1,
            maximum=MAX_MAX_NEW_TOKENS,
            step=1,
            value=DEFAULT_MAX_NEW_TOKENS,
        ),
        gr.Slider(
            label="Temperature",
            minimum=0.1,
            maximum=4.0,
            step=0.1,
            value=0.6,
        ),
        gr.Slider(
            label="Top-p (nucleus sampling)",
            minimum=0.05,
            maximum=1.0,
            step=0.05,
            value=0.9,
        ),
        gr.Slider(
            label="Top-k",
            minimum=1,
            maximum=1000,
            step=1,
            value=50,
        ),
        gr.Slider(
            label="Repetition penalty",
            minimum=1.0,
            maximum=2.0,
            step=0.05,
            value=1.2,
        ),
    ],
    stop_btn=None,
    examples=[
        ["Hello there! How are you doing?"],
        ["Can you explain briefly to me what is the Python programming language?"],
        ["Explain the plot of Cinderella in a sentence."],
        ["How many hours does it take a man to eat a Helicopter?"],
        ["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
    ],
)

with gr.Blocks(css="style.css") as demo:
    # gr.Markdown(DESCRIPTION)
    # gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
    chat_interface.render()
    # gr.Markdown(LICENSE)
#x = 0

if __name__ == "__main__":

    demo.launch()