Spaces:
Build error
Build error
import logging | |
import os | |
import sys | |
from datetime import datetime, timezone | |
from pathlib import Path | |
from typing import Any, Optional | |
import click | |
from click_default_group import DefaultGroup | |
from dotenv import load_dotenv | |
from agbenchmark.config import AgentBenchmarkConfig | |
from agbenchmark.utils.logging import configure_logging | |
load_dotenv() | |
# try: | |
# if os.getenv("HELICONE_API_KEY"): | |
# import helicone # noqa | |
# helicone_enabled = True | |
# else: | |
# helicone_enabled = False | |
# except ImportError: | |
# helicone_enabled = False | |
class InvalidInvocationError(ValueError): | |
pass | |
logger = logging.getLogger(__name__) | |
BENCHMARK_START_TIME_DT = datetime.now(timezone.utc) | |
BENCHMARK_START_TIME = BENCHMARK_START_TIME_DT.strftime("%Y-%m-%dT%H:%M:%S+00:00") | |
# if helicone_enabled: | |
# from helicone.lock import HeliconeLockManager | |
# HeliconeLockManager.write_custom_property( | |
# "benchmark_start_time", BENCHMARK_START_TIME | |
# ) | |
def cli( | |
debug: bool, | |
) -> Any: | |
configure_logging(logging.DEBUG if debug else logging.INFO) | |
def start(): | |
raise DeprecationWarning( | |
"`agbenchmark start` is deprecated. Use `agbenchmark run` instead." | |
) | |
# @click.argument( | |
# "agent_path", | |
# type=click.Path(exists=True, file_okay=False, path_type=Path), | |
# required=False, | |
# ) | |
def run( | |
maintain: bool, | |
improve: bool, | |
explore: bool, | |
mock: bool, | |
no_dep: bool, | |
nc: bool, | |
keep_answers: bool, | |
test: tuple[str], | |
category: tuple[str], | |
skip_category: tuple[str], | |
attempts: int, | |
cutoff: Optional[int] = None, | |
backend: Optional[bool] = False, | |
# agent_path: Optional[Path] = None, | |
) -> None: | |
""" | |
Run the benchmark on the agent in the current directory. | |
Options marked with (+) can be specified multiple times, to select multiple items. | |
""" | |
from agbenchmark.main import run_benchmark, validate_args | |
agbenchmark_config = AgentBenchmarkConfig.load() | |
logger.debug(f"agbenchmark_config: {agbenchmark_config.agbenchmark_config_dir}") | |
try: | |
validate_args( | |
maintain=maintain, | |
improve=improve, | |
explore=explore, | |
tests=test, | |
categories=category, | |
skip_categories=skip_category, | |
no_cutoff=nc, | |
cutoff=cutoff, | |
) | |
except InvalidInvocationError as e: | |
logger.error("Error: " + "\n".join(e.args)) | |
sys.exit(1) | |
original_stdout = sys.stdout # Save the original standard output | |
exit_code = None | |
if backend: | |
with open("backend/backend_stdout.txt", "w") as f: | |
sys.stdout = f | |
exit_code = run_benchmark( | |
config=agbenchmark_config, | |
maintain=maintain, | |
improve=improve, | |
explore=explore, | |
mock=mock, | |
no_dep=no_dep, | |
no_cutoff=nc, | |
keep_answers=keep_answers, | |
tests=test, | |
categories=category, | |
skip_categories=skip_category, | |
attempts_per_challenge=attempts, | |
cutoff=cutoff, | |
) | |
sys.stdout = original_stdout | |
else: | |
exit_code = run_benchmark( | |
config=agbenchmark_config, | |
maintain=maintain, | |
improve=improve, | |
explore=explore, | |
mock=mock, | |
no_dep=no_dep, | |
no_cutoff=nc, | |
keep_answers=keep_answers, | |
tests=test, | |
categories=category, | |
skip_categories=skip_category, | |
attempts_per_challenge=attempts, | |
cutoff=cutoff, | |
) | |
sys.exit(exit_code) | |
def serve(port: Optional[int] = None): | |
"""Serve the benchmark frontend and API on port 8080.""" | |
import uvicorn | |
from agbenchmark.app import setup_fastapi_app | |
config = AgentBenchmarkConfig.load() | |
app = setup_fastapi_app(config) | |
# Run the FastAPI application using uvicorn | |
port = port or int(os.getenv("PORT", 8080)) | |
uvicorn.run(app, host="0.0.0.0", port=port) | |
def config(): | |
"""Displays info regarding the present AGBenchmark config.""" | |
from .utils.utils import pretty_print_model | |
try: | |
config = AgentBenchmarkConfig.load() | |
except FileNotFoundError as e: | |
click.echo(e, err=True) | |
return 1 | |
pretty_print_model(config, include_header=False) | |
def challenge(): | |
logging.getLogger().setLevel(logging.WARNING) | |
def list_challenges(include_unavailable: bool, only_names: bool, output_json: bool): | |
"""Lists [available|all] challenges.""" | |
import json | |
from tabulate import tabulate | |
from .challenges.builtin import load_builtin_challenges | |
from .challenges.webarena import load_webarena_challenges | |
from .utils.data_types import Category, DifficultyLevel | |
from .utils.utils import sorted_by_enum_index | |
DIFFICULTY_COLORS = { | |
difficulty: color | |
for difficulty, color in zip( | |
DifficultyLevel, | |
["black", "blue", "cyan", "green", "yellow", "red", "magenta", "white"], | |
) | |
} | |
CATEGORY_COLORS = { | |
category: f"bright_{color}" | |
for category, color in zip( | |
Category, | |
["blue", "cyan", "green", "yellow", "magenta", "red", "white", "black"], | |
) | |
} | |
# Load challenges | |
challenges = filter( | |
lambda c: c.info.available or include_unavailable, | |
[ | |
*load_builtin_challenges(), | |
*load_webarena_challenges(skip_unavailable=False), | |
], | |
) | |
challenges = sorted_by_enum_index( | |
challenges, DifficultyLevel, key=lambda c: c.info.difficulty | |
) | |
if only_names: | |
if output_json: | |
click.echo(json.dumps([c.info.name for c in challenges])) | |
return | |
for c in challenges: | |
click.echo( | |
click.style(c.info.name, fg=None if c.info.available else "black") | |
) | |
return | |
if output_json: | |
click.echo( | |
json.dumps([json.loads(c.info.model_dump_json()) for c in challenges]) | |
) | |
return | |
headers = tuple( | |
click.style(h, bold=True) for h in ("Name", "Difficulty", "Categories") | |
) | |
table = [ | |
tuple( | |
v if challenge.info.available else click.style(v, fg="black") | |
for v in ( | |
challenge.info.name, | |
( | |
click.style( | |
challenge.info.difficulty.value, | |
fg=DIFFICULTY_COLORS[challenge.info.difficulty], | |
) | |
if challenge.info.difficulty | |
else click.style("-", fg="black") | |
), | |
" ".join( | |
click.style(cat.value, fg=CATEGORY_COLORS[cat]) | |
for cat in sorted_by_enum_index(challenge.info.category, Category) | |
), | |
) | |
) | |
for challenge in challenges | |
] | |
click.echo(tabulate(table, headers=headers)) | |
def info(name: str, json: bool): | |
from itertools import chain | |
from .challenges.builtin import load_builtin_challenges | |
from .challenges.webarena import load_webarena_challenges | |
from .utils.utils import pretty_print_model | |
for challenge in chain( | |
load_builtin_challenges(), | |
load_webarena_challenges(skip_unavailable=False), | |
): | |
if challenge.info.name != name: | |
continue | |
if json: | |
click.echo(challenge.info.model_dump_json()) | |
break | |
pretty_print_model(challenge.info) | |
break | |
else: | |
click.echo(click.style(f"Unknown challenge '{name}'", fg="red"), err=True) | |
def version(): | |
"""Print version info for the AGBenchmark application.""" | |
import toml | |
package_root = Path(__file__).resolve().parent.parent | |
pyproject = toml.load(package_root / "pyproject.toml") | |
version = pyproject["tool"]["poetry"]["version"] | |
click.echo(f"AGBenchmark version {version}") | |
if __name__ == "__main__": | |
cli() | |