civerson916 commited on
Commit
476b409
·
verified ·
1 Parent(s): 64e1e58

Upload 7 files

Browse files

Refactored adding Evaluator and Runner classes

Files changed (5) hide show
  1. agent.py +69 -0
  2. app.py +10 -13
  3. evaluator.py +120 -0
  4. runner.py +71 -0
  5. tools.py +199 -0
agent.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+ import io
3
+ import logging
4
+ import os
5
+ logger = logging.getLogger(__name__)
6
+ from models import GoogleModelID, OpenRouterModelID
7
+ from settings import Settings
8
+ from smolagents import LiteLLMModel, CodeAgent
9
+ from smolagents import GoogleSearchTool, VisitWebpageTool, FinalAnswerTool
10
+ from smolagents.local_python_executor import BASE_PYTHON_TOOLS
11
+ from tools import GetTaskFileTool, VideoUnderstandingTool, AudioUnderstandingTool
12
+ from tools import ChessPiecePlacementTool, ChessGameFenTool, BestChessMoveTool, ConvertChessMoveTool
13
+
14
+
15
+ # Base tools may use these to process files
16
+ BASE_PYTHON_TOOLS["open"] = open
17
+ BASE_PYTHON_TOOLS["os"] = os
18
+ BASE_PYTHON_TOOLS["io"] = io
19
+ BASE_PYTHON_TOOLS["contextlib"] = contextlib
20
+ BASE_PYTHON_TOOLS["exec"] = exec
21
+
22
+
23
+ class BasicAgent:
24
+ def __init__(self, settings: Settings):
25
+ self.agent = CodeAgent(
26
+ add_base_tools=False,
27
+ tools=[GoogleSearchTool("serper"),
28
+ VisitWebpageTool(max_output_length=100000),
29
+ FinalAnswerTool(),
30
+ GetTaskFileTool(settings),
31
+ VideoUnderstandingTool(settings, GoogleModelID.GEMINI_2_0_FLASH),
32
+ AudioUnderstandingTool(settings, GoogleModelID.GEMINI_2_0_FLASH),
33
+ ChessPiecePlacementTool(),
34
+ ChessGameFenTool(settings, OpenRouterModelID.GPT_O4_MINI),
35
+ BestChessMoveTool(settings),
36
+ ConvertChessMoveTool(settings, OpenRouterModelID.QWEN_3_14B_FREE)
37
+ ],
38
+ additional_authorized_imports=[
39
+ "unicodedata",
40
+ "stat",
41
+ "datetime",
42
+ "random",
43
+ "pandas",
44
+ "itertools",
45
+ "math",
46
+ "statistics",
47
+ "queue",
48
+ "time",
49
+ "collections",
50
+ "re",
51
+ "os"
52
+ ],
53
+ max_steps=10,
54
+ verbosity_level=1,
55
+ model=LiteLLMModel(
56
+ # model_id=OpenRouterModelID.GPT_O4_MINI,
57
+ model_id=OpenRouterModelID.GPT_4_1_MINI,
58
+ # model_id=OpenRouterModelID.GROK_3_BETA,
59
+ # model_id=OpenRouterModelID.GROK_3_MINI_BETA,
60
+ api_key = settings.openrouter_api_key.get_secret_value(),
61
+ temperature=0.0, timeout=180
62
+ )
63
+ )
64
+ # print("BasicAgent initialized.")
65
+ def __call__(self, question: str) -> str:
66
+ logger.info(f"Agent received question (first 50 chars): {question[:50]}...")
67
+ final_answer = self.agent.run(question)
68
+ logger.info(f"Agent returning fixed answer: {final_answer}")
69
+ return final_answer
app.py CHANGED
@@ -3,8 +3,8 @@ from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExport
3
  from openinference.instrumentation.smolagents import SmolagentsInstrumentor
4
  from opentelemetry.sdk.trace import TracerProvider
5
  from opentelemetry import trace
6
- # from evaluator import Evaluator
7
- # from runner import Runner
8
  from settings import Settings
9
  import os
10
  import pandas as pd
@@ -13,8 +13,8 @@ import logging
13
  logging.basicConfig(level=logging.INFO, force=True)
14
  logger = logging.getLogger(__name__)
15
  settings = Settings()
16
- # evaluator = Evaluator(settings)
17
- # runner = Runner(settings)
18
 
19
 
20
  # Create a TracerProvider for OpenTelemetry
@@ -45,22 +45,19 @@ EMPTY_RESULTS_TABLE = pd.DataFrame(columns=['task_id', 'question', 'answer'])
45
  def run_one(profile: gr.OAuthProfile | None) -> pd.DataFrame:
46
  if not user_logged_in(profile):
47
  return LOGIN_MESSAGE, EMPTY_RESULTS_TABLE
48
- # questions = [evaluator.get_one_question()]
49
- # return "Answer one random question...", runner.run_agent(questions)
50
- return "You are logged in.", EMPTY_RESULTS_TABLE
51
 
52
  def run_all(profile: gr.OAuthProfile | None) -> pd.DataFrame:
53
  if not user_logged_in(profile):
54
  return LOGIN_MESSAGE, EMPTY_RESULTS_TABLE
55
- # questions = evaluator.get_questions()
56
- # return "Answer all 20 questions...", runner.run_agent(questions)
57
- return "You are logged in.", EMPTY_RESULTS_TABLE
58
 
59
  def submit(profile: gr.OAuthProfile | None):
60
  if not user_logged_in(profile):
61
  return LOGIN_MESSAGE
62
- # evaluator.submit_answers()
63
- return "You are logged in."
64
 
65
 
66
  # --- Build Gradio Interface using Blocks ---
@@ -77,7 +74,7 @@ with gr.Blocks() as demo:
77
  ---
78
  **Disclaimers:**
79
  Once clicking 'Get All Answers', it can take quite some time (this is the time for the agent to go through all 20 questions).
80
- The agent will run question tasks in parallel making observability tools a must. Langfuse instrumentation has been configured.
81
  The 'Submit All Answers' button will use the most recent agent answers cached in the space for your username.
82
  """
83
  )
 
3
  from openinference.instrumentation.smolagents import SmolagentsInstrumentor
4
  from opentelemetry.sdk.trace import TracerProvider
5
  from opentelemetry import trace
6
+ from evaluator import Evaluator
7
+ from runner import Runner
8
  from settings import Settings
9
  import os
10
  import pandas as pd
 
13
  logging.basicConfig(level=logging.INFO, force=True)
14
  logger = logging.getLogger(__name__)
15
  settings = Settings()
16
+ evaluator = Evaluator(settings)
17
+ runner = Runner(settings)
18
 
19
 
20
  # Create a TracerProvider for OpenTelemetry
 
45
  def run_one(profile: gr.OAuthProfile | None) -> pd.DataFrame:
46
  if not user_logged_in(profile):
47
  return LOGIN_MESSAGE, EMPTY_RESULTS_TABLE
48
+ questions = [evaluator.get_one_question()]
49
+ return "Answer one random question...", runner.run_agent(questions)
 
50
 
51
  def run_all(profile: gr.OAuthProfile | None) -> pd.DataFrame:
52
  if not user_logged_in(profile):
53
  return LOGIN_MESSAGE, EMPTY_RESULTS_TABLE
54
+ questions = evaluator.get_questions()
55
+ return "Answer all 20 questions...", runner.run_agent(questions)
 
56
 
57
  def submit(profile: gr.OAuthProfile | None):
58
  if not user_logged_in(profile):
59
  return LOGIN_MESSAGE
60
+ evaluator.submit_answers()
 
61
 
62
 
63
  # --- Build Gradio Interface using Blocks ---
 
74
  ---
75
  **Disclaimers:**
76
  Once clicking 'Get All Answers', it can take quite some time (this is the time for the agent to go through all 20 questions).
77
+ The agent(s) will run question tasks in parallel making the logs hard to follow. Langfuse instrumentation has been configured.
78
  The 'Submit All Answers' button will use the most recent agent answers cached in the space for your username.
79
  """
80
  )
evaluator.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from settings import Settings
2
+ from typing import List
3
+ from models import Question, QuestionAnswerPair, Results
4
+ import requests
5
+ import random
6
+ import json
7
+ import logging
8
+ logger = logging.getLogger(__name__)
9
+
10
+
11
+ class Evaluator():
12
+ def __init__(self, settings: Settings):
13
+ self.settings = settings
14
+
15
+ def get_questions(self) -> list[Question]:
16
+ """
17
+ Get the questions from the HuggingFace endpoint.
18
+
19
+ Returns:
20
+ list[Question]: A list of Question objects
21
+ """
22
+ url = str(self.settings.scoring_api_base_url) + "questions"
23
+ try:
24
+ response = requests.get(url, timeout=10)
25
+ response.raise_for_status()
26
+ questions = [Question(**question) for question in response.json()]
27
+ with open("questions.json", "w") as f:
28
+ json.dump([question.model_dump()
29
+ for question in questions], f, indent=4)
30
+ except:
31
+ # Read local file instead, dealing with rate limits, etc.
32
+ with open("questions.json", "r") as f:
33
+ questions = [Question(**question) for question in json.load(f)]
34
+ return questions
35
+
36
+ def get_one_question(self, task_id=None) -> Question:
37
+ """
38
+ Get a random, or requested question from the HuggingFace endpoint.
39
+
40
+ Returns:
41
+ Question: A Question object
42
+ """
43
+ if task_id:
44
+ questions = self.get_questions()
45
+ if task_id:
46
+ for question in questions:
47
+ if question.task_id == task_id:
48
+ return question
49
+ try:
50
+ url = str(self.settings.scoring_api_base_url) + "random-question"
51
+ response = requests.get(url, timeout=10)
52
+ response.raise_for_status()
53
+ question = Question(**response.json())
54
+ return question
55
+ except:
56
+ # Read local file instead, dealing with rate limits, etc.
57
+ questions = self.get_questions()
58
+ return questions[random.randint(0, len(questions)-1)]
59
+
60
+ def _read_answer_file(self) -> List[str]:
61
+ with open("answers.json", "r") as f:
62
+ pairs = [QuestionAnswerPair(**pair) for pair in json.load(f)]
63
+ formatted_data = [pair.get_answer() for pair in pairs]
64
+ # Return count and the formatted data
65
+ return formatted_data
66
+
67
+ def submit_answers(self) -> str:
68
+ """Submits saved answers to the scoring endpoint and returns the result."""
69
+ answers_payload = self._read_answer_file()
70
+ agent_code = f"https://huggingface.co/spaces/{self.settings.space_id}/tree/main"
71
+ submission_data = {
72
+ "username": self.settings.username,
73
+ "agent_code": agent_code,
74
+ "answers": answers_payload}
75
+ submit_url = str(self.settings.scoring_api_base_url) + "submit"
76
+ logger.info(f"Submitting {len(answers_payload)} answers to: {submit_url}")
77
+ try:
78
+ response = requests.post(
79
+ submit_url, json=submission_data, timeout=60)
80
+ response.raise_for_status()
81
+ results = Results.model_validate(response.json())
82
+ logger.info(
83
+ f"Submission successful.\n"
84
+ f"User: {results.username}.\n"
85
+ f"Overall Score: {results.score}%.\n"
86
+ f"Correct Count: {results.correct_count}.\n"
87
+ f"Total Attempted: {results.total_attempted}.\n"
88
+ f"Message: {results.message}.\n"
89
+ f"Timestamp: {results.timestamp}.\n"
90
+ )
91
+ status_message = (
92
+ f"Submission Successful!\n"
93
+ f"User: {results.username}\n"
94
+ f"Overall Score: {results.score}% "
95
+ f"({results.correct_count}/{results.total_attempted} correct)\n"
96
+ f"Message: {results.message}"
97
+ )
98
+ return status_message
99
+ except requests.exceptions.HTTPError as e:
100
+ error_detail = f"Server responded with status {e.response.status_code}."
101
+ try:
102
+ error_json = e.response.json()
103
+ error_detail += f" Detail: {error_json.get('detail', e.response.text)}"
104
+ except requests.exceptions.JSONDecodeError:
105
+ error_detail += f" Response: {e.response.text[:500]}"
106
+ status_message = f"Submission Failed: {error_detail}"
107
+ logger.info(status_message)
108
+ return status_message
109
+ except requests.exceptions.Timeout:
110
+ status_message = "Submission Failed: The request timed out."
111
+ logger.info(status_message)
112
+ return status_message
113
+ except requests.exceptions.RequestException as e:
114
+ status_message = f"Submission Failed: Network error - {e}"
115
+ logger.info(status_message)
116
+ return status_message
117
+ except Exception as e:
118
+ status_message = f"An unexpected error occurred during submission: {e}"
119
+ logger.info(status_message)
120
+ return status_message
runner.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from settings import Settings
2
+ from models import QuestionAnswerPair
3
+ from agent import BasicAgent
4
+ import pandas as pd
5
+ import logging
6
+ import json
7
+ import asyncio
8
+ import nest_asyncio
9
+ nest_asyncio.apply()
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ class Runner():
14
+ def __init__(self, settings: Settings):
15
+ self.settings = settings
16
+
17
+ def _save_pairs(self, pairs: list[QuestionAnswerPair]):
18
+ answers = [pair.model_dump() for pair in pairs if pair is not None]
19
+ with open("answers.json", "w") as f:
20
+ json.dump(answers, f, indent=4)
21
+
22
+ def _enrich_question_text(self, item):
23
+ task_id = item.task_id
24
+ file_name = item.file_name
25
+ question_text = (
26
+ f"{item.question} "
27
+ "Think hard to answer. Parse all statements in the question to make a plan. "
28
+ "Your final answer should be a number or as few words as possible. "
29
+ "If needed, use a comma separated list of numbers and/or strings. Critically "
30
+ f"review your answer before making it the final answer. task_id: {task_id}."
31
+ )
32
+ if file_name:
33
+ question_text = f"{question_text} file_name: {file_name} (use tools to fetch the file)"
34
+ return question_text
35
+
36
+ async def _run_agent_async(self, item):
37
+ """Runs the agent asynchronously."""
38
+ task_id = item.task_id
39
+ question_text = self._enrich_question_text(item)
40
+ try:
41
+ answer = await asyncio.to_thread(BasicAgent(self.settings), question_text)
42
+ except Exception as e:
43
+ logger.error(f"Error running agent on task {task_id}: {e}")
44
+ answer = f"AGENT ERROR: {e}"
45
+ return QuestionAnswerPair(task_id=task_id,
46
+ question=item.question, answer=str(answer))
47
+
48
+ def _assign_questions(self, questions):
49
+ """Runs the asynchronous loop and returns task outputs."""
50
+ tasks = [self._run_agent_async(item) for item in questions]
51
+ return asyncio.gather(*tasks)
52
+
53
+ def run_agent(self, questions) -> pd.DataFrame:
54
+ """Run the agent(s) async, save answers and return a dataframe"""
55
+ # Assign questions to agents and wait
56
+ loop = asyncio.get_running_loop()
57
+
58
+ def run_tasks_in_thread():
59
+ question_answer_pairs = loop.run_until_complete(
60
+ self._assign_questions(questions))
61
+ return question_answer_pairs
62
+
63
+ pairs = run_tasks_in_thread()
64
+
65
+ # save json to disk and return a dataframe
66
+ self._save_pairs(pairs)
67
+ results_log = [pair.model_dump() for pair in pairs if pair is not None]
68
+ if not results_log:
69
+ logger.warning("Agent did not produce any answers to submit.")
70
+
71
+ return pd.DataFrame(results_log)
tools.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import logging
4
+ logger = logging.getLogger(__name__)
5
+ import requests
6
+ import shutil
7
+ from typing import Any
8
+ import urllib.parse
9
+ from board_to_fen.predict import get_fen_from_image_path
10
+ from google import genai
11
+ from google.genai import types
12
+ from litellm import completion
13
+ from smolagents import Tool
14
+ from settings import Settings
15
+
16
+
17
+ class BaseCustomTool(Tool):
18
+ def __init__(self, settings):
19
+ super().__init__()
20
+ self.settings = settings
21
+
22
+ class GetTaskFileTool(BaseCustomTool):
23
+ name = "get_task_file_tool"
24
+ description = """Download the file_name associated with a given task_id. Get absolute file path"""
25
+ inputs = {
26
+ "task_id": {"type": "string", "description": "Task ID"},
27
+ "file_name": {"type": "string", "description": "File name"},
28
+ }
29
+ output_type = "string"
30
+
31
+ def __init__(self, settings):
32
+ super().__init__(settings)
33
+ self.directory_name = "downloads"
34
+ self.create_dir()
35
+
36
+ def forward(self, task_id: str, file_name: str) -> str:
37
+ try:
38
+ response = requests.get(f"{self.settings.evaluation_api_base_url}/files/{task_id}", timeout=15)
39
+ response.raise_for_status()
40
+ with open(f"{self.directory_name}/{file_name}", 'wb') as file:
41
+ file.write(response.content)
42
+ return os.path.abspath(f"{self.directory_name}/{file_name}")
43
+ except Exception as e:
44
+ # Fetch the local file instead, dealing with rate limits, etc.
45
+ shutil.copy2(f"files/{file_name}", f"{self.directory_name}/{file_name}")
46
+ return f"{self.directory_name}/{file_name}"
47
+
48
+ def create_dir(self):
49
+ # Create the directory if it doesn't exist
50
+ if not os.path.exists(self.directory_name):
51
+ os.makedirs(self.directory_name)
52
+ logger.info(f"Directory '{self.directory_name}' created successfully.")
53
+ else:
54
+ logger.debug(f"Directory '{self.directory_name}' already exists.")
55
+
56
+ class VideoUnderstandingTool(BaseCustomTool):
57
+ name = "VideoUnderstanding"
58
+ description = "Prompt a YouTube video with questions to understand its content."
59
+ inputs = {
60
+ "youtube_url": {"type": "string", "description": "The URL of the YouTube video"},
61
+ "prompt": {"type": "string", "description": "A question or request regarding the video"},
62
+ }
63
+ output_type = "string"
64
+
65
+ def __init__(self, settings, model):
66
+ super().__init__(settings)
67
+ self.model = model
68
+
69
+ def forward(self, youtube_url: str, prompt: str) -> str:
70
+ client = genai.Client(api_key=self.settings.gemini_api_key.get_secret_value())
71
+ try:
72
+ video_description = client.models.generate_content(
73
+ model=self.model,
74
+ contents=types.Content(
75
+ parts=[
76
+ types.Part(
77
+ file_data=types.FileData(file_uri=youtube_url)
78
+ ),
79
+ types.Part(text=prompt)
80
+ ]
81
+ )
82
+ )
83
+ return video_description.text
84
+ except Exception as e:
85
+ logger.error(f"Error understanding video: {e}")
86
+ return False
87
+
88
+ class AudioUnderstandingTool(BaseCustomTool):
89
+ name = "AudioUnderstanding"
90
+ description = "Prompt a local audio file with questions to understand its content."
91
+ inputs = {
92
+ "file_path": {"type": "string", "description": "The local file of the audio"},
93
+ "prompt": {"type": "string", "description": "A question or request regarding the audio"},
94
+ }
95
+ output_type = "string"
96
+
97
+ def __init__(self, settings, model):
98
+ super().__init__(settings)
99
+ self.model = model
100
+
101
+ def forward(self, file_path: str, prompt: str) -> str:
102
+ client = genai.Client(api_key=self.settings.gemini_api_key.get_secret_value())
103
+ try:
104
+ mp3_file = client.files.upload(file=f"{file_path}")
105
+ audio_description = client.models.generate_content(
106
+ model=self.model,
107
+ contents=[prompt, mp3_file]
108
+ )
109
+ return audio_description.text
110
+ except Exception as e:
111
+ logger.error(f"Error understanding audio: {e}")
112
+ return False
113
+
114
+ class ConvertChessMoveTool(BaseCustomTool):
115
+ name = "ConvertChessMove"
116
+ description = "Convert a chess move from coordinate notation to algebraic notation."
117
+ inputs = {
118
+ "piece_placement": {"type": "string", "description": "The chess piece placement in plain text"},
119
+ "move": {"type": "string", "description": "The move in coordinate notation (e.g., e2e4)"},
120
+ }
121
+ output_type = "string"
122
+
123
+ def __init__(self, settings, model):
124
+ super().__init__(settings)
125
+ self.model = model
126
+
127
+ def forward(self, piece_placement: str, move: str) -> str:
128
+ move_message = f"""Convert this chess move from coordinate notation to algebraic
129
+ notation: {move}. Use the following {piece_placement}. Do not provide any additional
130
+ thinking or commentary in the response, the algebraic notation only."""
131
+ messages = [{ "content": move_message,"role": "user"}]
132
+ response = completion(
133
+ model=self.model,
134
+ temperature=0.0,
135
+ messages=messages,
136
+ api_key=self.settings.openrouter_api_key.get_secret_value()
137
+ )
138
+ return response.choices[0].message.content
139
+
140
+ class BestChessMoveTool(BaseCustomTool):
141
+ name = "BestChessMove"
142
+ description = "Get best chess move in coordinate notation based on a FEN representation."
143
+ inputs = {
144
+ "fen": {"type": "string", "description": "The FEN (Forsyth-Edwards Notation) \
145
+ representation of the chess position. Example \
146
+ rn1q1rk1/pp2b1pp/2p2n2/3p1pB1/3P4/1QP2N2/PP1N1PPP/R4RK1 b - - 1 11"},
147
+ }
148
+ output_type = "string"
149
+
150
+ def forward(self, fen: str) -> str:
151
+ try:
152
+ url = f"{self.settings.chess_eval_url}?fen={urllib.parse.quote(fen)}&depth=15"
153
+ response = requests.get(url, timeout=15)
154
+ if response.status_code == 200 and json.loads(response.text)['success'] == True:
155
+ return json.loads(response.text)['bestmove'].split()[1]
156
+ else:
157
+ raise ValueError(f"Error getting chess evaluation: {response.status_code}")
158
+ except Exception as e:
159
+ logger.error(f"Error getting chess evaluation: {e}")
160
+
161
+ class ChessGameFenTool(BaseCustomTool):
162
+ name = "ChessGameFen"
163
+ description = "Get a FEN representation given chess piece placement and a move."
164
+ inputs = {
165
+ "piece_placement": {"type": "string", "description": "The chess piece placement in plain text"},
166
+ "player_turn": {"type": "string",
167
+ "description": "The player with the next turn in the match, black or white"},
168
+ }
169
+ output_type = "string"
170
+
171
+ def __init__(self, settings, model):
172
+ super().__init__(settings)
173
+ self.model = model
174
+
175
+ def forward(self, piece_placement: str, player_turn: str) -> str:
176
+ """Use the tool."""
177
+ fen_message = f"""Assuming {player_turn} has the next turn, Use the following placement
178
+ {piece_placement} and provide the board state as FEN. Do not provide any
179
+ additional thinking or commentary in the response, the FEN only."""
180
+ messages = [{ "content": fen_message,"role": "user"}]
181
+ response = completion(
182
+ model=self.model,
183
+ temperature=0.0,
184
+ messages=messages,
185
+ api_key=self.settings.openrouter_api_key.get_secret_value()
186
+ )
187
+ return response.choices[0].message.content
188
+
189
+ class ChessPiecePlacementTool(Tool):
190
+ name = "ChessPiecePlacement"
191
+ description = "Get chess piece placement information from an image of a board."
192
+ inputs = {
193
+ "image_path": {"type": "string", "description": "The local file of the chess board image"},
194
+ }
195
+ output_type = "string"
196
+
197
+ def forward(self, image_path: str) -> str:
198
+ return get_fen_from_image_path(image_path)
199
+