File size: 5,468 Bytes
2091d19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import asyncio
import os
from datetime import date

from consts import PROJECT_ROOT_DIR

# from dotenv import find_dotenv, load_dotenv
from generate_arxiv_responses import ArxivResponseGenerator
from llama_index.core.agent.workflow import AgentWorkflow, ReActAgent
from llama_index.core.tools import FunctionTool
from llama_index.llms.huggingface_api import HuggingFaceInferenceAPI
from llama_index.tools.duckduckgo import DuckDuckGoSearchToolSpec

from src.agent_hackathon.logger import get_logger

# _ = load_dotenv(dotenv_path=find_dotenv(raise_error_if_not_found=False), override=True)

logger = get_logger(log_name="multiagent", log_dir=PROJECT_ROOT_DIR / "logs")


class MultiAgentWorkflow:
    """Multi-agent workflow for retrieving research papers and related events."""

    def __init__(self) -> None:
        """Initialize the workflow with LLM, tools, and generator."""
        logger.info("Initializing MultiAgentWorkflow.")
        self.llm = HuggingFaceInferenceAPI(
            model="meta-llama/Llama-3.3-70B-Instruct",
            provider="auto",
            # provider="nebius",
            temperature=0.1,
            top_p=0.95,
            max_tokens=8192
            # api_key=os.getenv(key="NEBIUS_API_KEY"),
            # base_url="https://api.studio.nebius.com/v1/",
        )
        self._generator = ArxivResponseGenerator(
            vector_store_path=PROJECT_ROOT_DIR / "db/arxiv_docs.db"
        )
        # self._arxiv_rag_tool = FunctionTool.from_defaults(
        #     fn=self._arxiv_rag,
        #     name="arxiv_rag",
        #     description="Retrieves arxiv research papers.",
        #     return_direct=True,
        # )
        self._duckduckgo_search_tool = [
            tool
            for tool in DuckDuckGoSearchToolSpec().to_tool_list()
            if tool.metadata.name == "duckduckgo_full_search"
        ]
        # self._arxiv_agent = ReActAgent(
        #     name="arxiv_agent",
        #     description="Retrieves information about arxiv research papers",
        #     system_prompt="You are arxiv research paper agent, who retrieves information "
        #     "about arxiv research papers.",
        #     tools=[self._arxiv_rag_tool],
        #     llm=self.llm,
        # )
        self._websearch_agent = ReActAgent(
            name="web_search",
            description="Searches the web",
            system_prompt="You are search engine who searches the web using duckduckgo tool",
            tools=self._duckduckgo_search_tool,
            llm=self.llm,
        )

        self._workflow = AgentWorkflow(
            agents=[self._websearch_agent],
            root_agent="web_search",
            timeout=180,
        )
        # AgentWorkflow.from_tools_or_functions(
        #     tools_or_functions=self._duckduckgo_search_tool,
        #     llm=self.llm,
        #     system_prompt="You are an expert that  "
        #     "searches for any corresponding events related to the "
        #     "user query "
        #     "using the duckduckgo_search_tool and returns the final results." \
        #     "Don't return the steps but execute the necessary tools that you have " \
        #     "access to and return the results.",
        #     timeout=180,
        # )

        logger.info("MultiAgentWorkflow initialized.")

    def _arxiv_rag(self, query: str) -> str:
        """Retrieve research papers from arXiv based on the query.

        Args:
            query (str): The search query.

        Returns:
            str: Retrieved research papers as a string.
        """
        return self._generator.retrieve_arxiv_papers(query=query)

    def _clean_response(self, result: str) -> str:
        """Removes the think tags.

        Args:
            result (str): The result with the <think></think> content.

        Returns:
            str: The result without the <think></think> content.
        """
        if result.find("</think>"):
            result = result[result.find("</think>") + len("</think>") :]
        return result

    async def run(self, user_query: str) -> str:
        """Run the multi-agent workflow for a given user query.

        Args:
            user_query (str): The user's search query.

        Returns:
            str: The output string.
        """
        logger.info("Running multi-agent workflow.")
        try:
            research_papers = self._arxiv_rag(query=user_query)
            user_msg = (
                f"search with the web search agent to find any relevant events related to: {user_query}.\n"
                f" The web search results relevant to the current year: {date.today().year}. \n"
            )
            web_search_results = await self._workflow.run(user_msg=user_msg)
            final_res = (
                research_papers + "\n\n" + web_search_results.response.blocks[0].text
            )
            logger.info("Workflow run completed successfully.")
            return final_res
        except Exception as err:
            logger.error(f"Workflow run failed: {err}")
            raise


if __name__ == "__main__":
    USER_QUERY = "i want to learn more about nlp"
    workflow = MultiAgentWorkflow()
    logger.info("Starting workflow for user query.")
    try:
        result = asyncio.run(workflow.run(user_query=USER_QUERY))
        logger.info("Workflow finished. Output below:")
        print(result)
    except Exception as err:
        logger.error(f"Error during workflow execution: {err}")