disLodge commited on
Commit
dcfa5ff
·
verified ·
1 Parent(s): 420226c

Updated the RAG code

Browse files
Files changed (1) hide show
  1. app.py +93 -29
app.py CHANGED
@@ -1,32 +1,87 @@
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
2
  from huggingface_hub import InferenceClient
3
 
4
  client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
5
 
6
- def respond(
7
- message,
8
- history: list[tuple[str, str]],
9
- system_message,
10
- max_tokens,
11
- temperature,
12
- top_p,
13
- ):
14
- messages = [{"role": "system", "content": system_message}] + history
15
- messages.append({"role": "user", "content": message})
16
-
17
- response = ""
18
- for part in client.chat_completion(
19
- messages, max_tokens=max_tokens, stream=True, temperature=temperature,
20
- top_p=top_p
21
- ):
22
- token = part.choices[0].delta.content
23
- if token:
24
- response += token
25
-
26
- history.append({"role": "user", "content": message})
27
- history.append({"role": "assistant", "content": response})
28
-
29
- return history, ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
  with gr.Blocks() as demo:
32
  gr.Markdown("## Zephyr Chatbot Controls")
@@ -38,12 +93,21 @@ with gr.Blocks() as demo:
38
  temperature = gr.Slider(0.1, 4.0, value=0.7, label="Temperature", step=0.1)
39
  top_p = gr.Slider(0.1, 1.0, value=0.95, label="Top-p", step=0.05)
40
 
41
- with gr.Row():
42
- clear_btn = gr.Button("Clear Chat")
43
- dummy_btn = gr.Button("Dummy Action")
 
 
 
 
 
 
 
44
 
45
- clear_btn.click(lambda: gr.Info("Chat cleared!"))
46
- dummy_btn.click(lambda: gr.Info("Dummy action clicked!"))
 
 
47
 
48
  if __name__ == "__main__":
49
  demo.launch()
 
1
  import gradio as gr
2
+ import requests
3
+ from pdfminer.high_level import extract_text
4
+ from langchain_community.vectorstores import Chroma
5
+ from langchain_huggingface import HuggingFaceEmbeddings, ChatHuggingFace
6
+ from langchain_core.runnables import RunnablePassthrough
7
+ from io import BytesIO
8
+ from langchain_core.output_parsers import StrOutputParser
9
+ from langchain_core.documents import Document
10
+ from langchain_core.prompts import ChatPromptTemplate
11
+ from langchain.text_splitter import CharacterTextSplitter
12
  from huggingface_hub import InferenceClient
13
 
14
  client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
15
 
16
+ def extract_pdf_text(url: str) -> str:
17
+ response = requests.get(url)
18
+ pdf_file = BytesIO(response.content)
19
+ text = extract_text(pdf_file)
20
+ return text
21
+
22
+ pdf_url = "https://huggingface.co/spaces/disLodge/Call_model/raw/main/temp.pdf"
23
+ text = extract_pdf_text(pdf_url)
24
+ docs_splits = [Document(page_content=text)]
25
+
26
+ text_splitter = CharacterTextSplitter.from_tiktoken_encoder(chunk_size=7500, chunk_overlap=100)
27
+ docs_splits = text_splitter.split_documents(docs_list)
28
+
29
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
30
+ vectorstore = Chroma.from_documents(
31
+ documents=docs_splits,
32
+ collection_name="rag-chroma",
33
+ embedding=embeddings,
34
+ )
35
+ retriever = vectorstore.as_retriever()
36
+
37
+ llm = ChatHuggingFace(
38
+ huggingfacehub_api_token=None,
39
+ model_id="HuggingFaceH4/zephyr-7b-beta",
40
+ interference_client=client,
41
+ )
42
+
43
+ # Before RAG chain
44
+ before_rag_template = "What is {topic}"
45
+ before_rag_prompt = ChatPromptTemplate.from_template(before_rag_template)
46
+ before_rag_chain = before_rag_prompt | llm | StrOutputParser()
47
+
48
+ # After RAG chain
49
+ after_rag_template = """You are a {role}. Summarize the following content for yourself and speak in terms of first person.
50
+ Only include content relevant to that role like a resume summary.
51
+
52
+ Context:
53
+ {context}
54
+
55
+ Question: Give a one paragraph summary of the key skills a {role} can have from this document.
56
+ """
57
+ after_rag_prompt = ChatPromptTemplate.from_template(after_rag_template)
58
+
59
+ def format_query(input_dict):
60
+ return f"Give a one paragraph summary of the key skills a {input_dict['role']} can have from this document."
61
+
62
+ after_rag_chain = (
63
+ {
64
+ "context": format_query | retriever,
65
+ "role": lambda x: x["role"],
66
+ }
67
+ | after_rag_prompt
68
+ | llm
69
+ | StrOutputParser()
70
+ )
71
+
72
+ def process_query(role, system_message, max_tokens, temperature, top_p):
73
+ client.max_tokens = max_tokens
74
+ client.temperature = temperature
75
+ client.top_p = top_p
76
+
77
+ # Before RAG
78
+ before_rag_result = before_rag_chain.invoke({"topic": "Hugging Face"})
79
+
80
+ # After RAG
81
+ after_rag_result = after_rag_chain.invoke({"role": role})
82
+
83
+ return f"**Before RAG**\n{before_rag_result}\n\n**After RAG**\n{after_rag_result}"
84
+
85
 
86
  with gr.Blocks() as demo:
87
  gr.Markdown("## Zephyr Chatbot Controls")
 
93
  temperature = gr.Slider(0.1, 4.0, value=0.7, label="Temperature", step=0.1)
94
  top_p = gr.Slider(0.1, 1.0, value=0.95, label="Top-p", step=0.05)
95
 
96
+ output = gr.Textbox(label="Output", lines=20)
97
+
98
+ submit_btn = gr.Button("Submit")
99
+ clear_btn = gr.Button("Clear")
100
+
101
+ submit_btn.click(
102
+ fn=process_query,
103
+ inputs=[role_dropdown, system_message, max_tokens, temperature, top_p]
104
+ outputs=output
105
+ )
106
 
107
+ clear_btn.click(
108
+ fn=lambda: ("", gr.Info("Chat cleared!")),
109
+ outputs=[output]
110
+ )
111
 
112
  if __name__ == "__main__":
113
  demo.launch()