File size: 3,199 Bytes
258cebf
d594a38
258cebf
d594a38
 
 
258cebf
d594a38
 
258cebf
 
 
d594a38
 
 
 
 
 
 
258cebf
 
 
 
 
 
 
 
542890e
d594a38
 
 
 
 
 
258cebf
 
 
d594a38
 
 
258cebf
 
 
 
 
 
 
d594a38
258cebf
 
 
d594a38
 
 
258cebf
bda01ad
 
 
d594a38
bda01ad
 
542890e
258cebf
bda01ad
d594a38
 
bda01ad
d594a38
bda01ad
d594a38
258cebf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d594a38
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import datetime

import dotenv
import streamlit as st

from langchain_core.tracers.langchain import LangChainTracer
from langchain.callbacks.base import BaseCallbackHandler
from langsmith.client import Client

import web_rag as wr
import web_crawler as wc

dotenv.load_dotenv()

ls_tracer = LangChainTracer(
    project_name="Search Agent UI",
    client=Client()
)

class StreamHandler(BaseCallbackHandler):
    def __init__(self, container, initial_text=""):
        self.container = container
        self.text = initial_text

    def on_llm_new_token(self, token: str, **kwargs):
        self.text += token
        self.container.markdown(self.text)

chat = wr.get_chat_llm(provider="cohere")

st.title("πŸ” Simple Search Agent πŸ’¬")

if "messages" not in st.session_state:
    st.session_state["messages"] = [{"role": "assistant", "content": "How can I help you?"}]
    
if "input_disabled" not in st.session_state:
    st.session_state["input_disabled"] = False

for message in st.session_state.messages:
    st.chat_message(message["role"]).write(message["content"])
    if message["role"] == "assistant" and 'message_id' in message:
        st.download_button(
            label="Download",
            data=message["content"],
            file_name=f"{message['message_id']}.txt",
            mime="text/plain"
        )

if prompt := st.chat_input("Enter you instructions...", disabled=st.session_state["input_disabled"] ):
    
    st.session_state["input_disabled"] = True
    
    st.chat_message("user").write(prompt)
    st.session_state.messages.append({"role": "user", "content": prompt})

    with st.status("Thinking", expanded=True):
        st.write("I first need to do some research")
        
        optimize_search_query = wr.optimize_search_query(chat, query=prompt, callbacks=[ls_tracer])
        st.write(f"I should search the web for: {optimize_search_query}")
        
        sources = wc.get_sources(optimize_search_query, max_pages=20)
       
        st.write(f"I'll now retrieve the {len(sources)} webpages and documents I found")
        contents = wc.get_links_contents(sources)

        st.write( f"Reading through the {len(contents)} sources I managed to retrieve")
        vector_store = wc.vectorize(contents)
        st.write(f"I collected {vector_store.index.ntotal} chunk of data and I can now answer")

    rag_prompt = wr.build_rag_prompt(prompt, optimize_search_query, vector_store, top_k=5, callbacks=[ls_tracer])  
    with st.chat_message("assistant"):
        st_cb = StreamHandler(st.empty())
        result = chat.invoke(rag_prompt, stream=True, config={ "callbacks": [st_cb, ls_tracer]})
        response = result.content.strip()
        message_id = f"{prompt}{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
        st.session_state.messages.append({"role": "assistant", "content": response})
        if st.session_state.messages[-1]["role"] == "assistant":
            st.download_button(
                label="Download",
                data=st.session_state.messages[-1]["content"],
                file_name=f"{message_id}.txt",
                mime="text/plain"
            )
    st.session_state["input_disabled"] = False