AGAZO_Final_Assignment / langchain_agent.py
Alexandre Gazola
trocando implementacao para obtencao do FEN
1c14abd
raw
history blame
2.38 kB
import re
import constants
import time
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain.agents import AgentExecutor, create_tool_calling_agent, create_openai_functions_agent
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.messages import SystemMessage
# --- Custom Tools ---
from wikipedia_tool import wikipedia_revision_by_year_keyword
from count_max_bird_species_tool import count_max_bird_species_in_video
from image_to_text_tool import image_to_text
from internet_search_tool import internet_search
from botanical_classification_tool import get_botanical_classification
from excel_parser_tool import parse_excel
from analyse_chess_position_tool import get_chess_best_move
from convert_chessboard_image_to_fen_tool import convert_chessboard_image_to_fen
from chess_image_to_fen_tool import chess_image_to_fen
class LangChainAgent:
def __init__(self):
llm = ChatGoogleGenerativeAI(
model=constants.MODEL,
api_key=constants.API_KEY,
temperature=0.4,
timeout=20)
tools = [
wikipedia_revision_by_year_keyword,
count_max_bird_species_in_video,
image_to_text,
internet_search,
get_botanical_classification,
parse_excel,
#convert_chessboard_image_to_fen,
chess_image_to_fen,
get_chess_best_move
]
prompt = ChatPromptTemplate.from_messages([
SystemMessage(content=constants.PROMPT_LIMITADOR_LLM),
MessagesPlaceholder(variable_name="chat_history"),
("human", "{input}"),
MessagesPlaceholder(variable_name="agent_scratchpad"),
])
agent = create_tool_calling_agent(llm, tools, prompt=prompt)
self.executor = AgentExecutor(agent=agent, tools=tools, verbose=True)
def __call__(self, question: str) -> str:
print(f"LangChain agent received: {question[:50]}...")
result = self.executor.invoke({
"input": question,
"chat_history": []
})
output = result.get("output", "No answer returned.")
print(f"Agent response: {output}")
match = re.search(r"FINAL ANSWER:\s*(.*)", output)
if match:
return match.group(1).strip()
else:
return output