search_agent / search_agent.py
CyranoB's picture
Cohere support. Added Selenium to fetch some pages.
d21cce9
raw
history blame
11.8 kB
"""search_agent.py
Usage:
search_agent.py
[--domain=domain]
[--provider=provider]
[--model=model]
[--temperature=temp]
[--max_pages=num]
[--output=text]
SEARCH_QUERY
search_agent.py --version
Options:
-h --help Show this screen.
--version Show version.
-d domain --domain=domain Limit search to a specific domain
-t temp --temperature=temp Set the temperature of the LLM [default: 0.0]
-p provider --provider=provider Use a specific LLM (choices: bedrock,openai,groq,ollama,cohere) [default: openai]
-m model --model=model Use a specific model
-n num --max_pages=num Max number of pages to retrieve [default: 10]
-o text --output=text Output format (choices: text, markdown) [default: markdown]
"""
import json
import os
import io
from concurrent.futures import ThreadPoolExecutor
from urllib.parse import quote
from docopt import docopt
import dotenv
import pdfplumber
from trafilatura import extract
from selenium import webdriver
from selenium.webdriver.chrome.options import Options
from langchain_core.documents.base import Document
from langchain_experimental.text_splitter import SemanticChunker
from langchain.retrievers.multi_query import MultiQueryRetriever
from langchain.callbacks import LangChainTracer
from langchain_cohere.chat_models import ChatCohere
from langchain_groq import ChatGroq
from langchain_openai import ChatOpenAI
from langchain_openai import OpenAIEmbeddings
from langchain_community.chat_models.bedrock import BedrockChat
from langchain_community.chat_models.ollama import ChatOllama
from langchain_community.vectorstores.faiss import FAISS
from langsmith import Client
import requests
from rich.console import Console
from rich.markdown import Markdown
from messages import get_rag_prompt_template, get_optimized_search_messages
def get_chat_llm(provider, model=None, temperature=0.0):
match provider:
case 'bedrock':
if model is None:
model = "anthropic.claude-3-sonnet-20240229-v1:0"
chat_llm = BedrockChat(
credentials_profile_name=os.getenv('CREDENTIALS_PROFILE_NAME'),
model_id=model,
model_kwargs={"temperature": temperature },
)
case 'openai':
if model is None:
model = "gpt-3.5-turbo"
chat_llm = ChatOpenAI(model_name=model, temperature=temperature)
case 'groq':
if model is None:
model = 'mixtral-8x7b-32768'
chat_llm = ChatGroq(model_name=model, temperature=temperature)
case 'ollama':
if model is None:
model = 'llama2'
chat_llm = ChatOllama(model=model, temperature=temperature)
case 'cohere':
if model is None:
model = 'command-r-plus'
chat_llm = ChatCohere(model=model, temperature=temperature)
case _:
raise ValueError(f"Unknown LLM provider {provider}")
console.log(f"Using {model} on {provider} with temperature {temperature}")
return chat_llm
def optimize_search_query(chat_llm, query):
messages = get_optimized_search_messages(query)
response = chat_llm.invoke(messages, config={"callbacks": callbacks})
optimized_search_query = response.content
return optimized_search_query.strip('"').split("**", 1)[0]
def get_sources(query, max_pages=10, domain=None):
search_query = query
if domain:
search_query += f" site:{domain}"
url = f"https://api.search.brave.com/res/v1/web/search?q={quote(search_query)}&count={max_pages}"
headers = {
'Accept': 'application/json',
'Accept-Encoding': 'gzip',
'X-Subscription-Token': os.getenv("BRAVE_SEARCH_API_KEY")
}
try:
response = requests.get(url, headers=headers, timeout=30)
if response.status_code != 200:
return []
json_response = response.json()
if 'web' not in json_response or 'results' not in json_response['web']:
raise Exception('Invalid API response format')
final_results = [{
'title': result['title'],
'link': result['url'],
'snippet': result['description'],
'favicon': result.get('profile', {}).get('img', '')
} for result in json_response['web']['results']]
return final_results
except Exception as error:
console.log('Error fetching search results:', error)
raise
def fetch_with_selenium(url, timeout=8):
chrome_options = Options()
chrome_options.add_argument("headless")
chrome_options.add_argument("--disable-extensions")
chrome_options.add_argument("--disable-gpu")
chrome_options.add_argument("--no-sandbox")
chrome_options.add_argument("--disable-dev-shm-usage")
chrome_options.add_argument("--remote-debugging-port=9222")
chrome_options.add_argument('--blink-settings=imagesEnabled=false')
chrome_options.add_argument("--window-size=1920,1080")
driver = webdriver.Chrome(options=chrome_options)
driver.get(url)
driver.execute_script("window.scrollTo(0, document.body.scrollHeight);")
html = driver.page_source
driver.quit()
return html
def fetch_with_timeout(url, timeout=8):
try:
response = requests.get(url, timeout=timeout)
response.raise_for_status()
return response
except requests.RequestException as error:
return None
def process_source(source):
url = source['link']
#console.log(f"Processing {url}")
response = fetch_with_timeout(url, 8)
if response:
content_type = response.headers.get('Content-Type')
if content_type:
if content_type.startswith('application/pdf'):
# The response is a PDF file
pdf_content = response.content
# Create a file-like object from the bytes
pdf_file = io.BytesIO(pdf_content)
# Extract text from PDF using pdfplumber
with pdfplumber.open(pdf_file) as pdf:
text = ""
for page in pdf.pages:
text += page.extract_text()
return {**source, 'page_content': text}
elif content_type.startswith('text/html'):
# The response is an HTML file
html = response.text
main_content = extract(html, output_format='txt', include_links=True)
return {**source, 'page_content': main_content}
else:
console.log(f"Skipping {url}! Unsupported content type: {content_type}")
return {**source, 'page_content': source['snippet']}
else:
console.log(f"Skipping {url}! No content type")
return {**source, 'page_content': source['snippet']}
return {**source, 'page_content': None}
def get_links_contents(sources):
with ThreadPoolExecutor() as executor:
results = list(executor.map(process_source, sources))
for result in results:
if result['page_content'] is None:
url = result['link']
console.log(f"Fetching with selenium {url}")
html = fetch_with_selenium(url, 8)
main_content = extract(html, output_format='txt', include_links=True)
if main_content:
result['page_content'] = main_content
# Filter out None results
return [result for result in results if result is not None]
def vectorize(contents, text_chunk_size=400,text_chunk_overlap=40):
documents = []
for content in contents:
try:
metadata = {'title': content['title'], 'source': content['link']}
doc = Document(page_content=content['page_content'], metadata=metadata)
documents.append(doc)
except Exception as e:
console.log(f"[gray]Error processing content for {content['link']}: {e}")
semantic_chunker = SemanticChunker(OpenAIEmbeddings(model="text-embedding-3-large"), breakpoint_threshold_type="percentile")
docs = semantic_chunker.split_documents(documents)
console.log(f"Vectorizing {len(docs)} document chunks")
embeddings = OpenAIEmbeddings()
store = FAISS.from_documents(docs, embeddings)
return store
def format_docs(docs):
formatted_docs = []
for d in docs:
content = d.page_content
title = d.metadata['title']
source = d.metadata['source']
doc = {"content": content, "title": title, "link": source}
formatted_docs.append(doc)
docs_as_json = json.dumps(formatted_docs, indent=2, ensure_ascii=False)
return docs_as_json
def multi_query_rag(chat_llm, question, search_query, vectorstore):
retriever_from_llm = MultiQueryRetriever.from_llm(
retriever=vectorstore.as_retriever(), llm=chat_llm, include_original=True,
)
unique_docs = retriever_from_llm.get_relevant_documents(
query=search_query, callbacks=callbacks, verbose=True
)
context = format_docs(unique_docs)
prompt = get_rag_prompt_template().format(query=question, context=context)
response = chat_llm.invoke(prompt, config={"callbacks": callbacks})
return response.content
def query_rag(chat_llm, question, search_query, vectorstore):
unique_docs = vectorstore.similarity_search(search_query, k=15, callbacks=callbacks, verbose=True)
context = format_docs(unique_docs)
prompt = get_rag_prompt_template().format(query=question, context=context)
response = chat_llm.invoke(prompt, config={"callbacks": callbacks})
return response.content
console = Console()
dotenv.load_dotenv()
callbacks = []
if os.getenv("LANGCHAIN_API_KEY"):
callbacks.append(
LangChainTracer(
project_name="search agent",
client=Client(
api_url="https://api.smith.langchain.com",
)
)
)
if __name__ == '__main__':
arguments = docopt(__doc__, version='Search Agent 0.1')
provider = arguments["--provider"]
model = arguments["--model"]
temperature = float(arguments["--temperature"])
domain=arguments["--domain"]
max_pages=arguments["--max_pages"]
output=arguments["--output"]
query = arguments["SEARCH_QUERY"]
chat = get_chat_llm(provider, model, temperature)
with console.status(f"[bold green]Optimizing query for search: {query}"):
optimize_search_query = optimize_search_query(chat, query)
console.log(f"Optimized search query: [bold blue]{optimize_search_query}")
with console.status(
f"[bold green]Searching sources using the optimized query: {optimize_search_query}"
):
sources = get_sources(optimize_search_query, max_pages=max_pages, domain=domain)
console.log(f"Found {len(sources)} sources {'on ' + domain if domain else ''}")
with console.status(
f"[bold green]Fetching content for {len(sources)} sources", spinner="growVertical"
):
contents = get_links_contents(sources)
console.log(f"Managed to extract content from {len(contents)} sources")
with console.status(f"[bold green]Embeddubg {len(contents)} sources for content", spinner="growVertical"):
vector_store = vectorize(contents)
with console.status("[bold green]Querying LLM relevant context", spinner='dots8Bit'):
respomse = query_rag(chat, query, optimize_search_query, vector_store)
console.rule(f"[bold green]Response from {provider}")
if output == "text":
console.print(respomse)
else:
console.print(Markdown(respomse))
console.rule("[bold green]")