Upload 7 files
Browse filesRefactored adding Evaluator and Runner classes
- agent.py +69 -0
- app.py +10 -13
- evaluator.py +120 -0
- runner.py +71 -0
- tools.py +199 -0
agent.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import contextlib
|
2 |
+
import io
|
3 |
+
import logging
|
4 |
+
import os
|
5 |
+
logger = logging.getLogger(__name__)
|
6 |
+
from models import GoogleModelID, OpenRouterModelID
|
7 |
+
from settings import Settings
|
8 |
+
from smolagents import LiteLLMModel, CodeAgent
|
9 |
+
from smolagents import GoogleSearchTool, VisitWebpageTool, FinalAnswerTool
|
10 |
+
from smolagents.local_python_executor import BASE_PYTHON_TOOLS
|
11 |
+
from tools import GetTaskFileTool, VideoUnderstandingTool, AudioUnderstandingTool
|
12 |
+
from tools import ChessPiecePlacementTool, ChessGameFenTool, BestChessMoveTool, ConvertChessMoveTool
|
13 |
+
|
14 |
+
|
15 |
+
# Base tools may use these to process files
|
16 |
+
BASE_PYTHON_TOOLS["open"] = open
|
17 |
+
BASE_PYTHON_TOOLS["os"] = os
|
18 |
+
BASE_PYTHON_TOOLS["io"] = io
|
19 |
+
BASE_PYTHON_TOOLS["contextlib"] = contextlib
|
20 |
+
BASE_PYTHON_TOOLS["exec"] = exec
|
21 |
+
|
22 |
+
|
23 |
+
class BasicAgent:
|
24 |
+
def __init__(self, settings: Settings):
|
25 |
+
self.agent = CodeAgent(
|
26 |
+
add_base_tools=False,
|
27 |
+
tools=[GoogleSearchTool("serper"),
|
28 |
+
VisitWebpageTool(max_output_length=100000),
|
29 |
+
FinalAnswerTool(),
|
30 |
+
GetTaskFileTool(settings),
|
31 |
+
VideoUnderstandingTool(settings, GoogleModelID.GEMINI_2_0_FLASH),
|
32 |
+
AudioUnderstandingTool(settings, GoogleModelID.GEMINI_2_0_FLASH),
|
33 |
+
ChessPiecePlacementTool(),
|
34 |
+
ChessGameFenTool(settings, OpenRouterModelID.GPT_O4_MINI),
|
35 |
+
BestChessMoveTool(settings),
|
36 |
+
ConvertChessMoveTool(settings, OpenRouterModelID.QWEN_3_14B_FREE)
|
37 |
+
],
|
38 |
+
additional_authorized_imports=[
|
39 |
+
"unicodedata",
|
40 |
+
"stat",
|
41 |
+
"datetime",
|
42 |
+
"random",
|
43 |
+
"pandas",
|
44 |
+
"itertools",
|
45 |
+
"math",
|
46 |
+
"statistics",
|
47 |
+
"queue",
|
48 |
+
"time",
|
49 |
+
"collections",
|
50 |
+
"re",
|
51 |
+
"os"
|
52 |
+
],
|
53 |
+
max_steps=10,
|
54 |
+
verbosity_level=1,
|
55 |
+
model=LiteLLMModel(
|
56 |
+
# model_id=OpenRouterModelID.GPT_O4_MINI,
|
57 |
+
model_id=OpenRouterModelID.GPT_4_1_MINI,
|
58 |
+
# model_id=OpenRouterModelID.GROK_3_BETA,
|
59 |
+
# model_id=OpenRouterModelID.GROK_3_MINI_BETA,
|
60 |
+
api_key = settings.openrouter_api_key.get_secret_value(),
|
61 |
+
temperature=0.0, timeout=180
|
62 |
+
)
|
63 |
+
)
|
64 |
+
# print("BasicAgent initialized.")
|
65 |
+
def __call__(self, question: str) -> str:
|
66 |
+
logger.info(f"Agent received question (first 50 chars): {question[:50]}...")
|
67 |
+
final_answer = self.agent.run(question)
|
68 |
+
logger.info(f"Agent returning fixed answer: {final_answer}")
|
69 |
+
return final_answer
|
app.py
CHANGED
@@ -3,8 +3,8 @@ from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExport
|
|
3 |
from openinference.instrumentation.smolagents import SmolagentsInstrumentor
|
4 |
from opentelemetry.sdk.trace import TracerProvider
|
5 |
from opentelemetry import trace
|
6 |
-
|
7 |
-
|
8 |
from settings import Settings
|
9 |
import os
|
10 |
import pandas as pd
|
@@ -13,8 +13,8 @@ import logging
|
|
13 |
logging.basicConfig(level=logging.INFO, force=True)
|
14 |
logger = logging.getLogger(__name__)
|
15 |
settings = Settings()
|
16 |
-
|
17 |
-
|
18 |
|
19 |
|
20 |
# Create a TracerProvider for OpenTelemetry
|
@@ -45,22 +45,19 @@ EMPTY_RESULTS_TABLE = pd.DataFrame(columns=['task_id', 'question', 'answer'])
|
|
45 |
def run_one(profile: gr.OAuthProfile | None) -> pd.DataFrame:
|
46 |
if not user_logged_in(profile):
|
47 |
return LOGIN_MESSAGE, EMPTY_RESULTS_TABLE
|
48 |
-
|
49 |
-
|
50 |
-
return "You are logged in.", EMPTY_RESULTS_TABLE
|
51 |
|
52 |
def run_all(profile: gr.OAuthProfile | None) -> pd.DataFrame:
|
53 |
if not user_logged_in(profile):
|
54 |
return LOGIN_MESSAGE, EMPTY_RESULTS_TABLE
|
55 |
-
|
56 |
-
|
57 |
-
return "You are logged in.", EMPTY_RESULTS_TABLE
|
58 |
|
59 |
def submit(profile: gr.OAuthProfile | None):
|
60 |
if not user_logged_in(profile):
|
61 |
return LOGIN_MESSAGE
|
62 |
-
|
63 |
-
return "You are logged in."
|
64 |
|
65 |
|
66 |
# --- Build Gradio Interface using Blocks ---
|
@@ -77,7 +74,7 @@ with gr.Blocks() as demo:
|
|
77 |
---
|
78 |
**Disclaimers:**
|
79 |
Once clicking 'Get All Answers', it can take quite some time (this is the time for the agent to go through all 20 questions).
|
80 |
-
The agent will run question tasks in parallel making
|
81 |
The 'Submit All Answers' button will use the most recent agent answers cached in the space for your username.
|
82 |
"""
|
83 |
)
|
|
|
3 |
from openinference.instrumentation.smolagents import SmolagentsInstrumentor
|
4 |
from opentelemetry.sdk.trace import TracerProvider
|
5 |
from opentelemetry import trace
|
6 |
+
from evaluator import Evaluator
|
7 |
+
from runner import Runner
|
8 |
from settings import Settings
|
9 |
import os
|
10 |
import pandas as pd
|
|
|
13 |
logging.basicConfig(level=logging.INFO, force=True)
|
14 |
logger = logging.getLogger(__name__)
|
15 |
settings = Settings()
|
16 |
+
evaluator = Evaluator(settings)
|
17 |
+
runner = Runner(settings)
|
18 |
|
19 |
|
20 |
# Create a TracerProvider for OpenTelemetry
|
|
|
45 |
def run_one(profile: gr.OAuthProfile | None) -> pd.DataFrame:
|
46 |
if not user_logged_in(profile):
|
47 |
return LOGIN_MESSAGE, EMPTY_RESULTS_TABLE
|
48 |
+
questions = [evaluator.get_one_question()]
|
49 |
+
return "Answer one random question...", runner.run_agent(questions)
|
|
|
50 |
|
51 |
def run_all(profile: gr.OAuthProfile | None) -> pd.DataFrame:
|
52 |
if not user_logged_in(profile):
|
53 |
return LOGIN_MESSAGE, EMPTY_RESULTS_TABLE
|
54 |
+
questions = evaluator.get_questions()
|
55 |
+
return "Answer all 20 questions...", runner.run_agent(questions)
|
|
|
56 |
|
57 |
def submit(profile: gr.OAuthProfile | None):
|
58 |
if not user_logged_in(profile):
|
59 |
return LOGIN_MESSAGE
|
60 |
+
evaluator.submit_answers()
|
|
|
61 |
|
62 |
|
63 |
# --- Build Gradio Interface using Blocks ---
|
|
|
74 |
---
|
75 |
**Disclaimers:**
|
76 |
Once clicking 'Get All Answers', it can take quite some time (this is the time for the agent to go through all 20 questions).
|
77 |
+
The agent(s) will run question tasks in parallel making the logs hard to follow. Langfuse instrumentation has been configured.
|
78 |
The 'Submit All Answers' button will use the most recent agent answers cached in the space for your username.
|
79 |
"""
|
80 |
)
|
evaluator.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from settings import Settings
|
2 |
+
from typing import List
|
3 |
+
from models import Question, QuestionAnswerPair, Results
|
4 |
+
import requests
|
5 |
+
import random
|
6 |
+
import json
|
7 |
+
import logging
|
8 |
+
logger = logging.getLogger(__name__)
|
9 |
+
|
10 |
+
|
11 |
+
class Evaluator():
|
12 |
+
def __init__(self, settings: Settings):
|
13 |
+
self.settings = settings
|
14 |
+
|
15 |
+
def get_questions(self) -> list[Question]:
|
16 |
+
"""
|
17 |
+
Get the questions from the HuggingFace endpoint.
|
18 |
+
|
19 |
+
Returns:
|
20 |
+
list[Question]: A list of Question objects
|
21 |
+
"""
|
22 |
+
url = str(self.settings.scoring_api_base_url) + "questions"
|
23 |
+
try:
|
24 |
+
response = requests.get(url, timeout=10)
|
25 |
+
response.raise_for_status()
|
26 |
+
questions = [Question(**question) for question in response.json()]
|
27 |
+
with open("questions.json", "w") as f:
|
28 |
+
json.dump([question.model_dump()
|
29 |
+
for question in questions], f, indent=4)
|
30 |
+
except:
|
31 |
+
# Read local file instead, dealing with rate limits, etc.
|
32 |
+
with open("questions.json", "r") as f:
|
33 |
+
questions = [Question(**question) for question in json.load(f)]
|
34 |
+
return questions
|
35 |
+
|
36 |
+
def get_one_question(self, task_id=None) -> Question:
|
37 |
+
"""
|
38 |
+
Get a random, or requested question from the HuggingFace endpoint.
|
39 |
+
|
40 |
+
Returns:
|
41 |
+
Question: A Question object
|
42 |
+
"""
|
43 |
+
if task_id:
|
44 |
+
questions = self.get_questions()
|
45 |
+
if task_id:
|
46 |
+
for question in questions:
|
47 |
+
if question.task_id == task_id:
|
48 |
+
return question
|
49 |
+
try:
|
50 |
+
url = str(self.settings.scoring_api_base_url) + "random-question"
|
51 |
+
response = requests.get(url, timeout=10)
|
52 |
+
response.raise_for_status()
|
53 |
+
question = Question(**response.json())
|
54 |
+
return question
|
55 |
+
except:
|
56 |
+
# Read local file instead, dealing with rate limits, etc.
|
57 |
+
questions = self.get_questions()
|
58 |
+
return questions[random.randint(0, len(questions)-1)]
|
59 |
+
|
60 |
+
def _read_answer_file(self) -> List[str]:
|
61 |
+
with open("answers.json", "r") as f:
|
62 |
+
pairs = [QuestionAnswerPair(**pair) for pair in json.load(f)]
|
63 |
+
formatted_data = [pair.get_answer() for pair in pairs]
|
64 |
+
# Return count and the formatted data
|
65 |
+
return formatted_data
|
66 |
+
|
67 |
+
def submit_answers(self) -> str:
|
68 |
+
"""Submits saved answers to the scoring endpoint and returns the result."""
|
69 |
+
answers_payload = self._read_answer_file()
|
70 |
+
agent_code = f"https://huggingface.co/spaces/{self.settings.space_id}/tree/main"
|
71 |
+
submission_data = {
|
72 |
+
"username": self.settings.username,
|
73 |
+
"agent_code": agent_code,
|
74 |
+
"answers": answers_payload}
|
75 |
+
submit_url = str(self.settings.scoring_api_base_url) + "submit"
|
76 |
+
logger.info(f"Submitting {len(answers_payload)} answers to: {submit_url}")
|
77 |
+
try:
|
78 |
+
response = requests.post(
|
79 |
+
submit_url, json=submission_data, timeout=60)
|
80 |
+
response.raise_for_status()
|
81 |
+
results = Results.model_validate(response.json())
|
82 |
+
logger.info(
|
83 |
+
f"Submission successful.\n"
|
84 |
+
f"User: {results.username}.\n"
|
85 |
+
f"Overall Score: {results.score}%.\n"
|
86 |
+
f"Correct Count: {results.correct_count}.\n"
|
87 |
+
f"Total Attempted: {results.total_attempted}.\n"
|
88 |
+
f"Message: {results.message}.\n"
|
89 |
+
f"Timestamp: {results.timestamp}.\n"
|
90 |
+
)
|
91 |
+
status_message = (
|
92 |
+
f"Submission Successful!\n"
|
93 |
+
f"User: {results.username}\n"
|
94 |
+
f"Overall Score: {results.score}% "
|
95 |
+
f"({results.correct_count}/{results.total_attempted} correct)\n"
|
96 |
+
f"Message: {results.message}"
|
97 |
+
)
|
98 |
+
return status_message
|
99 |
+
except requests.exceptions.HTTPError as e:
|
100 |
+
error_detail = f"Server responded with status {e.response.status_code}."
|
101 |
+
try:
|
102 |
+
error_json = e.response.json()
|
103 |
+
error_detail += f" Detail: {error_json.get('detail', e.response.text)}"
|
104 |
+
except requests.exceptions.JSONDecodeError:
|
105 |
+
error_detail += f" Response: {e.response.text[:500]}"
|
106 |
+
status_message = f"Submission Failed: {error_detail}"
|
107 |
+
logger.info(status_message)
|
108 |
+
return status_message
|
109 |
+
except requests.exceptions.Timeout:
|
110 |
+
status_message = "Submission Failed: The request timed out."
|
111 |
+
logger.info(status_message)
|
112 |
+
return status_message
|
113 |
+
except requests.exceptions.RequestException as e:
|
114 |
+
status_message = f"Submission Failed: Network error - {e}"
|
115 |
+
logger.info(status_message)
|
116 |
+
return status_message
|
117 |
+
except Exception as e:
|
118 |
+
status_message = f"An unexpected error occurred during submission: {e}"
|
119 |
+
logger.info(status_message)
|
120 |
+
return status_message
|
runner.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from settings import Settings
|
2 |
+
from models import QuestionAnswerPair
|
3 |
+
from agent import BasicAgent
|
4 |
+
import pandas as pd
|
5 |
+
import logging
|
6 |
+
import json
|
7 |
+
import asyncio
|
8 |
+
import nest_asyncio
|
9 |
+
nest_asyncio.apply()
|
10 |
+
logger = logging.getLogger(__name__)
|
11 |
+
|
12 |
+
|
13 |
+
class Runner():
|
14 |
+
def __init__(self, settings: Settings):
|
15 |
+
self.settings = settings
|
16 |
+
|
17 |
+
def _save_pairs(self, pairs: list[QuestionAnswerPair]):
|
18 |
+
answers = [pair.model_dump() for pair in pairs if pair is not None]
|
19 |
+
with open("answers.json", "w") as f:
|
20 |
+
json.dump(answers, f, indent=4)
|
21 |
+
|
22 |
+
def _enrich_question_text(self, item):
|
23 |
+
task_id = item.task_id
|
24 |
+
file_name = item.file_name
|
25 |
+
question_text = (
|
26 |
+
f"{item.question} "
|
27 |
+
"Think hard to answer. Parse all statements in the question to make a plan. "
|
28 |
+
"Your final answer should be a number or as few words as possible. "
|
29 |
+
"If needed, use a comma separated list of numbers and/or strings. Critically "
|
30 |
+
f"review your answer before making it the final answer. task_id: {task_id}."
|
31 |
+
)
|
32 |
+
if file_name:
|
33 |
+
question_text = f"{question_text} file_name: {file_name} (use tools to fetch the file)"
|
34 |
+
return question_text
|
35 |
+
|
36 |
+
async def _run_agent_async(self, item):
|
37 |
+
"""Runs the agent asynchronously."""
|
38 |
+
task_id = item.task_id
|
39 |
+
question_text = self._enrich_question_text(item)
|
40 |
+
try:
|
41 |
+
answer = await asyncio.to_thread(BasicAgent(self.settings), question_text)
|
42 |
+
except Exception as e:
|
43 |
+
logger.error(f"Error running agent on task {task_id}: {e}")
|
44 |
+
answer = f"AGENT ERROR: {e}"
|
45 |
+
return QuestionAnswerPair(task_id=task_id,
|
46 |
+
question=item.question, answer=str(answer))
|
47 |
+
|
48 |
+
def _assign_questions(self, questions):
|
49 |
+
"""Runs the asynchronous loop and returns task outputs."""
|
50 |
+
tasks = [self._run_agent_async(item) for item in questions]
|
51 |
+
return asyncio.gather(*tasks)
|
52 |
+
|
53 |
+
def run_agent(self, questions) -> pd.DataFrame:
|
54 |
+
"""Run the agent(s) async, save answers and return a dataframe"""
|
55 |
+
# Assign questions to agents and wait
|
56 |
+
loop = asyncio.get_running_loop()
|
57 |
+
|
58 |
+
def run_tasks_in_thread():
|
59 |
+
question_answer_pairs = loop.run_until_complete(
|
60 |
+
self._assign_questions(questions))
|
61 |
+
return question_answer_pairs
|
62 |
+
|
63 |
+
pairs = run_tasks_in_thread()
|
64 |
+
|
65 |
+
# save json to disk and return a dataframe
|
66 |
+
self._save_pairs(pairs)
|
67 |
+
results_log = [pair.model_dump() for pair in pairs if pair is not None]
|
68 |
+
if not results_log:
|
69 |
+
logger.warning("Agent did not produce any answers to submit.")
|
70 |
+
|
71 |
+
return pd.DataFrame(results_log)
|
tools.py
ADDED
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import logging
|
4 |
+
logger = logging.getLogger(__name__)
|
5 |
+
import requests
|
6 |
+
import shutil
|
7 |
+
from typing import Any
|
8 |
+
import urllib.parse
|
9 |
+
from board_to_fen.predict import get_fen_from_image_path
|
10 |
+
from google import genai
|
11 |
+
from google.genai import types
|
12 |
+
from litellm import completion
|
13 |
+
from smolagents import Tool
|
14 |
+
from settings import Settings
|
15 |
+
|
16 |
+
|
17 |
+
class BaseCustomTool(Tool):
|
18 |
+
def __init__(self, settings):
|
19 |
+
super().__init__()
|
20 |
+
self.settings = settings
|
21 |
+
|
22 |
+
class GetTaskFileTool(BaseCustomTool):
|
23 |
+
name = "get_task_file_tool"
|
24 |
+
description = """Download the file_name associated with a given task_id. Get absolute file path"""
|
25 |
+
inputs = {
|
26 |
+
"task_id": {"type": "string", "description": "Task ID"},
|
27 |
+
"file_name": {"type": "string", "description": "File name"},
|
28 |
+
}
|
29 |
+
output_type = "string"
|
30 |
+
|
31 |
+
def __init__(self, settings):
|
32 |
+
super().__init__(settings)
|
33 |
+
self.directory_name = "downloads"
|
34 |
+
self.create_dir()
|
35 |
+
|
36 |
+
def forward(self, task_id: str, file_name: str) -> str:
|
37 |
+
try:
|
38 |
+
response = requests.get(f"{self.settings.evaluation_api_base_url}/files/{task_id}", timeout=15)
|
39 |
+
response.raise_for_status()
|
40 |
+
with open(f"{self.directory_name}/{file_name}", 'wb') as file:
|
41 |
+
file.write(response.content)
|
42 |
+
return os.path.abspath(f"{self.directory_name}/{file_name}")
|
43 |
+
except Exception as e:
|
44 |
+
# Fetch the local file instead, dealing with rate limits, etc.
|
45 |
+
shutil.copy2(f"files/{file_name}", f"{self.directory_name}/{file_name}")
|
46 |
+
return f"{self.directory_name}/{file_name}"
|
47 |
+
|
48 |
+
def create_dir(self):
|
49 |
+
# Create the directory if it doesn't exist
|
50 |
+
if not os.path.exists(self.directory_name):
|
51 |
+
os.makedirs(self.directory_name)
|
52 |
+
logger.info(f"Directory '{self.directory_name}' created successfully.")
|
53 |
+
else:
|
54 |
+
logger.debug(f"Directory '{self.directory_name}' already exists.")
|
55 |
+
|
56 |
+
class VideoUnderstandingTool(BaseCustomTool):
|
57 |
+
name = "VideoUnderstanding"
|
58 |
+
description = "Prompt a YouTube video with questions to understand its content."
|
59 |
+
inputs = {
|
60 |
+
"youtube_url": {"type": "string", "description": "The URL of the YouTube video"},
|
61 |
+
"prompt": {"type": "string", "description": "A question or request regarding the video"},
|
62 |
+
}
|
63 |
+
output_type = "string"
|
64 |
+
|
65 |
+
def __init__(self, settings, model):
|
66 |
+
super().__init__(settings)
|
67 |
+
self.model = model
|
68 |
+
|
69 |
+
def forward(self, youtube_url: str, prompt: str) -> str:
|
70 |
+
client = genai.Client(api_key=self.settings.gemini_api_key.get_secret_value())
|
71 |
+
try:
|
72 |
+
video_description = client.models.generate_content(
|
73 |
+
model=self.model,
|
74 |
+
contents=types.Content(
|
75 |
+
parts=[
|
76 |
+
types.Part(
|
77 |
+
file_data=types.FileData(file_uri=youtube_url)
|
78 |
+
),
|
79 |
+
types.Part(text=prompt)
|
80 |
+
]
|
81 |
+
)
|
82 |
+
)
|
83 |
+
return video_description.text
|
84 |
+
except Exception as e:
|
85 |
+
logger.error(f"Error understanding video: {e}")
|
86 |
+
return False
|
87 |
+
|
88 |
+
class AudioUnderstandingTool(BaseCustomTool):
|
89 |
+
name = "AudioUnderstanding"
|
90 |
+
description = "Prompt a local audio file with questions to understand its content."
|
91 |
+
inputs = {
|
92 |
+
"file_path": {"type": "string", "description": "The local file of the audio"},
|
93 |
+
"prompt": {"type": "string", "description": "A question or request regarding the audio"},
|
94 |
+
}
|
95 |
+
output_type = "string"
|
96 |
+
|
97 |
+
def __init__(self, settings, model):
|
98 |
+
super().__init__(settings)
|
99 |
+
self.model = model
|
100 |
+
|
101 |
+
def forward(self, file_path: str, prompt: str) -> str:
|
102 |
+
client = genai.Client(api_key=self.settings.gemini_api_key.get_secret_value())
|
103 |
+
try:
|
104 |
+
mp3_file = client.files.upload(file=f"{file_path}")
|
105 |
+
audio_description = client.models.generate_content(
|
106 |
+
model=self.model,
|
107 |
+
contents=[prompt, mp3_file]
|
108 |
+
)
|
109 |
+
return audio_description.text
|
110 |
+
except Exception as e:
|
111 |
+
logger.error(f"Error understanding audio: {e}")
|
112 |
+
return False
|
113 |
+
|
114 |
+
class ConvertChessMoveTool(BaseCustomTool):
|
115 |
+
name = "ConvertChessMove"
|
116 |
+
description = "Convert a chess move from coordinate notation to algebraic notation."
|
117 |
+
inputs = {
|
118 |
+
"piece_placement": {"type": "string", "description": "The chess piece placement in plain text"},
|
119 |
+
"move": {"type": "string", "description": "The move in coordinate notation (e.g., e2e4)"},
|
120 |
+
}
|
121 |
+
output_type = "string"
|
122 |
+
|
123 |
+
def __init__(self, settings, model):
|
124 |
+
super().__init__(settings)
|
125 |
+
self.model = model
|
126 |
+
|
127 |
+
def forward(self, piece_placement: str, move: str) -> str:
|
128 |
+
move_message = f"""Convert this chess move from coordinate notation to algebraic
|
129 |
+
notation: {move}. Use the following {piece_placement}. Do not provide any additional
|
130 |
+
thinking or commentary in the response, the algebraic notation only."""
|
131 |
+
messages = [{ "content": move_message,"role": "user"}]
|
132 |
+
response = completion(
|
133 |
+
model=self.model,
|
134 |
+
temperature=0.0,
|
135 |
+
messages=messages,
|
136 |
+
api_key=self.settings.openrouter_api_key.get_secret_value()
|
137 |
+
)
|
138 |
+
return response.choices[0].message.content
|
139 |
+
|
140 |
+
class BestChessMoveTool(BaseCustomTool):
|
141 |
+
name = "BestChessMove"
|
142 |
+
description = "Get best chess move in coordinate notation based on a FEN representation."
|
143 |
+
inputs = {
|
144 |
+
"fen": {"type": "string", "description": "The FEN (Forsyth-Edwards Notation) \
|
145 |
+
representation of the chess position. Example \
|
146 |
+
rn1q1rk1/pp2b1pp/2p2n2/3p1pB1/3P4/1QP2N2/PP1N1PPP/R4RK1 b - - 1 11"},
|
147 |
+
}
|
148 |
+
output_type = "string"
|
149 |
+
|
150 |
+
def forward(self, fen: str) -> str:
|
151 |
+
try:
|
152 |
+
url = f"{self.settings.chess_eval_url}?fen={urllib.parse.quote(fen)}&depth=15"
|
153 |
+
response = requests.get(url, timeout=15)
|
154 |
+
if response.status_code == 200 and json.loads(response.text)['success'] == True:
|
155 |
+
return json.loads(response.text)['bestmove'].split()[1]
|
156 |
+
else:
|
157 |
+
raise ValueError(f"Error getting chess evaluation: {response.status_code}")
|
158 |
+
except Exception as e:
|
159 |
+
logger.error(f"Error getting chess evaluation: {e}")
|
160 |
+
|
161 |
+
class ChessGameFenTool(BaseCustomTool):
|
162 |
+
name = "ChessGameFen"
|
163 |
+
description = "Get a FEN representation given chess piece placement and a move."
|
164 |
+
inputs = {
|
165 |
+
"piece_placement": {"type": "string", "description": "The chess piece placement in plain text"},
|
166 |
+
"player_turn": {"type": "string",
|
167 |
+
"description": "The player with the next turn in the match, black or white"},
|
168 |
+
}
|
169 |
+
output_type = "string"
|
170 |
+
|
171 |
+
def __init__(self, settings, model):
|
172 |
+
super().__init__(settings)
|
173 |
+
self.model = model
|
174 |
+
|
175 |
+
def forward(self, piece_placement: str, player_turn: str) -> str:
|
176 |
+
"""Use the tool."""
|
177 |
+
fen_message = f"""Assuming {player_turn} has the next turn, Use the following placement
|
178 |
+
{piece_placement} and provide the board state as FEN. Do not provide any
|
179 |
+
additional thinking or commentary in the response, the FEN only."""
|
180 |
+
messages = [{ "content": fen_message,"role": "user"}]
|
181 |
+
response = completion(
|
182 |
+
model=self.model,
|
183 |
+
temperature=0.0,
|
184 |
+
messages=messages,
|
185 |
+
api_key=self.settings.openrouter_api_key.get_secret_value()
|
186 |
+
)
|
187 |
+
return response.choices[0].message.content
|
188 |
+
|
189 |
+
class ChessPiecePlacementTool(Tool):
|
190 |
+
name = "ChessPiecePlacement"
|
191 |
+
description = "Get chess piece placement information from an image of a board."
|
192 |
+
inputs = {
|
193 |
+
"image_path": {"type": "string", "description": "The local file of the chess board image"},
|
194 |
+
}
|
195 |
+
output_type = "string"
|
196 |
+
|
197 |
+
def forward(self, image_path: str) -> str:
|
198 |
+
return get_fen_from_image_path(image_path)
|
199 |
+
|