File size: 6,674 Bytes
258cebf
df527c8
d594a38
258cebf
d594a38
 
 
258cebf
d594a38
 
258cebf
 
9847233
258cebf
d594a38
 
 
df527c8
d594a38
 
 
258cebf
df527c8
258cebf
 
 
 
 
 
 
9847233
542890e
9847233
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d594a38
 
df527c8
 
9847233
 
df527c8
 
 
 
 
 
 
 
 
 
 
 
9847233
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
df527c8
d594a38
 
9847233
d594a38
 
258cebf
 
 
 
 
 
 
d594a38
df527c8
d594a38
 
258cebf
df527c8
 
bda01ad
 
 
d594a38
bda01ad
 
df527c8
9847233
258cebf
bda01ad
d594a38
 
bda01ad
df527c8
bda01ad
9847233
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d594a38
258cebf
 
 
 
 
 
9847233
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
import datetime
import os

import dotenv
import streamlit as st

from langchain_core.tracers.langchain import LangChainTracer
from langchain.callbacks.base import BaseCallbackHandler
from langsmith.client import Client

import web_rag as wr
import web_crawler as wc
import copywriter as cw

dotenv.load_dotenv()

ls_tracer = LangChainTracer(
    project_name=os.getenv("LANGSMITH_PROJECT_NAME"),
    client=Client()
)

class StreamHandler(BaseCallbackHandler):
    """Stream handler that appends tokens to container."""
    def __init__(self, container, initial_text=""):
        self.container = container
        self.text = initial_text

    def on_llm_new_token(self, token: str, **kwargs):
        self.text += token
        self.container.markdown(self.text)
        

def create_links_markdown(sources_list):
    """
    Create a markdown string for each source in the provided JSON.
    
    Args:
        sources_list (list): A list of dictionaries representing the sources.
                        Each dictionary should have 'title', 'link', and 'snippet' keys.
    
    Returns:
        str: A markdown string with a bullet point for each source,
             including the title linked to the URL and the snippet.
    """
    markdown_list = []
    for source in sources_list:
        title = source['title']
        link = source['link']
        snippet = source['snippet']
        markdown = f"- [{title}]({link})\n  {snippet}"
        markdown_list.append(markdown)
    return "\n".join(markdown_list)

st.set_page_config(layout="wide")
st.title("πŸ” Simple Search Agent πŸ’¬")

if "providers" not in st.session_state:
    providers = []
    if os.getenv("FIREWORKS_API_KEY"):
        providers.append("fireworks")
    if os.getenv("COHERE_API_KEY"):
        providers.append("cohere")
    if os.getenv("OPENAI_API_KEY"):
        providers.append("openai")
    if os.getenv("GROQ_API_KEY"):
        providers.append("groq")
    if os.getenv("OLLAMA_API_KEY"):
        providers.append("ollama")
    if os.getenv("CREDENTIALS_PROFILE_NAME"):
        providers.append("bedrock")
    st.session_state["providers"] = providers

with st.sidebar.expander("Options", expanded=False):
    model_provider = st.selectbox("Model provider 🧠", st.session_state["providers"])
    temperature = st.slider("Model temperature 🌑️", 0.0, 1.0, 0.1, help="The higher the more creative")
    max_pages = st.slider("Max pages to retrieve πŸ”", 1, 20, 15, help="How many web pages to retrive from the internet")
    top_k_documents = st.slider("Nbr of doc extracts to consider πŸ“„", 1, 20, 5, help="How many of the top extracts to consider")
    reviewer_mode =  st.checkbox("Draft / Comment / Rewrite mode ✍️", value=False, help="First generate a write, then comments and then rewrite")

with st.sidebar.expander("Links", expanded=False):
    links_md = st.markdown("")

if reviewer_mode:
    with st.sidebar.expander("Answer review", expanded=False):
        st.caption("Draft")  
        draft_md = st.markdown("")
        st.divider()
        st.caption("Comments")
        comments_md = st.markdown("")
        st.divider()
        st.caption("Comparaison")
        comparaison_md = st.markdown("")   

if "messages" not in st.session_state:
    st.session_state["messages"] = [{"role": "assistant", "content": "How can I help you?"}]

for message in st.session_state.messages:
    st.chat_message(message["role"]).write(message["content"])
    if message["role"] == "assistant" and 'message_id' in message:
        st.download_button(
            label="Download",
            data=message["content"],
            file_name=f"{message['message_id']}.txt",
            mime="text/plain"
        )

if prompt := st.chat_input("Enter you instructions..." ):   
    st.chat_message("user").write(prompt)
    st.session_state.messages.append({"role": "user", "content": prompt})

    chat, embedding_model = wr.get_models(model_provider, temperature=temperature)

    with st.status("Thinking", expanded=True):
        st.write("I first need to do some research")
        
        optimize_search_query = wr.optimize_search_query(chat, query=prompt, callbacks=[ls_tracer])
        st.write(f"I should search the web for: {optimize_search_query}")
        
        sources = wc.get_sources(optimize_search_query, max_pages=max_pages)
        links_md.markdown(create_links_markdown(sources))
       
        st.write(f"I'll now retrieve the {len(sources)} webpages and documents I found")
        contents = wc.get_links_contents(sources)

        st.write( f"Reading through the {len(contents)} sources I managed to retrieve")
        vector_store = wc.vectorize(contents, embedding_model=embedding_model)
        st.write(f"I collected {vector_store.index.ntotal} chunk of data and I can now answer")
      
    
        if reviewer_mode:
            st.write("Creating a draft")
            draft_prompt = wr.build_rag_prompt(
                chat, prompt, optimize_search_query, 
                vector_store, top_k=top_k_documents, callbacks=[ls_tracer])
            draft = chat.invoke(draft_prompt, stream=False, config={ "callbacks": [ls_tracer]})
            draft_md.markdown(draft.content)
            st.write("Sending draft for review")
            comments = cw.generate_comments(chat, prompt, draft, callbacks=[ls_tracer])
            comments_md.markdown(comments)
            st.write("Reviewing comments and generating final answer")
            rag_prompt = cw.get_final_text_prompt(prompt, draft, comments)
        else:
            rag_prompt = wr.build_rag_prompt(
                chat, prompt, optimize_search_query, vector_store,
                top_k=top_k_documents, callbacks=[ls_tracer]
            )

    with st.chat_message("assistant"):
        st_cb = StreamHandler(st.empty())
        result = chat.invoke(rag_prompt, stream=True, config={ "callbacks": [st_cb, ls_tracer]})
        response = result.content.strip()
        message_id = f"{prompt}{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
        st.session_state.messages.append({"role": "assistant", "content": response})
        
    if st.session_state.messages[-1]["role"] == "assistant":
        st.download_button(
            label="Download",
            data=st.session_state.messages[-1]["content"],
            file_name=f"{message_id}.txt",
            mime="text/plain"
        )
        
    if reviewer_mode:
        compare_prompt = cw.get_compare_texts_prompts(prompt, draft_text=draft, final_text=response)
        result = chat.invoke(compare_prompt, stream=False, config={ "callbacks": [ls_tracer]})
        comparaison_md.markdown(result.content)