pateas's picture
Update agent.py: refine final answer instructions and adjust BasicAgent model parameters
dc70c2d unverified
raw
history blame
8.05 kB
import base64
import logging
import os
from io import BytesIO
from typing import Any
from smolagents import (
CodeAgent,
DuckDuckGoSearchTool,
OpenAIServerModel,
VisitWebpageTool,
WikipediaSearchTool,
tool,
)
system_prompt = """You are an AI Agent that is tasked to answer questions in a concise and accurate manner.
I will ask you a question and provide you with additional context if available.
Context can be in the form of Data(data), Code(code), Audio(audio), or Images(image_url).
Context is provided by specifying the content type followed by the content itself.
For example: code: print("Hello World") or Data: [1, 2, 3, 4, 5] or audio: [base64 encoded audio] or image_url: [base64 encoded image].
YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings.
DO NOT use formatting such as bold, italics, or code blocks in your final answer.
DO NOT use sources, references, or abbreviations in your final answer.
If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise.
If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise.
If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.
If you are asked for a specific number format, follow the instructions carefully.
If you are asked for a number only answer with the number itself, without any additional text or formatting.
If you are asked for a string only answer with the string itself, without any additional text or formatting.
If you are asked for a list only answer with the list itself, without any additional text or formatting.
Think step by step. Report your thoughts.
Finish your answer with the following template:
FINAL ANSWER: [YOUR FINAL ANSWER].
For example, if the question is "What is the capital of France?", you should answer:
FINAL ANSWER: Paris
If the question is "What is 2 + 2?", you should answer:
FINAL ANSWER: 4
If the question is "What is 1 divided by 2, answer with 2 digits after the decimal point?", you should answer:
FINAL ANSWER: 0.50
If the quesion is "What is 10 * 10 with four digits after the decimal point?", you should answer:
FINAL ANSWER: 100.0000
"""
# def is_correct_format(answer: str, _) -> bool:
# """Check if the answer contains a final answer in the correct format.
# Args:
# answer: The answer to check.
# Returns:
# True if the answer contains a final answer, False otherwise.
# This ensures the final output is in the correct format.
# """
# return (
# "ANSWER:" in answer
# or "FINAL ANSWER:" in answer
# or "Answer:" in answer
# or "Final Answer:" in answer
# or "answer:" in answer
# or "final answer:" in answer
# or "answer:" in answer.lower()
# or "final answer:" in answer.lower()
# )
@tool
def wikipedia_suggested_page(query: str) -> str:
"""Search Wikipedia for suggested pages based on the query.
Args:
query: The search query. The query should be coarse and not provide too many details.
E.g. "Python programming" or "Artificial Intelligence".
Returns:
A list of suggested page titles. Pages are \n separated.
"""
from wikipedia import suggest
try:
return suggest(query)
except Exception as e:
logging.error(f"Error fetching Wikipedia suggestions for '{query}': {e}")
return f"Error fetching suggestions: {e}"
@tool
def wikipedia_page(title: str) -> str:
"""Search Wikipedia for a page based on the title.
Args:
title: The title of the Wikipedia page to search for.
Returns:
The content of the Wikipedia page.
"""
from wikipedia import page
try:
return page(title, auto_suggest=True).content
except Exception as e:
logging.error(f"Error fetching Wikipedia page for '{title}': {e}")
return f"Error fetching page: {e}"
class BasicAgent:
def __init__(self):
model = OpenAIServerModel(
model_id="gpt-4o-mini",
api_key=os.getenv("OPENAI_API_KEY"),
temperature=0.0,
)
search = DuckDuckGoSearchTool(max_results=5)
# speech_to_text = SpeechToTextTool()
visitor = VisitWebpageTool(max_output_length=4000)
wiki_search = WikipediaSearchTool()
self.agent = CodeAgent(
max_steps=10,
verbosity_level=0,
tools=[
search,
# speech_to_text,
visitor,
wiki_search,
wikipedia_suggested_page,
wikipedia_page,
],
model=model,
instructions=system_prompt,
additional_authorized_imports=["pandas", "numpy"],
use_structured_outputs_internally=True,
add_base_tools=True,
)
logging.info(
f"System prompt set for BasicAgent: {self.agent.memory.system_prompt}"
)
def __call__(self, question: str, content, content_type) -> Any:
match content_type:
case "xlsx":
additional_args = {"data": content}
case "py":
additional_args = {"code": content}
case "audio":
additional_args = {"audio": content}
case "png":
buffer = BytesIO()
content.save(buffer, format="PNG")
buffer.seek(0)
image_content = (
"data:image/png;base64,"
+ base64.b64encode(buffer.getvalue()).decode("utf-8")
)
additional_args = {"image_url": image_content}
case _:
additional_args = None
response = self.agent.run(
question,
additional_args=additional_args,
images=[content] if content_type == "png" else None,
reset=True,
)
return response
@staticmethod
def formatting(answer: str) -> str:
"""Extract the final answer from the response."""
if "FINAL ANSWER:" in answer:
answer = answer.split("FINAL ANSWER:")[-1].strip()
if "ANSWER:" in answer:
answer = answer.split("ANSWER:")[-1].strip()
if "Answer:" in answer:
answer = answer.split("Answer:")[-1].strip()
if "Final Answer:" in answer:
answer = answer.split("Final Answer:")[-1].strip()
if "answer:" in answer.lower():
answer = answer.split("answer:")[-1].strip()
if "final answer:" in answer.lower():
answer = answer.split("final answer:")[-1].strip()
if "answer is:" in answer.lower():
answer = answer.split("answer is:")[-1].strip()
if "is:" in answer.lower():
answer = answer.split("is:")[-1].strip()
if "**" in answer:
answer = answer.split("**")[-1].strip().replace("**", "")
if "```" in answer:
answer = answer.split("```")[-1].strip().replace("```", "")
if "```python" in answer:
answer = answer.split("```python")[-1].strip().replace("```", "")
if "```json" in answer:
answer = answer.split("```json")[-1].strip().replace("```", "")
if "```yaml" in answer:
answer = answer.split("```yaml")[-1].strip().replace("```", "")
if "```txt" in answer:
answer = answer.split("```txt")[-1].strip().replace("```", "")
answer = answer.capitalize()
answer = answer.replace('"', '').strip()
answer = answer.replace("'", "").strip()
answer = answer.replace("[", "").replace("]", "").strip()
return answer.strip() # Fallback to return the whole answer if no specific format found