File size: 3,285 Bytes
912f746
 
 
 
 
fb8728b
 
 
 
 
 
 
912f746
 
fb8728b
912f746
 
 
 
0866aba
 
912f746
 
 
 
 
 
fb8728b
 
 
 
 
 
912f746
fb8728b
912f746
 
fb8728b
 
 
 
 
 
 
 
 
 
 
 
912f746
 
4a8d3f6
 
 
912f746
 
 
 
fb8728b
912f746
 
 
 
 
 
 
 
 
fb8728b
912f746
fb8728b
 
912f746
fb8728b
 
912f746
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import os

import PIL.Image
from dotenv import load_dotenv
from loguru import logger
from smolagents import (
    AzureOpenAIServerModel,
    CodeAgent,
    GoogleSearchTool,
    PythonInterpreterTool,
    VisitWebpageTool,
)

from src.file_handler.parse import parse_file
from src.tools import reverse_question

load_dotenv()


class Agent:
    def __init__(self):
        model = AzureOpenAIServerModel(
            model_id=os.getenv("AZURE_OPENAI_MODEL_ID"),
            azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
            api_key=os.getenv("AZURE_OPENAI_API_KEY"),
            api_version=os.getenv("OPENAI_API_VERSION"),
        )
        tools = [
            GoogleSearchTool(provider="serper"),
            VisitWebpageTool(),
            PythonInterpreterTool(),
            reverse_question,
        ]
        self.agent = CodeAgent(
            tools=tools,
            model=model,
        )
        self.user_prompt = """
        I will ask you a question.
        Report your thoughts, and finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER].
        YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings.
        If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise.
        If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise.
        If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.

        Question: {question}

        Attached content: {content}
        """
        logger.info("BasicAgent initialized.")

    def __call__(
        self, question: str, task_id: str, file_name: str, api_url: str
    ) -> str:
        logger.info(
            f"Agent received question (first 50 chars): {question[:50]}..."
        )
        images = None
        prompt = self.user_prompt.format(question=question)

        if file_name:
            content = parse_file(task_id, file_name, api_url)
            if content:
                if isinstance(
                    content, PIL.Image.Image
                ):  # Parse content as image
                    images = [content]
                else:  # Append content to question
                    prompt = prompt.format(content=content)
                    logger.info(f"Question with content: {question}")
        else:
            prompt = prompt.format(content="")

        answer = self.agent.run(prompt, images=images)
        answer = answer.replace("FINAL ANSWER:", "").strip()
        logger.info(f"Agent returning answer: {answer}")
        return answer


if __name__ == "__main__":
    import requests

    api_url = "https://agents-course-unit4-scoring.hf.space"
    question_url = f"{api_url}/random-question"

    data = requests.get(question_url).json()
    agent = Agent()

    task_id = data["task_id"]
    question = data["question"]
    file_name = data["file_name"]
    logger.info(
        f"Task ID: {task_id}\nQuestion: {question}\nFile Name: {file_name}\n\n"
    )

    answer = agent(question, file_name)