File size: 5,046 Bytes
b10b892
dcfa5ff
 
 
 
c26f9f0
dcfa5ff
f20eb59
dcfa5ff
 
 
 
f20eb59
4bd3448
f20eb59
4bd3448
1a1cf30
 
b10b892
f20eb59
b10b892
1a1cf30
f20eb59
1a1cf30
f20eb59
 
 
 
 
 
1a1cf30
f20eb59
 
 
1a1cf30
f20eb59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1a1cf30
 
baf459f
 
 
 
 
4bd3448
 
baf459f
 
be01a64
dcfa5ff
 
 
 
 
 
 
 
 
 
 
 
f20eb59
 
 
 
 
 
 
dcfa5ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1a1cf30
f20eb59
 
 
dcfa5ff
 
 
 
 
 
 
 
 
e71b560
 
41ad78f
e71b560
 
 
be01a64
e71b560
 
 
 
dcfa5ff
 
 
 
 
 
 
0f5a3c3
dcfa5ff
 
e71b560
dcfa5ff
 
 
 
b10b892
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import gradio as gr
import requests
from pdfminer.high_level import extract_text
from langchain_community.vectorstores import Chroma
from langchain_huggingface import HuggingFaceEmbeddings, ChatHuggingFace
from langchain_core.runnables import RunnablePassthrough, Runnable
from io import BytesIO
from langchain_openai import ChatOpenAI
from langchain_core.output_parsers import StrOutputParser
from langchain_core.documents import Document
from langchain_core.prompts import ChatPromptTemplate
from langchain.text_splitter import CharacterTextSplitter
# from huggingface_hub import InferenceClient
import logging
import os

# logging.basicConfig(level=logging.INFO)
# logger = logging.getLogger(__name__)

OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "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")


# client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")

# class HuggingFaceInterferenceClientRunnable(Runnable):
#     def __init__(self, client, max_tokens=512, temperature=0.7, top_p=0.95):
#         self.client = client
#         self.max_tokens = max_tokens
#         self.temperature = temperature
#         self.top_p = top_p

#     def invoke(self, input, config=None):
#         prompt = input.to_messages()[0].content
#         messages = [{"role": "user", "content": prompt}]

#         response = ""
#         for part in self.client.chat_completion(
#             messages,
#             max_tokens=self.max_tokens,
#             stream=True,
#             temperature=self.temperature,
#             top_p=self.top_p
#         ):
#             token = part.choices[0].delta.content
#             if token:
#                 response += token

#         return response

#     def update_params(self, max_tokens, temperature, top_p):
#         self.max_tokens = max_tokens
#         self.temperature=temperature
#         self.top_p=top_p


def extract_pdf_text(url: str) -> str:
    response = requests.get(url)
    pdf_file = BytesIO(response.content)
    text = extract_text(pdf_file)
    return text


pdf_url = "https://arxiv.org/pdf/2408.09869"
text = extract_pdf_text(pdf_url)
docs_list = [Document(page_content=text)]

text_splitter = CharacterTextSplitter.from_tiktoken_encoder(chunk_size=7500, chunk_overlap=100)
docs_splits = text_splitter.split_documents(docs_list)

embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
vectorstore = Chroma.from_documents(
    documents=docs_splits,
    collection_name="rag-chroma",
    embedding=embeddings,
)
retriever = vectorstore.as_retriever()

llm = ChatOpenAI(
    model="gpt-3.5-turbo",
    api_key=OPENAI_API_KEY,
    max_tokens=512,
    temperature=0.7,
    top_p=0.95
)

# Before RAG chain
before_rag_template = "What is {topic}"
before_rag_prompt = ChatPromptTemplate.from_template(before_rag_template)
before_rag_chain = before_rag_prompt | llm | StrOutputParser()

# After RAG chain
after_rag_template = """You are a {role}. Summarize the following content for yourself and speak in terms of first person.
Only include content relevant to that role like a resume summary.

Context:
{context}

Question: Give a one paragraph summary of the key skills a {role} can have from this document.
"""
after_rag_prompt = ChatPromptTemplate.from_template(after_rag_template)

def format_query(input_dict):
    return f"Give a one paragraph summary of the key skills a {input_dict['role']} can have from this document."

after_rag_chain = (
    {
        "context": format_query | retriever,
        "role": lambda x: x["role"],
    }
    | after_rag_prompt
    | llm
    | StrOutputParser()
)

def process_query(role, system_message, max_tokens, temperature, top_p):

    llm.max_tokens = max_tokens
    llm.temperature = temperature
    llm.top_p = top_p

    # Before RAG
    before_rag_result = before_rag_chain.invoke({"topic": "Hugging Face"})

    # After RAG
    after_rag_result = after_rag_chain.invoke({"role": role})

    return f"**Before RAG**\n{before_rag_result}\n\n**After RAG**\n{after_rag_result}"


with gr.Blocks() as demo:
    gr.Markdown("## Zephyr Chatbot Controls")

    role_dropdown = gr.Dropdown(choices=["SDE", "BA"], label="Select Role", value="SDE")

    system_message = gr.Textbox(value="You are a friendly chatbot.", label="System message")
    max_tokens = gr.Slider(1, 2048, value=512, label="Max tokens")
    temperature = gr.Slider(0.1, 4.0, value=0.7, label="Temperature", step=0.1)
    top_p = gr.Slider(0.1, 1.0, value=0.95, label="Top-p", step=0.05)

    output = gr.Textbox(label="Output", lines=20)

    submit_btn = gr.Button("Submit")
    clear_btn = gr.Button("Clear")

    submit_btn.click(
        fn=process_query,
        inputs=[role_dropdown, system_message, max_tokens, temperature, top_p],
        outputs=output
    )

    clear_btn.click(
        fn=lambda: ("", gr.Info("Chat cleared!")),
        outputs=[output]
    )

if __name__ == "__main__":
    demo.launch()