CyranoB commited on
Commit
8d1e83e
·
1 Parent(s): a63f98f

Trying multi query retrieve

Browse files
Files changed (3) hide show
  1. messages.py +114 -85
  2. requirements.txt +1 -0
  3. search_agent.py +77 -87
messages.py CHANGED
@@ -1,92 +1,121 @@
1
- import json
2
- from langchain.schema import SystemMessage, HumanMessage
3
-
4
- def get_optimized_search_messages(query):
5
- messages = [
6
- SystemMessage(
7
- content="""
8
- You are a serach query optimizer specialist.
9
- Provide a better search query for web search engine to answer the given question, end the queries with ’**’
10
- Tips:
11
- Identify the key concepts in the question
12
- Remove filler words like "how to", "what is", "I want to"
13
- Removed style such as "in the style of", "engaging", "short", "long"
14
- Remove lenght instruction (example: essay, article, letter, blog, post, blogpost, etc)
15
- Keep it short, around 3-7 words total
16
- Put the most important keywords first
17
- Remove formatting instructions
18
- Remove style instructions (exmaple: in the style of, engaging, short, long)
19
- Remove lenght instruction (example: essay, article, letter, etc)
20
- Example:
21
- Question: How do I bake chocolate chip cookies from scratch?
22
- Search query: chocolate chip cookies recipe from scratch**
23
- Example:
24
- Question: I would like you to show me a time line of Marie Curie life. Show results as a markdown table
25
- Search query: Marie Curie timeline**
26
- Example:
27
- Question: I would like you to write a long article on nato vs russia. Use know geopolical frameworks.
28
- Search query: geopolitics nato russia**
29
- Example:
30
- Question: Write a engaging linkedin post about Andrew Ng
31
- Search query: Andrew Ng**
32
- Example:
33
- Question: Write a short artible about the solar system in the style of Carl Sagan
34
- Search query: solar system**
35
- Example:
36
- Question: Should I use Kubernetes? Answer in the style of Gilfoyde from the TV show Silicon Valley
37
- Search query: Kubernetes decision**
38
- Example:
39
- Question: biography of napoleon. include a table with the major events.
40
- Search query: napoleon biography events**
41
- """
42
- ),
43
- HumanMessage(
44
- content=f"""
45
- Provide a better search query for web search engine to answer the given question, provide only one search query and nothing else, end the queries with ’**’.
46
- Question: {query}
47
- Search query:
48
- """
49
- ),
50
- ]
51
- return messages
52
 
53
- def get_query_with_sources_messages(query, relevant_docs):
54
- messages = [
55
- SystemMessage(
56
- content="""
57
- You are an expert research assistant.
58
- You are provided with a Context in JSON format and a Question.
 
59
 
60
- Use RAG to answer the Question, providing references and links to the Context material you retrieve and use in your answer:
61
- When generating your answer, follow these steps:
62
- - Retrieve the most relevant context material from your knowledge base to help answer the question
63
- - Cite the references you use by including the title, author, publication, and a link to each source
64
- - Synthesize the retrieved information into a clear, informative answer to the question
65
- - Format your answer in Markdown, using heading levels 2-3 as needed
66
- - Include a "References" section at the end with the full citations and link for each source you used
67
 
 
 
68
 
69
- Example of Context JSON entry:
70
- {
71
- "page_content": "This provides access to material related to ...",
72
- "metadata": {
73
- "title": "Introduction - Marie Curie: Topics in Chronicling America",
74
- "link": "https://guides.loc.gov/chronicling-america-marie-curie"
75
- }
76
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
 
78
  """
79
- ),
80
- HumanMessage(
81
- content= f"""
82
- Context information is below.
83
- Context:
84
- ---------------------
85
- {json.dumps(relevant_docs, indent=2, ensure_ascii=False)}
86
- ---------------------
87
- Question: {query}
88
- Answer:
89
  """
90
- ),
91
- ]
92
- return messages
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This module provides functions for generating optimized search messages, RAG prompt templates,
3
+ and messages for queries with relevant source documents using the LangChain library.
4
+ """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
+ from langchain.schema import SystemMessage, HumanMessage
7
+ from langchain.prompts.chat import (
8
+ HumanMessagePromptTemplate,
9
+ SystemMessagePromptTemplate,
10
+ ChatPromptTemplate
11
+ )
12
+ from langchain.prompts.prompt import PromptTemplate
13
 
14
+ def get_optimized_search_messages(query):
15
+ """
16
+ Generate optimized search messages for a given query.
 
 
 
 
17
 
18
+ Args:
19
+ query (str): The user's query.
20
 
21
+ Returns:
22
+ list: A list containing the system message and human message for optimized search.
23
+ """
24
+ system_message = SystemMessage(
25
+ content="""
26
+ I want you to act as a prompt optimizer for web search. I will provide you with a chat prompt, and your goal is to optimize it into a search string that will yield the most relevant and useful information from a search engine like Google.
27
+ To optimize the prompt:
28
+ Identify the key information being requested
29
+ Arrange the keywords into a concise search string
30
+ Keep it short, around 1 to 5 words total
31
+ Put the most important keywords first
32
+
33
+ Some tips and things to be sure to remove:
34
+ - Remove any conversational or instructional phrases
35
+ - Removed style such as "in the style of", "engaging", "short", "long"
36
+ - Remove lenght instruction (example: essay, article, letter, blog, post, blogpost, etc)
37
+ - Remove style instructions (exmaple: "in the style of", engaging, short, long)
38
+ - Remove lenght instruction (example: essay, article, letter, etc)
39
+
40
+ Add "**" to the end of the search string to indicate the end of the query
41
+ Provide your output in this format: optimized search string**
42
+
43
+ Example:
44
+ Question: How do I bake chocolate chip cookies from scratch?
45
+ Search query: chocolate chip cookies recipe from scratch**
46
+ Example:
47
+ Question: I would like you to show me a timeline of Marie Curie's life. Show results as a markdown table
48
+ Search query: Marie Curie timeline**
49
+ Example:
50
+ Question: I would like you to write a long article on NATO vs Russia. Use known geopolitical frameworks.
51
+ Search query: geopolitics nato russia**
52
+ Example:
53
+ Question: Write an engaging LinkedIn post about Andrew Ng
54
+ Search query: Andrew Ng**
55
+ Example:
56
+ Question: Write a short article about the solar system in the style of Carl Sagan
57
+ Search query: solar system**
58
+ Example:
59
+ Question: Should I use Kubernetes? Answer in the style of Gilfoyle from the TV show Silicon Valley
60
+ Search query: Kubernetes decision**
61
+ Example:
62
+ Question: Biography of Napoleon. Include a table with the major events.
63
+ Search query: napoleon biography events**
64
+ Example:
65
+ Question: Write a short article on the history of the United States. Include a table with the major events.
66
+ Search query: united states history events**
67
+ Example:
68
+ Question: Write a short article about the solar system in the style of donald trump
69
+ Search query: solar system**
70
+ """
71
+ )
72
+ human_message = HumanMessage(
73
+ content=f"""
74
+ Question: {query}
75
+ Search query:
76
+ """
77
+ )
78
+ return [system_message, human_message]
79
 
80
+ def get_rag_prompt_template():
81
  """
82
+ Get the prompt template for Retrieval-Augmented Generation (RAG).
83
+
84
+ Returns:
85
+ ChatPromptTemplate: The prompt template for RAG.
 
 
 
 
 
 
86
  """
87
+ system_prompt = SystemMessagePromptTemplate(
88
+ prompt=PromptTemplate(
89
+ input_variables=[],
90
+ template="""
91
+ You are an expert research assistant.
92
+ You are provided with a Context in JSON format and a Question.
93
+ Each JSON entry contains: content, title, link
94
+
95
+ Use RAG to answer the Question, providing references and links to the Context material you retrieve and use in your answer:
96
+ When generating your answer, follow these steps:
97
+ - Retrieve the most relevant context material from your knowledge base to help answer the question
98
+ - Cite the references you use by including the title, author, publication, and a link to each source
99
+ - Synthesize the retrieved information into a clear, informative answer to the question
100
+ - Format your answer in Markdown, using heading levels 2-3 as needed
101
+ - Include a "References" section at the end with the full citations and link for each source you used
102
+ """
103
+ )
104
+ )
105
+ human_prompt = HumanMessagePromptTemplate(
106
+ prompt=PromptTemplate(
107
+ input_variables=["context", "query"],
108
+ template="""
109
+ Context:
110
+ ---------------------
111
+ {context}
112
+ ---------------------
113
+ Question: {query}
114
+ Answer:
115
+ """
116
+ )
117
+ )
118
+ return ChatPromptTemplate(
119
+ input_variables=["context", "query"],
120
+ messages=[system_prompt, human_prompt],
121
+ )
requirements.txt CHANGED
@@ -4,6 +4,7 @@ docopt
4
  faiss-cpu
5
  python-dotenv
6
  langchain
 
7
  langchain_community
8
  langchain_openai
9
  langchain_groq
 
4
  faiss-cpu
5
  python-dotenv
6
  langchain
7
+ langchain_core
8
  langchain_community
9
  langchain_openai
10
  langchain_groq
search_agent.py CHANGED
@@ -32,28 +32,31 @@ from bs4 import BeautifulSoup
32
  from docopt import docopt
33
  import dotenv
34
 
 
35
  from langchain.text_splitter import RecursiveCharacterTextSplitter
36
- from langchain.schema import SystemMessage, HumanMessage
37
  from langchain.callbacks import LangChainTracer
38
  from langchain_groq import ChatGroq
39
  from langchain_openai import ChatOpenAI
40
- from langchain_community.chat_models import ChatOllama
41
  from langchain_openai import OpenAIEmbeddings
42
- from langchain_community.vectorstores.faiss import FAISS
43
  from langchain_community.chat_models.bedrock import BedrockChat
 
 
 
44
  from langsmith import Client
45
 
46
  import requests
47
 
48
  from rich.console import Console
49
- from rich.rule import Rule
50
  from rich.markdown import Markdown
51
 
 
52
 
53
- def get_chat_llm(provider, model, temperature=0.0):
 
54
  match provider:
55
  case 'bedrock':
56
- if(model == None):
57
  model = "anthropic.claude-3-sonnet-20240229-v1:0"
58
  chat_llm = BedrockChat(
59
  credentials_profile_name=os.getenv('CREDENTIALS_PROFILE_NAME'),
@@ -61,29 +64,28 @@ def get_chat_llm(provider, model, temperature=0.0):
61
  model_kwargs={"temperature": temperature },
62
  )
63
  case 'openai':
64
- if(model == None):
65
  model = "gpt-3.5-turbo"
66
  chat_llm = ChatOpenAI(model_name=model, temperature=temperature)
67
  case 'groq':
68
- if(model == None):
69
  model = 'mixtral-8x7b-32768'
70
  chat_llm = ChatGroq(model_name=model, temperature=temperature)
71
  case 'ollama':
72
- if(model == None):
73
- model = 'llam2'
74
  chat_llm = ChatOllama(model=model, temperature=temperature)
75
  case _:
76
  raise ValueError(f"Unknown LLM provider {provider}")
77
-
78
- console.log(f"Using {model} on {provider} with temperature {temperature}")
79
  return chat_llm
80
 
81
- def optimize_search_query(query):
82
- from messages import get_optimized_search_messages
83
  messages = get_optimized_search_messages(query)
84
- response = chat.invoke(messages, config={"callbacks": callbacks})
85
  optimized_search_query = response.content
86
- return optimized_search_query.strip('"').strip("**")
87
 
88
 
89
  def get_sources(query, max_pages=10, domain=None):
@@ -99,10 +101,10 @@ def get_sources(query, max_pages=10, domain=None):
99
  }
100
 
101
  try:
102
- response = requests.get(url, headers=headers)
103
 
104
  if response.status_code != 200:
105
- raise Exception(f"HTTP error! status: {response.status_code}")
106
 
107
  json_response = response.json()
108
 
@@ -140,8 +142,7 @@ def extract_main_content(html):
140
  element.extract()
141
  main_content = ' '.join(soup.body.get_text().split())
142
  return main_content
143
- except Exception as error:
144
- #console.log(f"Error extracting main content: {error}")
145
  return None
146
 
147
  def process_source(source):
@@ -159,68 +160,57 @@ def get_links_contents(sources):
159
  # Filter out None results
160
  return [result for result in results if result is not None]
161
 
162
- def process_and_vectorize_content(
163
- contents,
164
- query,
165
- text_chunk_size=1000,
166
- text_chunk_overlap=200,
167
- number_of_similarity_results=5
168
- ):
169
- """
170
- Process and vectorize content using Langchain.
171
-
172
- Args:
173
- contents (list): List of dictionaries containing 'title', 'link', and 'html' keys.
174
- query (str): Query string for similarity search.
175
- text_chunk_size (int): Size of each text chunk.
176
- text_chunk_overlap (int): Overlap between text chunks.
177
- number_of_similarity_results (int): Number of most similar results to return.
178
-
179
- Returns:
180
- list: List of most similar documents.
181
- """
182
  documents = []
183
-
184
  for content in contents:
185
  if content['html']:
186
  try:
187
- # Split text into chunks
188
- text_splitter = RecursiveCharacterTextSplitter(
189
- chunk_size=text_chunk_size,
190
- chunk_overlap=text_chunk_overlap
191
- )
192
- texts = text_splitter.split_text(content['html'])
193
-
194
- # Create metadata for each text chunk
195
- metadatas = [{'title': content['title'], 'link': content['link']} for _ in range(len(texts))]
196
-
197
- # Create vector store
198
- embeddings = OpenAIEmbeddings()
199
- docsearch = FAISS.from_texts(texts, embedding=embeddings, metadatas=metadatas)
200
-
201
- # Perform similarity search
202
- docs = docsearch.similarity_search(query, k=number_of_similarity_results)
203
- doc_dicts = [{'page_content': doc.page_content, 'metadata': doc.metadata} for doc in docs]
204
- documents.extend(doc_dicts)
205
-
206
  except Exception as e:
207
  console.log(f"[gray]Error processing content for {content['link']}: {e}")
208
 
209
-
210
- return documents
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
 
212
 
213
- def answer_query_with_sources(query, relevant_docs):
214
- from messages import get_query_with_sources_messages
215
- messages = get_query_with_sources_messages(query, relevant_docs)
216
- response = chat.invoke(messages, config={"callbacks": callbacks})
217
- return response
218
 
219
  console = Console()
220
  dotenv.load_dotenv()
221
 
222
  callbacks = []
223
- if(os.getenv("LANGCHAIN_API_KEY")):
224
  callbacks.append(
225
  LangChainTracer(
226
  project_name="search agent",
@@ -230,44 +220,44 @@ if(os.getenv("LANGCHAIN_API_KEY")):
230
  )
231
  )
232
 
233
- if __name__ == '__main__':
234
  arguments = docopt(__doc__, version='Search Agent 0.1')
235
 
236
  provider = arguments["--provider"]
237
  model = arguments["--model"]
238
  temperature = float(arguments["--temperature"])
239
- domain=arguments["--domain"]
240
  max_pages=arguments["--max_pages"]
241
  output=arguments["--output"]
242
  query = arguments["SEARCH_QUERY"]
243
-
244
  chat = get_chat_llm(provider, model, temperature)
245
-
246
  with console.status(f"[bold green]Optimizing query for search: {query}"):
247
- optimize_search_query = optimize_search_query(query)
248
- console.log(f"Optimized search query: [bold blue]{optimize_search_query}")
249
-
250
- with console.status(f"[bold green]Searching sources using the optimized query: {optimize_search_query}"):
 
 
251
  sources = get_sources(optimize_search_query, max_pages=max_pages, domain=domain)
252
  console.log(f"Found {len(sources)} sources {'on ' + domain if domain else ''}")
253
 
254
- with console.status(f"[bold green]Fetching content for {len(sources)} sources", spinner="growVertical"):
 
 
255
  contents = get_links_contents(sources)
256
  console.log(f"Managed to extract content from {len(contents)} sources")
257
 
258
- with console.status(
259
- f"[bold green]Processing {len(contents)} contents and finding relevant extracts",
260
- spinner="dots8Bit"
261
- ):
262
- relevant_docs = process_and_vectorize_content(contents, query)
263
- console.log(f"Filtered {len(relevant_docs)} relevant content extracts")
264
 
265
- with console.status(f"[bold green]Querying LLM with {len(relevant_docs)} relevant extracts", spinner='dots8Bit'):
266
- respomse = answer_query_with_sources(query, relevant_docs)
267
 
268
  console.rule(f"[bold green]Response from {provider}")
269
  if output == "text":
270
- console.print(respomse.content)
271
  else:
272
- console.print(Markdown(respomse.content))
273
  console.rule("[bold green]")
 
32
  from docopt import docopt
33
  import dotenv
34
 
35
+ from langchain_core.documents.base import Document
36
  from langchain.text_splitter import RecursiveCharacterTextSplitter
37
+ from langchain.retrievers.multi_query import MultiQueryRetriever
38
  from langchain.callbacks import LangChainTracer
39
  from langchain_groq import ChatGroq
40
  from langchain_openai import ChatOpenAI
 
41
  from langchain_openai import OpenAIEmbeddings
 
42
  from langchain_community.chat_models.bedrock import BedrockChat
43
+ from langchain_community.chat_models.ollama import ChatOllama
44
+ from langchain_community.vectorstores.faiss import FAISS
45
+
46
  from langsmith import Client
47
 
48
  import requests
49
 
50
  from rich.console import Console
 
51
  from rich.markdown import Markdown
52
 
53
+ from messages import get_rag_prompt_template, get_optimized_search_messages
54
 
55
+
56
+ def get_chat_llm(provider, model=None, temperature=0.0):
57
  match provider:
58
  case 'bedrock':
59
+ if model is None:
60
  model = "anthropic.claude-3-sonnet-20240229-v1:0"
61
  chat_llm = BedrockChat(
62
  credentials_profile_name=os.getenv('CREDENTIALS_PROFILE_NAME'),
 
64
  model_kwargs={"temperature": temperature },
65
  )
66
  case 'openai':
67
+ if model is None:
68
  model = "gpt-3.5-turbo"
69
  chat_llm = ChatOpenAI(model_name=model, temperature=temperature)
70
  case 'groq':
71
+ if model is None:
72
  model = 'mixtral-8x7b-32768'
73
  chat_llm = ChatGroq(model_name=model, temperature=temperature)
74
  case 'ollama':
75
+ if model is None:
76
+ model = 'llama2'
77
  chat_llm = ChatOllama(model=model, temperature=temperature)
78
  case _:
79
  raise ValueError(f"Unknown LLM provider {provider}")
80
+
81
+ console.log(f"Using {model} on {provider} with temperature {temperature}")
82
  return chat_llm
83
 
84
+ def optimize_search_query(chat_llm, query):
 
85
  messages = get_optimized_search_messages(query)
86
+ response = chat_llm.invoke(messages, config={"callbacks": callbacks})
87
  optimized_search_query = response.content
88
+ return optimized_search_query.strip('"').split("**", 1)[0]
89
 
90
 
91
  def get_sources(query, max_pages=10, domain=None):
 
101
  }
102
 
103
  try:
104
+ response = requests.get(url, headers=headers, timeout=30)
105
 
106
  if response.status_code != 200:
107
+ return []
108
 
109
  json_response = response.json()
110
 
 
142
  element.extract()
143
  main_content = ' '.join(soup.body.get_text().split())
144
  return main_content
145
+ except Exception:
 
146
  return None
147
 
148
  def process_source(source):
 
160
  # Filter out None results
161
  return [result for result in results if result is not None]
162
 
163
+ def vectorize(contents, text_chunk_size=1000,text_chunk_overlap=200,):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
  documents = []
 
165
  for content in contents:
166
  if content['html']:
167
  try:
168
+ page_content = content['html']
169
+ metadata = {'title': content['title'], 'source': content['link']}
170
+ doc = Document(page_content=page_content, metadata=metadata)
171
+ documents.append(doc)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  except Exception as e:
173
  console.log(f"[gray]Error processing content for {content['link']}: {e}")
174
 
175
+ text_splitter = RecursiveCharacterTextSplitter(
176
+ chunk_size=text_chunk_size,
177
+ chunk_overlap=text_chunk_overlap
178
+ )
179
+ docs = text_splitter.split_documents(documents)
180
+ embeddings = OpenAIEmbeddings()
181
+ store = FAISS.from_documents(docs, embeddings)
182
+ return store
183
+
184
+ def format_docs(docs):
185
+ formatted_docs = []
186
+ for d in docs:
187
+ content = d.page_content
188
+ title = d.metadata['title']
189
+ source = d.metadata['source']
190
+ doc = {"content": content, "title": title, "link": source}
191
+ formatted_docs.append(doc)
192
+ docs_as_json = json.dumps(formatted_docs, indent=2, ensure_ascii=False)
193
+ return docs_as_json
194
+
195
+
196
+ def query_rag(chat_llm, question, search_query, vectorstore):
197
+ retriever_from_llm = MultiQueryRetriever.from_llm(
198
+ retriever=vectorstore.as_retriever(), llm=chat_llm,
199
+ )
200
+ unique_docs = retriever_from_llm.get_relevant_documents(query=search_query, config={"callbacks": callbacks})
201
+ context = format_docs(unique_docs)
202
+ prompt = get_rag_prompt_template().format(query=question, context=context)
203
+ response = chat_llm.invoke(prompt, config={"callbacks": callbacks})
204
+ return response.content
205
+
206
 
207
 
 
 
 
 
 
208
 
209
  console = Console()
210
  dotenv.load_dotenv()
211
 
212
  callbacks = []
213
+ if os.getenv("LANGCHAIN_API_KEY"):
214
  callbacks.append(
215
  LangChainTracer(
216
  project_name="search agent",
 
220
  )
221
  )
222
 
223
+ if __name__ == '__main__':
224
  arguments = docopt(__doc__, version='Search Agent 0.1')
225
 
226
  provider = arguments["--provider"]
227
  model = arguments["--model"]
228
  temperature = float(arguments["--temperature"])
229
+ domain=arguments["--domain"]
230
  max_pages=arguments["--max_pages"]
231
  output=arguments["--output"]
232
  query = arguments["SEARCH_QUERY"]
233
+
234
  chat = get_chat_llm(provider, model, temperature)
235
+
236
  with console.status(f"[bold green]Optimizing query for search: {query}"):
237
+ optimize_search_query = optimize_search_query(chat, query)
238
+ console.log(f"Optimized search query: [bold blue]{optimize_search_query}")
239
+
240
+ with console.status(
241
+ f"[bold green]Searching sources using the optimized query: {optimize_search_query}"
242
+ ):
243
  sources = get_sources(optimize_search_query, max_pages=max_pages, domain=domain)
244
  console.log(f"Found {len(sources)} sources {'on ' + domain if domain else ''}")
245
 
246
+ with console.status(
247
+ f"[bold green]Fetching content for {len(sources)} sources", spinner="growVertical"
248
+ ):
249
  contents = get_links_contents(sources)
250
  console.log(f"Managed to extract content from {len(contents)} sources")
251
 
252
+ with console.status(f"[bold green]Embeddubg {len(sources)} sources", spinner="growVertical"):
253
+ vector_store = vectorize(contents)
 
 
 
 
254
 
255
+ with console.status("[bold green]Querying LLM relevant context", spinner='dots8Bit'):
256
+ respomse = query_rag(chat, query, optimize_search_query, vector_store)
257
 
258
  console.rule(f"[bold green]Response from {provider}")
259
  if output == "text":
260
+ console.print(respomse)
261
  else:
262
+ console.print(Markdown(respomse))
263
  console.rule("[bold green]")