|
import asyncio |
|
import os |
|
from datetime import date |
|
|
|
from consts import PROJECT_ROOT_DIR |
|
|
|
|
|
from generate_arxiv_responses import ArxivResponseGenerator |
|
from llama_index.core.agent.workflow import AgentWorkflow, ReActAgent |
|
from llama_index.core.tools import FunctionTool |
|
from llama_index.llms.huggingface_api import HuggingFaceInferenceAPI |
|
from llama_index.tools.duckduckgo import DuckDuckGoSearchToolSpec |
|
|
|
from src.agent_hackathon.logger import get_logger |
|
|
|
|
|
|
|
logger = get_logger(log_name="multiagent", log_dir=PROJECT_ROOT_DIR / "logs") |
|
|
|
|
|
class MultiAgentWorkflow: |
|
"""Multi-agent workflow for retrieving research papers and related events.""" |
|
|
|
def __init__(self) -> None: |
|
"""Initialize the workflow with LLM, tools, and generator.""" |
|
logger.info("Initializing MultiAgentWorkflow.") |
|
self.llm = HuggingFaceInferenceAPI( |
|
model="meta-llama/Llama-3.3-70B-Instruct", |
|
provider="auto", |
|
|
|
temperature=0.1, |
|
top_p=0.95, |
|
max_tokens=8192 |
|
|
|
|
|
) |
|
self._generator = ArxivResponseGenerator( |
|
vector_store_path=PROJECT_ROOT_DIR / "db/arxiv_docs.db" |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
self._duckduckgo_search_tool = [ |
|
tool |
|
for tool in DuckDuckGoSearchToolSpec().to_tool_list() |
|
if tool.metadata.name == "duckduckgo_full_search" |
|
] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self._websearch_agent = ReActAgent( |
|
name="web_search", |
|
description="Searches the web", |
|
system_prompt="You are search engine who searches the web using duckduckgo tool", |
|
tools=self._duckduckgo_search_tool, |
|
llm=self.llm, |
|
) |
|
|
|
self._workflow = AgentWorkflow( |
|
agents=[self._websearch_agent], |
|
root_agent="web_search", |
|
timeout=180, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logger.info("MultiAgentWorkflow initialized.") |
|
|
|
def _arxiv_rag(self, query: str) -> str: |
|
"""Retrieve research papers from arXiv based on the query. |
|
|
|
Args: |
|
query (str): The search query. |
|
|
|
Returns: |
|
str: Retrieved research papers as a string. |
|
""" |
|
return self._generator.retrieve_arxiv_papers(query=query) |
|
|
|
def _clean_response(self, result: str) -> str: |
|
"""Removes the think tags. |
|
|
|
Args: |
|
result (str): The result with the <think></think> content. |
|
|
|
Returns: |
|
str: The result without the <think></think> content. |
|
""" |
|
if result.find("</think>"): |
|
result = result[result.find("</think>") + len("</think>") :] |
|
return result |
|
|
|
async def run(self, user_query: str) -> str: |
|
"""Run the multi-agent workflow for a given user query. |
|
|
|
Args: |
|
user_query (str): The user's search query. |
|
|
|
Returns: |
|
str: The output string. |
|
""" |
|
logger.info("Running multi-agent workflow.") |
|
try: |
|
research_papers = self._arxiv_rag(query=user_query) |
|
user_msg = ( |
|
f"search with the web search agent to find any relevant events related to: {user_query}.\n" |
|
f" The web search results relevant to the current year: {date.today().year}. \n" |
|
) |
|
web_search_results = await self._workflow.run(user_msg=user_msg) |
|
final_res = ( |
|
research_papers + "\n\n" + web_search_results.response.blocks[0].text |
|
) |
|
logger.info("Workflow run completed successfully.") |
|
return final_res |
|
except Exception as err: |
|
logger.error(f"Workflow run failed: {err}") |
|
raise |
|
|
|
|
|
if __name__ == "__main__": |
|
USER_QUERY = "i want to learn more about nlp" |
|
workflow = MultiAgentWorkflow() |
|
logger.info("Starting workflow for user query.") |
|
try: |
|
result = asyncio.run(workflow.run(user_query=USER_QUERY)) |
|
logger.info("Workflow finished. Output below:") |
|
print(result) |
|
except Exception as err: |
|
logger.error(f"Error during workflow execution: {err}") |
|
|