File size: 2,179 Bytes
991089b
 
 
 
 
 
 
 
 
6b58720
dff08c2
991089b
6b58720
991089b
2b4a8ed
991089b
1e6cc91
 
0e96680
1e6cc91
991089b
 
 
 
 
 
 
 
 
 
 
dff08c2
0e96680
991089b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import re
import constants
import time
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain.agents import AgentExecutor, create_tool_calling_agent, create_openai_functions_agent
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.messages import SystemMessage

# --- Custom Tools ---
from nb_tool import get_team_roster
from nb_tool import get_npb_player_info


PROMPT = """
    You are an agent specialized in answering questions related to Nippon Professional Baseball players. Report your thoughts, and 
    finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER]. 

    If you need to identify players that come before or after other players in NPB, you can first get the player number and team. 
    Then you get the players of the team for that season. Then you just find what you need out of that list.
    
"""

class NpbAgent:
    def __init__(self):
        llm = ChatGoogleGenerativeAI(
            model=constants.MODEL,
            api_key=constants.API_KEY, 
            temperature=0.4,
            timeout=20)

        tools = [
            get_npb_player_info,
            get_team_roster
        ]

        prompt = ChatPromptTemplate.from_messages([
            SystemMessage(content=PROMPT),
            MessagesPlaceholder(variable_name="chat_history"),
            ("human", "{input}"),
            MessagesPlaceholder(variable_name="agent_scratchpad"),
        ])

        agent = create_tool_calling_agent(llm, tools, prompt=prompt)
        self.executor = AgentExecutor(
            agent=agent, 
            tools=tools, 
            verbose=True,
            max_iterations=15)

    def __call__(self, question: str) -> str:
        print(f"NpbAgent agent received: {question[:50]}...")
        result = self.executor.invoke({
            "input": question,
            "chat_history": []
        })
        output = result.get("output", "No answer returned.")
        print(f"Agent response: {output}")
        match = re.search(r"FINAL ANSWER:\s*(.*)", output)
        if match:
            return match.group(1).strip()
        else:
            return output