mjschock's picture
Update main_v2.py to format task instructions in GAIA-style, ensuring responses are concise and follow specific guidelines. Modify the task parameter to utilize the new formatted instructions for improved clarity and response accuracy.
2aa9dc2 unverified
raw
history blame
3.92 kB
import importlib
import logging
import os
import requests
import yaml
from dotenv import find_dotenv, load_dotenv
from litellm._logging import _disable_debugging
from openinference.instrumentation.smolagents import SmolagentsInstrumentor
from phoenix.otel import register
# from smolagents import CodeAgent, LiteLLMModel, LiteLLMRouterModel
from smolagents import CodeAgent, LiteLLMModel
from smolagents.default_tools import (
DuckDuckGoSearchTool,
VisitWebpageTool,
WikipediaSearchTool,
)
from smolagents.monitoring import LogLevel
from agents.data_agent.agent import create_data_agent
from agents.media_agent.agent import create_media_agent
from agents.web_agent.agent import create_web_agent
from utils import extract_final_answer
_disable_debugging()
# Configure OpenTelemetry with Phoenix
register()
SmolagentsInstrumentor().instrument()
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
load_dotenv(find_dotenv())
API_BASE = os.getenv("API_BASE")
API_KEY = os.getenv("API_KEY")
MODEL_ID = os.getenv("MODEL_ID")
model = LiteLLMModel(
api_base=API_BASE,
api_key=API_KEY,
model_id=MODEL_ID,
)
data_agent = create_data_agent(model)
media_agent = create_media_agent(model)
web_agent = create_web_agent(model)
prompt_templates = yaml.safe_load(
importlib.resources.files("smolagents.prompts")
.joinpath("code_agent.yaml")
.read_text()
)
agent = CodeAgent(
# add_base_tools=True,
additional_authorized_imports=[
"json",
"pandas",
"numpy",
"re",
# "requests"
# "urllib.request",
],
# max_steps=10,
# managed_agents=[web_agent, data_agent, media_agent],
model=model,
prompt_templates=prompt_templates,
tools=[
DuckDuckGoSearchTool(max_results=3),
VisitWebpageTool(max_output_length=1024),
WikipediaSearchTool(),
],
step_callbacks=None,
verbosity_level=LogLevel.ERROR,
)
agent.visualize()
def main(task: str):
# Format the task with GAIA-style instructions
gaia_task = f"""Instructions:
1. Your response must contain ONLY the answer to the question, nothing else
2. Do not repeat the question or any part of it
3. Do not include any explanations, reasoning, or context
4. Do not include source attribution or references
5. Do not use phrases like "The answer is" or "I found that"
6. Do not include any formatting, bullet points, or line breaks
7. If the answer is a number, return only the number
8. If the answer requires multiple items, separate them with commas
9. If the answer requires ordering, maintain the specified order
10. Use the most direct and succinct form possible
{task}"""
result = agent.run(
additional_args=None,
images=None,
max_steps=5,
reset=True,
stream=False,
task=gaia_task,
)
logger.info(f"Result: {result}")
return extract_final_answer(result)
if __name__ == "__main__":
DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
api_url = DEFAULT_API_URL
questions_url = f"{api_url}/questions"
submit_url = f"{api_url}/submit"
response = requests.get(questions_url, timeout=15)
response.raise_for_status()
questions_data = response.json()
for question_data in questions_data[:1]:
file_name = question_data["file_name"]
level = question_data["Level"]
question = question_data["question"]
task_id = question_data["task_id"]
logger.info(f"Question: {question}")
# logger.info(f"Level: {level}")
if file_name:
logger.info(f"File Name: {file_name}")
# logger.info(f"Task ID: {task_id}")
final_answer = main(question)
logger.info(f"Final Answer: {final_answer}")
logger.info("--------------------------------")