|
import json |
|
import logging |
|
import os |
|
import re |
|
import traceback |
|
from typing import AsyncGenerator |
|
import uuid |
|
from aworld.config.conf import AgentConfig, TaskConfig |
|
from aworld.agents.llm_agent import Agent |
|
from aworld.core.task import Task |
|
from aworld.runner import Runners |
|
from aworld.output.ui.base import AworldUI |
|
from aworld.output.ui.markdown_aworld_ui import MarkdownAworldUI |
|
from aworld.output.base import Output |
|
from .utils import ( |
|
add_file_path, |
|
load_dataset_meta_dict, |
|
question_scorer, |
|
) |
|
from .prompt import system_prompt |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class GaiaAgentRunner: |
|
""" |
|
Gaia Agent Runner |
|
""" |
|
|
|
def __init__( |
|
self, |
|
llm_provider: str, |
|
llm_model_name: str, |
|
llm_base_url: str, |
|
llm_api_key: str, |
|
llm_temperature: float = 0.0, |
|
mcp_config: dict = {}, |
|
): |
|
self.agent_config = AgentConfig( |
|
llm_provider=llm_provider, |
|
llm_model_name=llm_model_name, |
|
llm_api_key=llm_api_key, |
|
llm_base_url=llm_base_url, |
|
llm_temperature=llm_temperature, |
|
) |
|
|
|
self.super_agent = Agent( |
|
conf=self.agent_config, |
|
name="gaia_super_agent", |
|
system_prompt=system_prompt, |
|
mcp_config=mcp_config, |
|
mcp_servers=mcp_config.get("mcpServers", {}).keys(), |
|
) |
|
|
|
self.gaia_dataset_path = os.path.abspath( |
|
os.getenv( |
|
"GAIA_DATASET_PATH", |
|
os.path.join(os.path.dirname(os.path.abspath(__file__)), "GAIA", "2023"), |
|
) |
|
) |
|
self.full_dataset = load_dataset_meta_dict(self.gaia_dataset_path) |
|
logger.info( |
|
f"Gaia Agent Runner initialized: super_agent={self.super_agent}, agent_config={self.agent_config}, gaia_dataset_path={self.gaia_dataset_path}, full_dataset={len(self.full_dataset)}" |
|
) |
|
|
|
async def run(self, prompt: str): |
|
yield (f"\n### GAIA Agent Start!") |
|
|
|
mcp_servers = "\n- ".join(self.super_agent.mcp_servers) |
|
yield (f"\n```gaia_agent_status\n- {mcp_servers}\n```\n") |
|
|
|
question = None |
|
data_item = None |
|
task_id = None |
|
try: |
|
json_data = json.loads(prompt) |
|
task_id = json_data["task_id"] |
|
|
|
data_item = self.full_dataset[task_id] |
|
question = add_file_path(data_item, file_path=self.gaia_dataset_path)[ |
|
"Question" |
|
] |
|
yield (f"\n```gaia_question\n{json.dumps(data_item, indent=2)}\n```\n") |
|
except Exception as e: |
|
pass |
|
|
|
if not question: |
|
logger.warning( |
|
"Could not find GAIA question for prompt, chat using prompt directly!" |
|
) |
|
yield (f"\n{prompt}\n") |
|
question = prompt |
|
|
|
try: |
|
task = Task( |
|
id=task_id + "." + uuid.uuid1().hex if task_id else uuid.uuid1().hex, |
|
input=question, |
|
agent=self.super_agent, |
|
event_driven=False, |
|
conf=TaskConfig(max_steps=20), |
|
) |
|
|
|
last_output: Output = None |
|
rich_ui = MarkdownAworldUI() |
|
async for output in Runners.streamed_run_task(task).stream_events(): |
|
logger.info(f"Gaia Agent Ouput: {output}") |
|
res = await AworldUI.parse_output(output, rich_ui) |
|
for item in res if isinstance(res, list) else [res]: |
|
if isinstance(item, AsyncGenerator): |
|
async for sub_item in item: |
|
yield sub_item |
|
else: |
|
yield item |
|
last_output = item |
|
|
|
logger.info(f"Gaia Agent Last Output: {last_output}") |
|
|
|
if data_item and last_output: |
|
final_response = self._judge_answer(data_item, last_output) |
|
yield final_response |
|
|
|
except Exception as e: |
|
logger.error(f"Error processing {prompt}, error: {traceback.format_exc()}") |
|
|
|
def _judge_answer(self, data_item: dict, result: Output): |
|
answer = result |
|
match = re.search(r"<answer>(.*?)</answer>", answer) |
|
if match: |
|
answer = match.group(1) |
|
logger.info(f"Agent answer: {answer}") |
|
logger.info(f"Correct answer: {data_item['Final answer']}") |
|
|
|
if question_scorer(answer, data_item["Final answer"]): |
|
logger.info(f"Question {data_item['task_id']} Correct!") |
|
else: |
|
logger.info(f"Question {data_item['task_id']} Incorrect!") |
|
|
|
|
|
correct = question_scorer(answer, data_item["Final answer"]) |
|
new_result = { |
|
"task_id": data_item["task_id"], |
|
"level": data_item["Level"], |
|
"question": data_item["Question"], |
|
"answer": data_item["Final answer"], |
|
"response": answer, |
|
"is_correct": correct, |
|
} |
|
return f"\n## Final Result: {'✅' if correct else '❌'}\n \n```gaia_result\n{json.dumps(new_result, indent=2)}\n```" |
|
else: |
|
new_result = answer |
|
return f"\n## Final Result:\n \n```gaia_result\n{json.dumps(new_result, indent=2)}\n```" |
|
|
|
|
|
if __name__ == "__main__": |
|
import asyncio |
|
import argparse |
|
from datetime import datetime |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
output_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "output") |
|
if not os.path.exists(output_dir): |
|
os.makedirs(output_dir) |
|
|
|
output_file = os.path.join( |
|
output_dir, f"output_{datetime.now().strftime('%Y%m%d_%H%M%S')}.md" |
|
) |
|
|
|
async def main(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--prompt", type=str, default="") |
|
args = parser.parse_args() |
|
|
|
try: |
|
prompt = args.prompt |
|
|
|
llm_provider = os.getenv("LLM_PROVIDER") |
|
llm_model_name = os.getenv("LLM_MODEL_NAME") |
|
llm_api_key = os.getenv("LLM_API_KEY") |
|
llm_base_url = os.getenv("LLM_BASE_URL") |
|
llm_temperature = os.getenv("LLM_TEMPERATURE", 0.0) |
|
|
|
def send_output(output): |
|
with open(output_file, "a") as f: |
|
f.write(f"{output}\n") |
|
|
|
async for i in GaiaAgentRunner( |
|
llm_provider=llm_provider, |
|
llm_model_name=llm_model_name, |
|
llm_base_url=llm_base_url, |
|
llm_api_key=llm_api_key, |
|
llm_temperature=llm_temperature, |
|
).run(prompt): |
|
send_output(i) |
|
except Exception as e: |
|
logger.error( |
|
f"Error processing {args.prompt}, error: {traceback.format_exc()}" |
|
) |
|
|
|
asyncio.run(main()) |
|
|