CyranoB commited on
Commit
258cebf
Β·
1 Parent(s): 4fd7d00

Added streaming in the web ui

Browse files
Files changed (2) hide show
  1. search_agent_ui.py +51 -17
  2. web_rag.py +11 -3
search_agent_ui.py CHANGED
@@ -1,13 +1,15 @@
1
- import dotenv
2
 
 
3
  import streamlit as st
4
 
5
- import web_rag as wr
6
- import web_crawler as wc
7
-
8
  from langchain_core.tracers.langchain import LangChainTracer
 
9
  from langsmith.client import Client
10
 
 
 
 
11
  dotenv.load_dotenv()
12
 
13
  ls_tracer = LangChainTracer(
@@ -15,6 +17,14 @@ ls_tracer = LangChainTracer(
15
  client=Client()
16
  )
17
 
 
 
 
 
 
 
 
 
18
 
19
  chat = wr.get_chat_llm(provider="cohere")
20
 
@@ -22,40 +32,64 @@ st.title("πŸ” Simple Search Agent πŸ’¬")
22
 
23
  if "messages" not in st.session_state:
24
  st.session_state["messages"] = [{"role": "assistant", "content": "How can I help you?"}]
 
 
 
25
 
26
  for message in st.session_state.messages:
27
  st.chat_message(message["role"]).write(message["content"])
 
 
 
 
 
 
 
28
 
29
- if prompt := st.chat_input():
 
 
30
 
31
  st.chat_message("user").write(prompt)
32
  st.session_state.messages.append({"role": "user", "content": prompt})
33
-
34
  message = "I first need to do some research"
35
  st.chat_message("assistant").write(message)
36
  st.session_state.messages.append({"role": "assistant", "content": message})
37
-
38
  with st.spinner("Optimizing search query"):
39
  optimize_search_query = wr.optimize_search_query(chat, query=prompt, callbacks=[ls_tracer])
40
-
41
  message = f"I'll search the web for: {optimize_search_query}"
42
  st.chat_message("assistant").write(message)
43
  st.session_state.messages.append({"role": "assistant", "content": message})
44
-
45
-
46
  with st.spinner(f"Searching the web for: {optimize_search_query}"):
47
  sources = wc.get_sources(optimize_search_query, max_pages=20)
48
-
49
  with st.spinner(f"I'm now retrieveing the {len(sources)} webpages and documents I found (be patient)"):
50
  contents = wc.get_links_contents(sources)
51
 
52
-
53
  with st.spinner( f"Reading through the {len(contents)} sources I managed to retrieve"):
54
  vector_store = wc.vectorize(contents)
 
 
 
 
55
 
56
- with st.spinner( "Ok I have now enough information to answer"):
57
- response = wr.query_rag(chat, prompt, optimize_search_query, vector_store, callbacks=[ls_tracer])
58
-
59
- st.chat_message("assistant").write(response)
60
- st.session_state.messages.append({"role": "assistant", "content": response})
 
 
 
 
 
 
 
 
 
 
61
 
 
1
+ import datetime
2
 
3
+ import dotenv
4
  import streamlit as st
5
 
 
 
 
6
  from langchain_core.tracers.langchain import LangChainTracer
7
+ from langchain.callbacks.base import BaseCallbackHandler
8
  from langsmith.client import Client
9
 
10
+ import web_rag as wr
11
+ import web_crawler as wc
12
+
13
  dotenv.load_dotenv()
14
 
15
  ls_tracer = LangChainTracer(
 
17
  client=Client()
18
  )
19
 
20
+ class StreamHandler(BaseCallbackHandler):
21
+ def __init__(self, container, initial_text=""):
22
+ self.container = container
23
+ self.text = initial_text
24
+
25
+ def on_llm_new_token(self, token: str, **kwargs):
26
+ self.text += token
27
+ self.container.markdown(self.text)
28
 
29
  chat = wr.get_chat_llm(provider="cohere")
30
 
 
32
 
33
  if "messages" not in st.session_state:
34
  st.session_state["messages"] = [{"role": "assistant", "content": "How can I help you?"}]
35
+
36
+ if "input_disabled" not in st.session_state:
37
+ st.session_state["input_disabled"] = False
38
 
39
  for message in st.session_state.messages:
40
  st.chat_message(message["role"]).write(message["content"])
41
+ if message["role"] == "assistant" and 'message_id' in message:
42
+ st.download_button(
43
+ label="Download",
44
+ data=message["content"],
45
+ file_name=f"{message['message_id']}.txt",
46
+ mime="text/plain"
47
+ )
48
 
49
+ if prompt := st.chat_input("Enter you instructions...", disabled=st.session_state["input_disabled"] ):
50
+
51
+ st.session_state["input_disabled"] = True
52
 
53
  st.chat_message("user").write(prompt)
54
  st.session_state.messages.append({"role": "user", "content": prompt})
55
+
56
  message = "I first need to do some research"
57
  st.chat_message("assistant").write(message)
58
  st.session_state.messages.append({"role": "assistant", "content": message})
59
+
60
  with st.spinner("Optimizing search query"):
61
  optimize_search_query = wr.optimize_search_query(chat, query=prompt, callbacks=[ls_tracer])
62
+
63
  message = f"I'll search the web for: {optimize_search_query}"
64
  st.chat_message("assistant").write(message)
65
  st.session_state.messages.append({"role": "assistant", "content": message})
66
+
 
67
  with st.spinner(f"Searching the web for: {optimize_search_query}"):
68
  sources = wc.get_sources(optimize_search_query, max_pages=20)
69
+
70
  with st.spinner(f"I'm now retrieveing the {len(sources)} webpages and documents I found (be patient)"):
71
  contents = wc.get_links_contents(sources)
72
 
 
73
  with st.spinner( f"Reading through the {len(contents)} sources I managed to retrieve"):
74
  vector_store = wc.vectorize(contents)
75
+
76
+ message = f"Got {vector_store.index.ntotal} chunk of data"
77
+ st.chat_message("assistant").write(message)
78
+ st.session_state.messages.append({"role": "assistant", "content": message})
79
 
80
+ rag_prompt = wr.build_rag_prompt(prompt, optimize_search_query, vector_store, top_k=5, callbacks=[ls_tracer])
81
+ with st.chat_message("assistant"):
82
+ st_cb = StreamHandler(st.empty())
83
+ result = chat.invoke(rag_prompt, stream=True, config={ "callbacks": [st_cb, ls_tracer]})
84
+ response = result.content.strip()
85
+ message_id = f"{prompt}{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
86
+ st.session_state.messages.append({"role": "assistant", "content": response})
87
+ if st.session_state.messages[-1]["role"] == "assistant":
88
+ st.download_button(
89
+ label="Download",
90
+ data=st.session_state.messages[-1]["content"],
91
+ file_name=f"{message_id}.txt",
92
+ mime="text/plain"
93
+ )
94
+ st.session_state["input_disabled"] = False
95
 
web_rag.py CHANGED
@@ -120,6 +120,9 @@ def get_optimized_search_messages(query):
120
  Example:
121
  Question: Write a short article about the solar system in the style of donald trump
122
  Search query: solar system**
 
 
 
123
  """
124
  )
125
  human_message = HumanMessage(
@@ -209,9 +212,14 @@ def multi_query_rag(chat_llm, question, search_query, vectorstore, callbacks = [
209
  return response.content
210
 
211
 
212
- def query_rag(chat_llm, question, search_query, vectorstore, callbacks = []):
213
- unique_docs = vectorstore.similarity_search(search_query, k=15, callbacks=callbacks, verbose=True)
 
214
  context = format_docs(unique_docs)
215
  prompt = get_rag_prompt_template().format(query=question, context=context)
 
 
 
 
216
  response = chat_llm.invoke(prompt, config={"callbacks": callbacks})
217
- return response.content
 
120
  Example:
121
  Question: Write a short article about the solar system in the style of donald trump
122
  Search query: solar system**
123
+ Exmaple:
124
+ Question: Write a short linkedin about how the "freakeconomics" book previsions didn't pan out
125
+ Search query: freakeconomics book predictions failed**
126
  """
127
  )
128
  human_message = HumanMessage(
 
212
  return response.content
213
 
214
 
215
+ def build_rag_prompt(question, search_query, vectorstore, top_k = 10, callbacks = []):
216
+ unique_docs = vectorstore.similarity_search(
217
+ search_query, k=top_k, callbacks=callbacks, verbose=True)
218
  context = format_docs(unique_docs)
219
  prompt = get_rag_prompt_template().format(query=question, context=context)
220
+ return prompt
221
+
222
+ def query_rag(chat_llm, question, search_query, vectorstore, callbacks = []):
223
+ prompt = build_rag_prompt(question, search_query, vectorstore, callbacks)
224
  response = chat_llm.invoke(prompt, config={"callbacks": callbacks})
225
+ return response.content