File size: 3,953 Bytes
9c3709d
 
 
 
 
 
4c66227
9c3709d
 
7e9684b
9c3709d
 
 
 
 
 
 
 
d21cce9
4c66227
 
7e9684b
9c3709d
 
 
 
 
 
 
 
 
8d1e83e
9c3709d
 
 
 
 
d594a38
 
9c3709d
 
 
 
542890e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9c3709d
8d1e83e
9c3709d
d594a38
9c3709d
 
8d1e83e
9c3709d
 
 
4c66227
9c3709d
8d1e83e
7e9684b
 
9c3709d
8d1e83e
d594a38
bda01ad
8d1e83e
9c3709d
d594a38
8d1e83e
 
 
 
 
d594a38
9c3709d
 
8d1e83e
 
 
542890e
9c3709d
 
615fe60
d594a38
9c3709d
8d1e83e
5cba10a
9c3709d
 
7e9684b
8d1e83e
7e9684b
8d1e83e
9c3709d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
"""search_agent.py

Usage:
    search_agent.py 
        [--domain=domain]
        [--provider=provider]
        [--model=model]
        [--temperature=temp]
        [--max_pages=num]
        [--output=text]
        SEARCH_QUERY
    search_agent.py --version

Options:
    -h --help                           Show this screen.
    --version                           Show version.
    -d domain --domain=domain           Limit search to a specific domain
    -t temp --temperature=temp          Set the temperature of the LLM [default: 0.0]
    -p provider --provider=provider     Use a specific LLM (choices: bedrock,openai,groq,ollama,cohere) [default: openai]
    -m model --model=model              Use a specific model
    -n num --max_pages=num              Max number of pages to retrieve [default: 10]
    -o text --output=text               Output format (choices: text, markdown) [default: markdown]

"""

import os

from docopt import docopt
import dotenv

from langchain.callbacks import LangChainTracer

from langsmith import Client

from rich.console import Console
from rich.markdown import Markdown

import web_rag as wr
import web_crawler as wc

console = Console()
dotenv.load_dotenv()

def get_selenium_driver():
    from selenium import webdriver
    from selenium.webdriver.chrome.options import Options
    from selenium.common.exceptions import TimeoutException

    chrome_options = Options()
    chrome_options.add_argument("headless")
    chrome_options.add_argument("--disable-extensions")
    chrome_options.add_argument("--disable-gpu")
    chrome_options.add_argument("--no-sandbox")
    chrome_options.add_argument("--disable-dev-shm-usage")
    chrome_options.add_argument("--remote-debugging-port=9222")
    chrome_options.add_argument('--blink-settings=imagesEnabled=false')
    chrome_options.add_argument("--window-size=1920,1080")

    driver = webdriver.Chrome(options=chrome_options)
    return driver



callbacks = []
if os.getenv("LANGCHAIN_API_KEY"):
    callbacks.append(
        LangChainTracer(client=Client())
    )

if __name__ == '__main__':
    arguments = docopt(__doc__, version='Search Agent 0.1')

    provider = arguments["--provider"]
    model = arguments["--model"]
    temperature = float(arguments["--temperature"])
    domain=arguments["--domain"]
    max_pages=arguments["--max_pages"]
    output=arguments["--output"]
    query = arguments["SEARCH_QUERY"]

    chat = wr.get_chat_llm(provider, model, temperature)
    console.log(f"Using {chat.model} on {provider} with temperature {temperature}")

    with console.status(f"[bold green]Optimizing query for search: {query}"):
        optimize_search_query = wr.optimize_search_query(chat, query, callbacks=callbacks)
    console.log(f"Optimized search query: [bold blue]{optimize_search_query}")

    with console.status(
            f"[bold green]Searching sources using the optimized query: {optimize_search_query}"
        ):
        sources = wc.get_sources(optimize_search_query, max_pages=max_pages, domain=domain)
    console.log(f"Found {len(sources)} sources {'on ' + domain if domain else ''}")

    with console.status(
        f"[bold green]Fetching content for {len(sources)} sources", spinner="growVertical"
    ):
        contents = wc.get_links_contents(sources, get_selenium_driver)
    console.log(f"Managed to extract content from {len(contents)} sources")

    with console.status(f"[bold green]Embeddubg {len(contents)} sources for content", spinner="growVertical"):
        vector_store = wc.vectorize(contents)

    with console.status("[bold green]Querying LLM relevant context", spinner='dots8Bit'):
        respomse = wr.query_rag(chat, query, optimize_search_query, vector_store, top_k = 5, callbacks=callbacks)

    console.rule(f"[bold green]Response from {provider}")
    if output == "text":
        console.print(respomse)
    else:
        console.print(Markdown(respomse))
    console.rule("[bold green]")