pateas's picture
Update agent.py: refine final answer instructions and adjust BasicAgent model parameters
dc70c2d unverified
import base64
import logging
import os
from io import BytesIO
from typing import Any
from smolagents import (
CodeAgent,
DuckDuckGoSearchTool,
OpenAIServerModel,
VisitWebpageTool,
WikipediaSearchTool,
tool,
)
system_prompt = """You are an AI Agent that is tasked to answer questions in a concise and accurate manner.
I will ask you a question and provide you with additional context if available.
Context can be in the form of Data(data), Code(code), Audio(audio), or Images(image_url).
Context is provided by specifying the content type followed by the content itself.
For example: code: print("Hello World") or Data: [1, 2, 3, 4, 5] or audio: [base64 encoded audio] or image_url: [base64 encoded image].
YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings.
DO NOT use formatting such as bold, italics, or code blocks in your final answer.
DO NOT use sources, references, or abbreviations in your final answer.
If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise.
If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise.
If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.
If you are asked for a specific number format, follow the instructions carefully.
If you are asked for a number only answer with the number itself, without any additional text or formatting.
If you are asked for a string only answer with the string itself, without any additional text or formatting.
If you are asked for a list only answer with the list itself, without any additional text or formatting.
Think step by step. Report your thoughts.
Finish your answer with the following template:
FINAL ANSWER: [YOUR FINAL ANSWER].
For example, if the question is "What is the capital of France?", you should answer:
FINAL ANSWER: Paris
If the question is "What is 2 + 2?", you should answer:
FINAL ANSWER: 4
If the question is "What is 1 divided by 2, answer with 2 digits after the decimal point?", you should answer:
FINAL ANSWER: 0.50
If the quesion is "What is 10 * 10 with four digits after the decimal point?", you should answer:
FINAL ANSWER: 100.0000
"""
# def is_correct_format(answer: str, _) -> bool:
# """Check if the answer contains a final answer in the correct format.
# Args:
# answer: The answer to check.
# Returns:
# True if the answer contains a final answer, False otherwise.
# This ensures the final output is in the correct format.
# """
# return (
# "ANSWER:" in answer
# or "FINAL ANSWER:" in answer
# or "Answer:" in answer
# or "Final Answer:" in answer
# or "answer:" in answer
# or "final answer:" in answer
# or "answer:" in answer.lower()
# or "final answer:" in answer.lower()
# )
@tool
def wikipedia_suggested_page(query: str) -> str:
"""Search Wikipedia for suggested pages based on the query.
Args:
query: The search query. The query should be coarse and not provide too many details.
E.g. "Python programming" or "Artificial Intelligence".
Returns:
A list of suggested page titles. Pages are \n separated.
"""
from wikipedia import suggest
try:
return suggest(query)
except Exception as e:
logging.error(f"Error fetching Wikipedia suggestions for '{query}': {e}")
return f"Error fetching suggestions: {e}"
@tool
def wikipedia_page(title: str) -> str:
"""Search Wikipedia for a page based on the title.
Args:
title: The title of the Wikipedia page to search for.
Returns:
The content of the Wikipedia page.
"""
from wikipedia import page
try:
return page(title, auto_suggest=True).content
except Exception as e:
logging.error(f"Error fetching Wikipedia page for '{title}': {e}")
return f"Error fetching page: {e}"
class BasicAgent:
def __init__(self):
model = OpenAIServerModel(
model_id="gpt-4o-mini",
api_key=os.getenv("OPENAI_API_KEY"),
temperature=0.0,
)
search = DuckDuckGoSearchTool(max_results=5)
# speech_to_text = SpeechToTextTool()
visitor = VisitWebpageTool(max_output_length=4000)
wiki_search = WikipediaSearchTool()
self.agent = CodeAgent(
max_steps=10,
verbosity_level=0,
tools=[
search,
# speech_to_text,
visitor,
wiki_search,
wikipedia_suggested_page,
wikipedia_page,
],
model=model,
instructions=system_prompt,
additional_authorized_imports=["pandas", "numpy"],
use_structured_outputs_internally=True,
add_base_tools=True,
)
logging.info(
f"System prompt set for BasicAgent: {self.agent.memory.system_prompt}"
)
def __call__(self, question: str, content, content_type) -> Any:
match content_type:
case "xlsx":
additional_args = {"data": content}
case "py":
additional_args = {"code": content}
case "audio":
additional_args = {"audio": content}
case "png":
buffer = BytesIO()
content.save(buffer, format="PNG")
buffer.seek(0)
image_content = (
"data:image/png;base64,"
+ base64.b64encode(buffer.getvalue()).decode("utf-8")
)
additional_args = {"image_url": image_content}
case _:
additional_args = None
response = self.agent.run(
question,
additional_args=additional_args,
images=[content] if content_type == "png" else None,
reset=True,
)
return response
@staticmethod
def formatting(answer: str) -> str:
"""Extract the final answer from the response."""
if "FINAL ANSWER:" in answer:
answer = answer.split("FINAL ANSWER:")[-1].strip()
if "ANSWER:" in answer:
answer = answer.split("ANSWER:")[-1].strip()
if "Answer:" in answer:
answer = answer.split("Answer:")[-1].strip()
if "Final Answer:" in answer:
answer = answer.split("Final Answer:")[-1].strip()
if "answer:" in answer.lower():
answer = answer.split("answer:")[-1].strip()
if "final answer:" in answer.lower():
answer = answer.split("final answer:")[-1].strip()
if "answer is:" in answer.lower():
answer = answer.split("answer is:")[-1].strip()
if "is:" in answer.lower():
answer = answer.split("is:")[-1].strip()
if "**" in answer:
answer = answer.split("**")[-1].strip().replace("**", "")
if "```" in answer:
answer = answer.split("```")[-1].strip().replace("```", "")
if "```python" in answer:
answer = answer.split("```python")[-1].strip().replace("```", "")
if "```json" in answer:
answer = answer.split("```json")[-1].strip().replace("```", "")
if "```yaml" in answer:
answer = answer.split("```yaml")[-1].strip().replace("```", "")
if "```txt" in answer:
answer = answer.split("```txt")[-1].strip().replace("```", "")
answer = answer.capitalize()
answer = answer.replace('"', '').strip()
answer = answer.replace("'", "").strip()
answer = answer.replace("[", "").replace("]", "").strip()
return answer.strip() # Fallback to return the whole answer if no specific format found