|
|
|
import argparse |
|
import json |
|
import os |
|
import threading |
|
from concurrent.futures import ThreadPoolExecutor, as_completed |
|
from datetime import datetime |
|
from pathlib import Path |
|
from typing import Any |
|
|
|
import datasets |
|
import pandas as pd |
|
from dotenv import load_dotenv |
|
from huggingface_hub import login, snapshot_download |
|
from scripts.reformulator import prepare_response |
|
from scripts.run_agents import ( |
|
get_single_file_description, |
|
get_zip_description, |
|
) |
|
from scripts.text_inspector_tool import TextInspectorTool |
|
from scripts.text_web_browser import ( |
|
ArchiveSearchTool, |
|
FinderTool, |
|
FindNextTool, |
|
PageDownTool, |
|
PageUpTool, |
|
SimpleTextBrowser, |
|
VisitTool, |
|
) |
|
from scripts.visual_qa import visualizer |
|
from tqdm import tqdm |
|
|
|
from smolagents import ( |
|
CodeAgent, |
|
GoogleSearchTool, |
|
LiteLLMModel, |
|
Model, |
|
ToolCallingAgent, |
|
) |
|
|
|
|
|
load_dotenv(override=True) |
|
login(os.getenv("HF_TOKEN")) |
|
|
|
append_answer_lock = threading.Lock() |
|
|
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--concurrency", type=int, default=8) |
|
parser.add_argument("--model-id", type=str, default="o1") |
|
parser.add_argument("--run-name", type=str, required=True) |
|
parser.add_argument("--set-to-run", type=str, default="validation") |
|
parser.add_argument("--use-open-models", type=bool, default=False) |
|
parser.add_argument("--use-raw-dataset", action="store_true") |
|
return parser.parse_args() |
|
|
|
|
|
|
|
|
|
print("Make sure you deactivated any VPN like Tailscale, else some URLs will be blocked!") |
|
|
|
custom_role_conversions = {"tool-call": "assistant", "tool-response": "user"} |
|
|
|
|
|
user_agent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/119.0.0.0 Safari/537.36 Edg/119.0.0.0" |
|
|
|
BROWSER_CONFIG = { |
|
"viewport_size": 1024 * 5, |
|
"downloads_folder": "downloads_folder", |
|
"request_kwargs": { |
|
"headers": {"User-Agent": user_agent}, |
|
"timeout": 300, |
|
}, |
|
"serpapi_key": os.getenv("SERPAPI_API_KEY"), |
|
} |
|
|
|
os.makedirs(f"./{BROWSER_CONFIG['downloads_folder']}", exist_ok=True) |
|
|
|
|
|
def create_agent_team(model: Model): |
|
text_limit = 100000 |
|
ti_tool = TextInspectorTool(model, text_limit) |
|
|
|
browser = SimpleTextBrowser(**BROWSER_CONFIG) |
|
|
|
WEB_TOOLS = [ |
|
GoogleSearchTool(provider="serper"), |
|
VisitTool(browser), |
|
PageUpTool(browser), |
|
PageDownTool(browser), |
|
FinderTool(browser), |
|
FindNextTool(browser), |
|
ArchiveSearchTool(browser), |
|
TextInspectorTool(model, text_limit), |
|
] |
|
|
|
text_webbrowser_agent = ToolCallingAgent( |
|
model=model, |
|
tools=WEB_TOOLS, |
|
max_steps=20, |
|
verbosity_level=2, |
|
planning_interval=4, |
|
name="search_agent", |
|
description="""A team member that will search the internet to answer your question. |
|
Ask him for all your questions that require browsing the web. |
|
Provide him as much context as possible, in particular if you need to search on a specific timeframe! |
|
And don't hesitate to provide him with a complex search task, like finding a difference between two webpages. |
|
Your request must be a real sentence, not a google search! Like "Find me this information (...)" rather than a few keywords. |
|
""", |
|
provide_run_summary=True, |
|
) |
|
text_webbrowser_agent.prompt_templates["managed_agent"]["task"] += """You can navigate to .txt online files. |
|
If a non-html page is in another format, especially .pdf or a Youtube video, use tool 'inspect_file_as_text' to inspect it. |
|
Additionally, if after some searching you find out that you need more information to answer the question, you can use `final_answer` with your request for clarification as argument to request for more information.""" |
|
|
|
manager_agent = CodeAgent( |
|
model=model, |
|
tools=[visualizer, ti_tool], |
|
max_steps=12, |
|
verbosity_level=2, |
|
additional_authorized_imports=["*"], |
|
planning_interval=4, |
|
managed_agents=[text_webbrowser_agent], |
|
) |
|
return manager_agent |
|
|
|
|
|
def load_gaia_dataset(use_raw_dataset: bool, set_to_run: str) -> datasets.Dataset: |
|
if not os.path.exists("data/gaia"): |
|
if use_raw_dataset: |
|
snapshot_download( |
|
repo_id="gaia-benchmark/GAIA", |
|
repo_type="dataset", |
|
local_dir="data/gaia", |
|
ignore_patterns=[".gitattributes", "README.md"], |
|
) |
|
else: |
|
|
|
snapshot_download( |
|
repo_id="smolagents/GAIA-annotated", |
|
repo_type="dataset", |
|
local_dir="data/gaia", |
|
ignore_patterns=[".gitattributes", "README.md"], |
|
) |
|
|
|
def preprocess_file_paths(row): |
|
if len(row["file_name"]) > 0: |
|
row["file_name"] = f"data/gaia/{set_to_run}/" + row["file_name"] |
|
return row |
|
|
|
eval_ds = datasets.load_dataset( |
|
"data/gaia/GAIA.py", |
|
name="2023_all", |
|
split=set_to_run, |
|
|
|
) |
|
|
|
eval_ds = eval_ds.rename_columns({"Question": "question", "Final answer": "true_answer", "Level": "task"}) |
|
eval_ds = eval_ds.map(preprocess_file_paths) |
|
return eval_ds |
|
|
|
|
|
def append_answer(entry: dict, jsonl_file: str) -> None: |
|
jsonl_path = Path(jsonl_file) |
|
jsonl_path.parent.mkdir(parents=True, exist_ok=True) |
|
with append_answer_lock, open(jsonl_file, "a", encoding="utf-8") as fp: |
|
fp.write(json.dumps(entry) + "\n") |
|
assert jsonl_path.exists(), "File not found!" |
|
print("Answer exported to file:", jsonl_path.resolve()) |
|
|
|
|
|
def answer_single_question( |
|
example: dict, model_id: str, answers_file: str, visual_inspection_tool: TextInspectorTool |
|
) -> None: |
|
model_params: dict[str, Any] = { |
|
"model_id": model_id, |
|
"custom_role_conversions": custom_role_conversions, |
|
} |
|
if model_id == "o1": |
|
model_params["reasoning_effort"] = "high" |
|
model_params["max_completion_tokens"] = 8192 |
|
else: |
|
model_params["max_tokens"] = 4096 |
|
model = LiteLLMModel(**model_params) |
|
|
|
document_inspection_tool = TextInspectorTool(model, 100000) |
|
|
|
agent = create_agent_team(model) |
|
|
|
augmented_question = """You have one question to answer. It is paramount that you provide a correct answer. |
|
Give it all you can: I know for a fact that you have access to all the relevant tools to solve it and find the correct answer (the answer does exist). |
|
Failure or 'I cannot answer' or 'None found' will not be tolerated, success will be rewarded. |
|
Run verification steps if that's needed, you must make sure you find the correct answer! Here is the task: |
|
|
|
""" + example["question"] |
|
|
|
if example["file_name"]: |
|
if ".zip" in example["file_name"]: |
|
prompt_use_files = "\n\nTo solve the task above, you will have to use these attached files:\n" |
|
prompt_use_files += get_zip_description( |
|
example["file_name"], example["question"], visual_inspection_tool, document_inspection_tool |
|
) |
|
else: |
|
prompt_use_files = "\n\nTo solve the task above, you will have to use this attached file:\n" |
|
prompt_use_files += get_single_file_description( |
|
example["file_name"], example["question"], visual_inspection_tool, document_inspection_tool |
|
) |
|
augmented_question += prompt_use_files |
|
|
|
start_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") |
|
try: |
|
|
|
final_result = agent.run(augmented_question) |
|
|
|
agent_memory = agent.write_memory_to_messages() |
|
|
|
final_result = prepare_response(augmented_question, agent_memory, reformulation_model=model) |
|
|
|
output = str(final_result) |
|
for memory_step in agent.memory.steps: |
|
memory_step.model_input_messages = None |
|
intermediate_steps = agent_memory |
|
|
|
|
|
parsing_error = True if any(["AgentParsingError" in step for step in intermediate_steps]) else False |
|
|
|
|
|
iteration_limit_exceeded = True if "Agent stopped due to iteration limit or time limit." in output else False |
|
raised_exception = False |
|
|
|
except Exception as e: |
|
print("Error on ", augmented_question, e) |
|
output = None |
|
intermediate_steps = [] |
|
parsing_error = False |
|
iteration_limit_exceeded = False |
|
exception = e |
|
raised_exception = True |
|
end_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") |
|
token_counts_manager = agent.monitor.get_total_token_counts() |
|
token_counts_web = list(agent.managed_agents.values())[0].monitor.get_total_token_counts() |
|
total_token_counts = { |
|
"input": token_counts_manager["input"] + token_counts_web["input"], |
|
"output": token_counts_manager["output"] + token_counts_web["output"], |
|
} |
|
annotated_example = { |
|
"agent_name": model.model_id, |
|
"question": example["question"], |
|
"augmented_question": augmented_question, |
|
"prediction": output, |
|
"intermediate_steps": intermediate_steps, |
|
"parsing_error": parsing_error, |
|
"iteration_limit_exceeded": iteration_limit_exceeded, |
|
"agent_error": str(exception) if raised_exception else None, |
|
"task": example["task"], |
|
"task_id": example["task_id"], |
|
"true_answer": example["true_answer"], |
|
"start_time": start_time, |
|
"end_time": end_time, |
|
"token_counts": total_token_counts, |
|
} |
|
append_answer(annotated_example, answers_file) |
|
|
|
|
|
def get_examples_to_answer(answers_file: str, eval_ds: datasets.Dataset) -> list[dict]: |
|
print(f"Loading answers from {answers_file}...") |
|
try: |
|
done_questions = pd.read_json(answers_file, lines=True)["question"].tolist() |
|
print(f"Found {len(done_questions)} previous results!") |
|
except Exception as e: |
|
print("Error when loading records: ", e) |
|
print("No usable records! ▶️ Starting new.") |
|
done_questions = [] |
|
return [line for line in eval_ds.to_list() if line["question"] not in done_questions and line["file_name"]] |
|
|
|
|
|
def main(): |
|
args = parse_args() |
|
print(f"Starting run with arguments: {args}") |
|
|
|
eval_ds = load_gaia_dataset(args.use_raw_dataset, args.set_to_run) |
|
print("Loaded evaluation dataset:") |
|
print(pd.DataFrame(eval_ds)["task"].value_counts()) |
|
|
|
answers_file = f"output/{args.set_to_run}/{args.run_name}.jsonl" |
|
tasks_to_run = get_examples_to_answer(answers_file, eval_ds) |
|
|
|
with ThreadPoolExecutor(max_workers=args.concurrency) as exe: |
|
futures = [ |
|
exe.submit(answer_single_question, example, args.model_id, answers_file, visualizer) |
|
for example in tasks_to_run |
|
] |
|
for f in tqdm(as_completed(futures), total=len(tasks_to_run), desc="Processing tasks"): |
|
f.result() |
|
|
|
|
|
|
|
print("All tasks processed.") |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|