File size: 4,132 Bytes
e836bd4 188a166 bd03e7f 188a166 09253eb 188a166 09253eb 188a166 09253eb 188a166 4695b90 188a166 09253eb 188a166 4695b90 188a166 09253eb 188a166 4695b90 09253eb 188a166 09253eb 188a166 09253eb 188a166 09253eb 188a166 09253eb 188a166 09253eb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
import os
import asyncio
import re
from typing import Any
from llama_index.llms.openai import OpenAI
from llama_index.core.agent.react import ReActAgent
from llama_index.core.tools import FunctionTool
# Correct import for LlamaIndex >= 0.10
from llama_index.tools.duckduckgo_search import DuckDuckGoSearchTool
# Simple tool: Evaluate Python code for math/code questions
def eval_python_code(code: str) -> str:
try:
return str(eval(code, {"__builtins__": {}}))
except Exception as e:
return f"ERROR: {e}"
# Strict output formatting for GAIA
def format_gaia_answer(answer: str, question: str = "") -> str:
if not answer:
return ""
answer = re.sub(r'(?i)final answer:?\s*', '', answer).strip()
answer = re.sub(r'(?i)i(\'?m| cannot| can\'t| unable to| apologize| not available|process the file).*', '', answer).strip()
if answer.startswith('"') and answer.endswith('"'): answer = answer[1:-1]
if answer.startswith('[') and answer.endswith(']'): answer = answer[1:-1]
if not re.match(r'^[A-Za-z]+\.$', answer): answer = re.sub(r'\.$', '', answer)
if question:
if re.search(r'how many|number of|at bats|total sales|albums|output.*python|highest number', question, re.I):
num = re.search(r'(\$?\d[\d,\.]*)', answer)
if num: return num.group(1).replace(',', '')
if 'first name' in question: return answer.split()[0]
if 'surname' in question: return answer.split()[-1]
if 'city' in question: return answer.split()[0]
if re.search(r'IOC country code|award number|NASA', question, re.I):
code = re.search(r'[A-Z0-9]{3,}', answer)
if code: return code.group(0)
if re.search(r'list|comma.*separated|page numbers', question, re.I):
items = [x.strip('",.').lower() for x in re.split(r'[,\n]', answer) if x.strip()]
if 'page numbers' in question:
nums = [int(x) for x in re.findall(r'\d+', answer)]
return ', '.join(str(n) for n in sorted(nums))
if 'ingredient' in question or 'vegetable' in question:
merged = []
skip = False
for i, item in enumerate(items):
if skip: skip = False; continue
if i+1 < len(items) and item in ['sweet', 'green', 'lemon', 'ripe', 'whole', 'fresh']:
merged.append(f"{item} {items[i+1]}")
skip = True
else: merged.append(item)
merged = sorted(set(merged))
return ', '.join(merged)
return ', '.join(items)
return answer.strip().rstrip('.').strip()
# LLM setup
llm = OpenAI(model="gpt-4o", api_key=os.environ.get("OPENAI_API_KEY"))
# Tool registry
tools = [
DuckDuckGoSearchTool(),
FunctionTool.from_defaults(
eval_python_code,
name="python_eval",
description="Evaluate simple Python code and return result as string. Use for math or code output."
),
FunctionTool.from_defaults(
format_gaia_answer,
name="format_gaia_answer",
description="Postprocess and enforce strict GAIA format on answers given a question."
),
]
# Main agent
agent = ReActAgent.from_tools(
tools=tools,
llm=llm,
system_prompt="You are a helpful GAIA benchmark agent. For every question, use the best tools available and always return only the final answer in the strict GAIA-required format—never explain, never apologize.",
verbose=False
)
# Async entrypoint
async def answer_question(question: str, task_id: str = None, file_path: str = None) -> str:
result = await agent.achat(question)
return result.response
# Synchronous wrapper
def answer_question_sync(question: str, task_id: str = None, file_path: str = None) -> str:
return asyncio.run(answer_question(question, task_id, file_path))
# For compatibility with app.py (GAIAAgent class)
class GaiaAgent:
def __call__(self, question: str, task_id: str = None, file_path: str = None) -> str:
return answer_question_sync(question, task_id, file_path) |