Spaces:
Sleeping
Sleeping
File size: 2,179 Bytes
991089b 6b58720 dff08c2 991089b 6b58720 991089b 2b4a8ed 991089b 1e6cc91 0e96680 1e6cc91 991089b dff08c2 0e96680 991089b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 |
import re
import constants
import time
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain.agents import AgentExecutor, create_tool_calling_agent, create_openai_functions_agent
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.messages import SystemMessage
# --- Custom Tools ---
from nb_tool import get_team_roster
from nb_tool import get_npb_player_info
PROMPT = """
You are an agent specialized in answering questions related to Nippon Professional Baseball players. Report your thoughts, and
finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER].
If you need to identify players that come before or after other players in NPB, you can first get the player number and team.
Then you get the players of the team for that season. Then you just find what you need out of that list.
"""
class NpbAgent:
def __init__(self):
llm = ChatGoogleGenerativeAI(
model=constants.MODEL,
api_key=constants.API_KEY,
temperature=0.4,
timeout=20)
tools = [
get_npb_player_info,
get_team_roster
]
prompt = ChatPromptTemplate.from_messages([
SystemMessage(content=PROMPT),
MessagesPlaceholder(variable_name="chat_history"),
("human", "{input}"),
MessagesPlaceholder(variable_name="agent_scratchpad"),
])
agent = create_tool_calling_agent(llm, tools, prompt=prompt)
self.executor = AgentExecutor(
agent=agent,
tools=tools,
verbose=True,
max_iterations=15)
def __call__(self, question: str) -> str:
print(f"NpbAgent agent received: {question[:50]}...")
result = self.executor.invoke({
"input": question,
"chat_history": []
})
output = result.get("output", "No answer returned.")
print(f"Agent response: {output}")
match = re.search(r"FINAL ANSWER:\s*(.*)", output)
if match:
return match.group(1).strip()
else:
return output
|