File size: 5,520 Bytes
9c3709d
 
 
 
 
 
4c66227
9c3709d
8c28786
9c3709d
9847233
7e9684b
9c3709d
 
 
 
 
 
8c28786
9c3709d
 
df527c8
4c66227
 
9847233
7e9684b
9c3709d
 
 
 
 
 
8c28786
9c3709d
 
 
8d1e83e
9c3709d
 
 
 
 
d594a38
 
8c28786
9c3709d
 
 
 
542890e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9c3709d
8d1e83e
9c3709d
d594a38
9c3709d
 
8d1e83e
9c3709d
8c28786
 
 
 
 
 
 
 
 
 
 
 
9c3709d
4c66227
9c3709d
8d1e83e
7e9684b
9847233
7e9684b
9c3709d
8d1e83e
df527c8
8d1e83e
9c3709d
d594a38
9847233
 
8d1e83e
 
 
 
 
d594a38
9c3709d
 
8d1e83e
 
 
542890e
9c3709d
 
9847233
df527c8
9c3709d
9847233
 
9c3709d
 
7e9684b
8c28786
7e9684b
8c28786
9c3709d
8c28786
 
 
 
 
9847233
8c28786
 
 
 
 
 
 
 
 
9847233
8c28786
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
"""search_agent.py

Usage:
    search_agent.py 
        [--domain=domain]
        [--provider=provider]
        [--model=model]
        [--temperature=temp]
        [--copywrite]
        [--max_pages=num]
        [--max_extracts=num]
        [--output=text]
        SEARCH_QUERY
    search_agent.py --version

Options:
    -h --help                           Show this screen.
    --version                           Show version.
    -c --copywrite                      First produce a draft, review it and rewrite for a final text
    -d domain --domain=domain           Limit search to a specific domain
    -t temp --temperature=temp          Set the temperature of the LLM [default: 0.0]
    -p provider --provider=provider     Use a specific LLM (choices: bedrock,openai,groq,ollama,cohere,fireworks) [default: openai]
    -m model --model=model              Use a specific model
    -n num --max_pages=num              Max number of pages to retrieve [default: 10]
    -e num --max_extracts=num           Max number of page extract to consider [default: 5]
    -o text --output=text               Output format (choices: text, markdown) [default: markdown]

"""

import os

from docopt import docopt
#from schema import Schema, Use, SchemaError
import dotenv

from langchain.callbacks import LangChainTracer

from langsmith import Client

from rich.console import Console
from rich.markdown import Markdown

import web_rag as wr
import web_crawler as wc
import copywriter as cw

console = Console()
dotenv.load_dotenv()

def get_selenium_driver():
    from selenium import webdriver
    from selenium.webdriver.chrome.options import Options
    from selenium.common.exceptions import TimeoutException

    chrome_options = Options()
    chrome_options.add_argument("headless")
    chrome_options.add_argument("--disable-extensions")
    chrome_options.add_argument("--disable-gpu")
    chrome_options.add_argument("--no-sandbox")
    chrome_options.add_argument("--disable-dev-shm-usage")
    chrome_options.add_argument("--remote-debugging-port=9222")
    chrome_options.add_argument('--blink-settings=imagesEnabled=false')
    chrome_options.add_argument("--window-size=1920,1080")

    driver = webdriver.Chrome(options=chrome_options)
    return driver

callbacks = []
if os.getenv("LANGCHAIN_API_KEY"):
    callbacks.append(
        LangChainTracer(client=Client())
    )

if __name__ == '__main__':
    arguments = docopt(__doc__, version='Search Agent 0.1')
    
    #schema = Schema({
    #    '--max_pages': Use(int, error='--max_pages must be an integer'),
    #    '--temperature': Use(float, error='--temperature must be an float'),
    #})

    #try:
    #    arguments = schema.validate(arguments)
    #except SchemaError as e:
    #    exit(e)

    copywrite_mode = arguments["--copywrite"]
    provider = arguments["--provider"]
    model = arguments["--model"]
    temperature = float(arguments["--temperature"])
    domain=arguments["--domain"]
    max_pages=arguments["--max_pages"]
    max_extract=int(arguments["--max_extracts"])
    output=arguments["--output"]
    query = arguments["SEARCH_QUERY"]

    chat, embedding_model = wr.get_models(provider, model, temperature)

    with console.status(f"[bold green]Optimizing query for search: {query}"):
        optimize_search_query = wr.optimize_search_query(chat, query, callbacks=callbacks)
        if len(optimize_search_query) < 3:
            optimize_search_query = query
    console.log(f"Optimized search query: [bold blue]{optimize_search_query}")

    with console.status(
            f"[bold green]Searching sources using the optimized query: {optimize_search_query}"
        ):
        sources = wc.get_sources(optimize_search_query, max_pages=max_pages, domain=domain)
    console.log(f"Found {len(sources)} sources {'on ' + domain if domain else ''}")

    with console.status(
        f"[bold green]Fetching content for {len(sources)} sources", spinner="growVertical"
    ):
        contents = wc.get_links_contents(sources, get_selenium_driver)
    console.log(f"Managed to extract content from {len(contents)} sources")

    with console.status(f"[bold green]Embedding {len(contents)} sources for content", spinner="growVertical"):
        vector_store = wc.vectorize(contents, embedding_model)

    with console.status("[bold green]Writing content", spinner='dots8Bit'):
        draft = wr.query_rag(chat, query, optimize_search_query, vector_store, top_k = max_extract, callbacks=callbacks)

    console.rule(f"[bold green]Response from {provider}")
    if output == "text":
        console.print(draft)
    else:
        console.print(Markdown(draft))
    console.rule("[bold green]")
    
    if(copywrite_mode):
        with console.status("[bold green]Getting comments from the reviewer", spinner="dots8Bit"):
            comments = cw.generate_comments(chat, query, draft, callbacks=callbacks)

        console.rule("[bold green]Response from reviewer")
        if output == "text":
            console.print(comments)
        else:
            console.print(Markdown(comments))
        console.rule("[bold green]")

        with console.status("[bold green]Writing the final text", spinner="dots8Bit"):
            final_text = cw.generate_final_text(chat, query, draft, comments, callbacks=callbacks)

        console.rule("[bold green]Final text")
        if output == "text":
            console.print(final_text)
        else:
            console.print(Markdown(final_text))
        console.rule("[bold green]")