AGAZO_Final_Assignment / npb_agent.py
agazo's picture
Update npb_agent.py
0e96680 verified
import re
import constants
import time
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain.agents import AgentExecutor, create_tool_calling_agent, create_openai_functions_agent
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.messages import SystemMessage
# --- Custom Tools ---
from nb_tool import get_team_roster
from nb_tool import get_npb_player_info
PROMPT = """
You are an agent specialized in answering questions related to Nippon Professional Baseball players. Report your thoughts, and
finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER].
If you need to identify players that come before or after other players in NPB, you can first get the player number and team.
Then you get the players of the team for that season. Then you just find what you need out of that list.
"""
class NpbAgent:
def __init__(self):
llm = ChatGoogleGenerativeAI(
model=constants.MODEL,
api_key=constants.API_KEY,
temperature=0.4,
timeout=20)
tools = [
get_npb_player_info,
get_team_roster
]
prompt = ChatPromptTemplate.from_messages([
SystemMessage(content=PROMPT),
MessagesPlaceholder(variable_name="chat_history"),
("human", "{input}"),
MessagesPlaceholder(variable_name="agent_scratchpad"),
])
agent = create_tool_calling_agent(llm, tools, prompt=prompt)
self.executor = AgentExecutor(
agent=agent,
tools=tools,
verbose=True,
max_iterations=15)
def __call__(self, question: str) -> str:
print(f"NpbAgent agent received: {question[:50]}...")
result = self.executor.invoke({
"input": question,
"chat_history": []
})
output = result.get("output", "No answer returned.")
print(f"Agent response: {output}")
match = re.search(r"FINAL ANSWER:\s*(.*)", output)
if match:
return match.group(1).strip()
else:
return output