mjschock's picture
Enhance main_v2.py by adding ToolCallingAgent for improved web search capabilities and updating prompt template loading method. Comment out unused agent initializations and adjust task handling to streamline agent functionality. Introduce new YAML-based prompt template for CodeAgent to enhance task execution clarity.
2579190 unverified
raw
history blame
4.18 kB
import importlib
import logging
import os
import requests
import yaml
from dotenv import find_dotenv, load_dotenv
from litellm._logging import _disable_debugging
from openinference.instrumentation.smolagents import SmolagentsInstrumentor
from phoenix.otel import register
# from smolagents import CodeAgent, LiteLLMModel, LiteLLMRouterModel
from smolagents import CodeAgent, LiteLLMModel, ToolCallingAgent
from smolagents.default_tools import (
DuckDuckGoSearchTool,
VisitWebpageTool,
WikipediaSearchTool,
)
from smolagents.monitoring import LogLevel
from agents.data_agent.agent import create_data_agent
from agents.media_agent.agent import create_media_agent
from agents.web_agent.agent import create_web_agent
from utils import extract_final_answer
_disable_debugging()
# Configure OpenTelemetry with Phoenix
register()
SmolagentsInstrumentor().instrument()
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
load_dotenv(find_dotenv())
API_BASE = os.getenv("API_BASE")
API_KEY = os.getenv("API_KEY")
MODEL_ID = os.getenv("MODEL_ID")
model = LiteLLMModel(
api_base=API_BASE,
api_key=API_KEY,
model_id=MODEL_ID,
)
# data_agent = create_data_agent(model)
# media_agent = create_media_agent(model)
# web_agent = create_web_agent(model)
# search_agent = ToolCallingAgent(
# tools=[DuckDuckGoSearchTool(), VisitWebpageTool()],
# model=model,
# name="search_agent",
# description="This is an agent that can do web search.",
# )
prompt_templates = yaml.safe_load(open("prompts/code_agent.yaml", "r"))
agent = CodeAgent(
# add_base_tools=True,
# additional_authorized_imports=[
# "json",
# "pandas",
# "numpy",
# "re",
# # "requests"
# # "urllib.request",
# ],
# max_steps=10,
# managed_agents=[web_agent, data_agent, media_agent],
# managed_agents=[search_agent],
model=model,
prompt_templates=prompt_templates,
tools=[
DuckDuckGoSearchTool(max_results=3),
VisitWebpageTool(max_output_length=1024),
WikipediaSearchTool(),
],
step_callbacks=None,
verbosity_level=LogLevel.ERROR,
)
agent.visualize()
def main(task: str):
# Format the task with GAIA-style instructions
# gaia_task = f"""Instructions:
# 1. Your response must contain ONLY the answer to the question, nothing else
# 2. Do not repeat the question or any part of it
# 3. Do not include any explanations, reasoning, or context
# 4. Do not include source attribution or references
# 5. Do not use phrases like "The answer is" or "I found that"
# 6. Do not include any formatting, bullet points, or line breaks
# 7. If the answer is a number, return only the number
# 8. If the answer requires multiple items, separate them with commas
# 9. If the answer requires ordering, maintain the specified order
# 10. Use the most direct and succinct form possible
# {task}"""
result = agent.run(
additional_args=None,
images=None,
max_steps=3,
reset=True,
stream=False,
task=task,
# task=gaia_task,
)
logger.info(f"Result: {result}")
return extract_final_answer(result)
if __name__ == "__main__":
DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
api_url = DEFAULT_API_URL
questions_url = f"{api_url}/questions"
submit_url = f"{api_url}/submit"
response = requests.get(questions_url, timeout=15)
response.raise_for_status()
questions_data = response.json()
for question_data in questions_data[:1]:
file_name = question_data["file_name"]
level = question_data["Level"]
question = question_data["question"]
task_id = question_data["task_id"]
logger.info(f"Question: {question}")
# logger.info(f"Level: {level}")
if file_name:
logger.info(f"File Name: {file_name}")
# logger.info(f"Task ID: {task_id}")
final_answer = main(question)
logger.info(f"Final Answer: {final_answer}")
logger.info("--------------------------------")