File size: 3,915 Bytes
a60b872
 
 
 
 
 
 
 
 
 
 
 
 
e4c7240
 
 
 
 
a60b872
 
837e221
 
 
a60b872
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
837e221
a60b872
837e221
a60b872
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e4c7240
 
 
a60b872
 
 
 
 
 
 
 
 
2aa9dc2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a60b872
 
 
e4c7240
a60b872
 
2aa9dc2
a60b872
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import importlib
import logging
import os

import requests
import yaml
from dotenv import find_dotenv, load_dotenv
from litellm._logging import _disable_debugging
from openinference.instrumentation.smolagents import SmolagentsInstrumentor
from phoenix.otel import register

# from smolagents import CodeAgent, LiteLLMModel, LiteLLMRouterModel
from smolagents import CodeAgent, LiteLLMModel
from smolagents.default_tools import (
    DuckDuckGoSearchTool,
    VisitWebpageTool,
    WikipediaSearchTool,
)
from smolagents.monitoring import LogLevel

from agents.data_agent.agent import create_data_agent
from agents.media_agent.agent import create_media_agent
from agents.web_agent.agent import create_web_agent
from utils import extract_final_answer

_disable_debugging()

# Configure OpenTelemetry with Phoenix
register()
SmolagentsInstrumentor().instrument()

logging.basicConfig(
    level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)

load_dotenv(find_dotenv())

API_BASE = os.getenv("API_BASE")
API_KEY = os.getenv("API_KEY")
MODEL_ID = os.getenv("MODEL_ID")

model = LiteLLMModel(
    api_base=API_BASE,
    api_key=API_KEY,
    model_id=MODEL_ID,
)

data_agent = create_data_agent(model)
media_agent = create_media_agent(model)
web_agent = create_web_agent(model)

prompt_templates = yaml.safe_load(
    importlib.resources.files("smolagents.prompts")
    .joinpath("code_agent.yaml")
    .read_text()
)

agent = CodeAgent(
    # add_base_tools=True,
    additional_authorized_imports=[
        "json",
        "pandas",
        "numpy",
        "re",
        # "requests"
        # "urllib.request",
    ],
    # max_steps=10,
    # managed_agents=[web_agent, data_agent, media_agent],
    model=model,
    prompt_templates=prompt_templates,
    tools=[
        DuckDuckGoSearchTool(max_results=3),
        VisitWebpageTool(max_output_length=1024),
        WikipediaSearchTool(),
    ],
    step_callbacks=None,
    verbosity_level=LogLevel.ERROR,
)

agent.visualize()


def main(task: str):
    # Format the task with GAIA-style instructions
    gaia_task = f"""Instructions:
1. Your response must contain ONLY the answer to the question, nothing else
2. Do not repeat the question or any part of it
3. Do not include any explanations, reasoning, or context
4. Do not include source attribution or references
5. Do not use phrases like "The answer is" or "I found that"
6. Do not include any formatting, bullet points, or line breaks
7. If the answer is a number, return only the number
8. If the answer requires multiple items, separate them with commas
9. If the answer requires ordering, maintain the specified order
10. Use the most direct and succinct form possible

{task}"""

    result = agent.run(
        additional_args=None,
        images=None,
        max_steps=5,
        reset=True,
        stream=False,
        task=gaia_task,
    )

    logger.info(f"Result: {result}")

    return extract_final_answer(result)


if __name__ == "__main__":
    DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"

    api_url = DEFAULT_API_URL
    questions_url = f"{api_url}/questions"
    submit_url = f"{api_url}/submit"

    response = requests.get(questions_url, timeout=15)
    response.raise_for_status()
    questions_data = response.json()

    for question_data in questions_data[:1]:
        file_name = question_data["file_name"]
        level = question_data["Level"]
        question = question_data["question"]
        task_id = question_data["task_id"]

        logger.info(f"Question: {question}")
        # logger.info(f"Level: {level}")
        if file_name:
            logger.info(f"File Name: {file_name}")
        # logger.info(f"Task ID: {task_id}")

        final_answer = main(question)
        logger.info(f"Final Answer: {final_answer}")
        logger.info("--------------------------------")