File size: 5,468 Bytes
2091d19 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
import asyncio
import os
from datetime import date
from consts import PROJECT_ROOT_DIR
# from dotenv import find_dotenv, load_dotenv
from generate_arxiv_responses import ArxivResponseGenerator
from llama_index.core.agent.workflow import AgentWorkflow, ReActAgent
from llama_index.core.tools import FunctionTool
from llama_index.llms.huggingface_api import HuggingFaceInferenceAPI
from llama_index.tools.duckduckgo import DuckDuckGoSearchToolSpec
from src.agent_hackathon.logger import get_logger
# _ = load_dotenv(dotenv_path=find_dotenv(raise_error_if_not_found=False), override=True)
logger = get_logger(log_name="multiagent", log_dir=PROJECT_ROOT_DIR / "logs")
class MultiAgentWorkflow:
"""Multi-agent workflow for retrieving research papers and related events."""
def __init__(self) -> None:
"""Initialize the workflow with LLM, tools, and generator."""
logger.info("Initializing MultiAgentWorkflow.")
self.llm = HuggingFaceInferenceAPI(
model="meta-llama/Llama-3.3-70B-Instruct",
provider="auto",
# provider="nebius",
temperature=0.1,
top_p=0.95,
max_tokens=8192
# api_key=os.getenv(key="NEBIUS_API_KEY"),
# base_url="https://api.studio.nebius.com/v1/",
)
self._generator = ArxivResponseGenerator(
vector_store_path=PROJECT_ROOT_DIR / "db/arxiv_docs.db"
)
# self._arxiv_rag_tool = FunctionTool.from_defaults(
# fn=self._arxiv_rag,
# name="arxiv_rag",
# description="Retrieves arxiv research papers.",
# return_direct=True,
# )
self._duckduckgo_search_tool = [
tool
for tool in DuckDuckGoSearchToolSpec().to_tool_list()
if tool.metadata.name == "duckduckgo_full_search"
]
# self._arxiv_agent = ReActAgent(
# name="arxiv_agent",
# description="Retrieves information about arxiv research papers",
# system_prompt="You are arxiv research paper agent, who retrieves information "
# "about arxiv research papers.",
# tools=[self._arxiv_rag_tool],
# llm=self.llm,
# )
self._websearch_agent = ReActAgent(
name="web_search",
description="Searches the web",
system_prompt="You are search engine who searches the web using duckduckgo tool",
tools=self._duckduckgo_search_tool,
llm=self.llm,
)
self._workflow = AgentWorkflow(
agents=[self._websearch_agent],
root_agent="web_search",
timeout=180,
)
# AgentWorkflow.from_tools_or_functions(
# tools_or_functions=self._duckduckgo_search_tool,
# llm=self.llm,
# system_prompt="You are an expert that "
# "searches for any corresponding events related to the "
# "user query "
# "using the duckduckgo_search_tool and returns the final results." \
# "Don't return the steps but execute the necessary tools that you have " \
# "access to and return the results.",
# timeout=180,
# )
logger.info("MultiAgentWorkflow initialized.")
def _arxiv_rag(self, query: str) -> str:
"""Retrieve research papers from arXiv based on the query.
Args:
query (str): The search query.
Returns:
str: Retrieved research papers as a string.
"""
return self._generator.retrieve_arxiv_papers(query=query)
def _clean_response(self, result: str) -> str:
"""Removes the think tags.
Args:
result (str): The result with the <think></think> content.
Returns:
str: The result without the <think></think> content.
"""
if result.find("</think>"):
result = result[result.find("</think>") + len("</think>") :]
return result
async def run(self, user_query: str) -> str:
"""Run the multi-agent workflow for a given user query.
Args:
user_query (str): The user's search query.
Returns:
str: The output string.
"""
logger.info("Running multi-agent workflow.")
try:
research_papers = self._arxiv_rag(query=user_query)
user_msg = (
f"search with the web search agent to find any relevant events related to: {user_query}.\n"
f" The web search results relevant to the current year: {date.today().year}. \n"
)
web_search_results = await self._workflow.run(user_msg=user_msg)
final_res = (
research_papers + "\n\n" + web_search_results.response.blocks[0].text
)
logger.info("Workflow run completed successfully.")
return final_res
except Exception as err:
logger.error(f"Workflow run failed: {err}")
raise
if __name__ == "__main__":
USER_QUERY = "i want to learn more about nlp"
workflow = MultiAgentWorkflow()
logger.info("Starting workflow for user query.")
try:
result = asyncio.run(workflow.run(user_query=USER_QUERY))
logger.info("Workflow finished. Output below:")
print(result)
except Exception as err:
logger.error(f"Error during workflow execution: {err}")
|