Jeremy Live
commited on
Commit
·
3c8de53
1
Parent(s):
0f903c2
Creating SQL agent with memory.
Browse files
app.py
CHANGED
@@ -13,6 +13,7 @@ try:
|
|
13 |
from langchain_community.utilities import SQLDatabase
|
14 |
from langchain_google_genai import ChatGoogleGenerativeAI
|
15 |
from langchain.agents.agent_types import AgentType
|
|
|
16 |
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
17 |
import pymysql
|
18 |
from dotenv import load_dotenv
|
@@ -151,7 +152,7 @@ def create_agent():
|
|
151 |
logger.info("Starting agent creation process...")
|
152 |
|
153 |
def create_agent(llm, db_connection):
|
154 |
-
"""Create and return a SQL database agent."""
|
155 |
if not llm:
|
156 |
error_msg = "Cannot create agent: LLM is not available"
|
157 |
logger.error(error_msg)
|
@@ -163,7 +164,15 @@ def create_agent(llm, db_connection):
|
|
163 |
return None, error_msg
|
164 |
|
165 |
try:
|
166 |
-
logger.info("Creating SQL agent...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
167 |
|
168 |
# Create the database toolkit with additional configuration
|
169 |
toolkit = SQLDatabaseToolkit(
|
@@ -171,7 +180,7 @@ def create_agent(llm, db_connection):
|
|
171 |
llm=llm
|
172 |
)
|
173 |
|
174 |
-
# Create the agent with more detailed configuration
|
175 |
agent = create_sql_agent(
|
176 |
llm=llm,
|
177 |
toolkit=toolkit,
|
@@ -179,7 +188,9 @@ def create_agent(llm, db_connection):
|
|
179 |
verbose=True,
|
180 |
handle_parsing_errors=True, # Better error handling for parsing
|
181 |
max_iterations=10, # Limit the number of iterations
|
182 |
-
early_stopping_method="generate" # Stop early if the agent is stuck
|
|
|
|
|
183 |
)
|
184 |
|
185 |
# Test the agent with a simple query
|
@@ -328,17 +339,22 @@ def convert_to_messages_format(chat_history):
|
|
328 |
return messages
|
329 |
|
330 |
async def stream_agent_response(question: str, chat_history: List) -> List[Dict]:
|
331 |
-
"""Procesa la pregunta del usuario y devuelve la respuesta del agente."""
|
332 |
# Initialize response
|
333 |
response_text = ""
|
334 |
messages = []
|
335 |
|
336 |
-
# Add previous chat history
|
337 |
if chat_history:
|
338 |
-
|
|
|
|
|
|
|
|
|
|
|
339 |
|
340 |
-
# Add user's question
|
341 |
-
user_message =
|
342 |
messages.append(user_message)
|
343 |
|
344 |
if not agent:
|
@@ -351,7 +367,6 @@ async def stream_agent_response(question: str, chat_history: List) -> List[Dict]
|
|
351 |
f"Error: {agent_error}"
|
352 |
)
|
353 |
assistant_message = {"role": "assistant", "content": error_msg}
|
354 |
-
messages.append(assistant_message)
|
355 |
return [assistant_message]
|
356 |
|
357 |
try:
|
|
|
13 |
from langchain_community.utilities import SQLDatabase
|
14 |
from langchain_google_genai import ChatGoogleGenerativeAI
|
15 |
from langchain.agents.agent_types import AgentType
|
16 |
+
from langchain.memory import ConversationBufferWindowMemory
|
17 |
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
18 |
import pymysql
|
19 |
from dotenv import load_dotenv
|
|
|
152 |
logger.info("Starting agent creation process...")
|
153 |
|
154 |
def create_agent(llm, db_connection):
|
155 |
+
"""Create and return a SQL database agent with conversation memory."""
|
156 |
if not llm:
|
157 |
error_msg = "Cannot create agent: LLM is not available"
|
158 |
logger.error(error_msg)
|
|
|
164 |
return None, error_msg
|
165 |
|
166 |
try:
|
167 |
+
logger.info("Creating SQL agent with memory...")
|
168 |
+
|
169 |
+
# Create conversation memory
|
170 |
+
memory = ConversationBufferWindowMemory(
|
171 |
+
memory_key="chat_history",
|
172 |
+
k=5, # Keep last 5 message exchanges in memory
|
173 |
+
return_messages=True,
|
174 |
+
output_key="output"
|
175 |
+
)
|
176 |
|
177 |
# Create the database toolkit with additional configuration
|
178 |
toolkit = SQLDatabaseToolkit(
|
|
|
180 |
llm=llm
|
181 |
)
|
182 |
|
183 |
+
# Create the agent with memory and more detailed configuration
|
184 |
agent = create_sql_agent(
|
185 |
llm=llm,
|
186 |
toolkit=toolkit,
|
|
|
188 |
verbose=True,
|
189 |
handle_parsing_errors=True, # Better error handling for parsing
|
190 |
max_iterations=10, # Limit the number of iterations
|
191 |
+
early_stopping_method="generate", # Stop early if the agent is stuck
|
192 |
+
memory=memory, # Add memory to the agent
|
193 |
+
return_intermediate_steps=True # Important for memory to work properly
|
194 |
)
|
195 |
|
196 |
# Test the agent with a simple query
|
|
|
339 |
return messages
|
340 |
|
341 |
async def stream_agent_response(question: str, chat_history: List) -> List[Dict]:
|
342 |
+
"""Procesa la pregunta del usuario y devuelve la respuesta del agente con memoria de conversación."""
|
343 |
# Initialize response
|
344 |
response_text = ""
|
345 |
messages = []
|
346 |
|
347 |
+
# Add previous chat history in the correct format for the agent
|
348 |
if chat_history:
|
349 |
+
# Convert chat history to the format expected by the agent's memory
|
350 |
+
for msg in chat_history:
|
351 |
+
if msg["role"] == "user":
|
352 |
+
messages.append(HumanMessage(content=msg["content"]))
|
353 |
+
elif msg["role"] == "assistant":
|
354 |
+
messages.append(AIMessage(content=msg["content"]))
|
355 |
|
356 |
+
# Add current user's question
|
357 |
+
user_message = HumanMessage(content=question)
|
358 |
messages.append(user_message)
|
359 |
|
360 |
if not agent:
|
|
|
367 |
f"Error: {agent_error}"
|
368 |
)
|
369 |
assistant_message = {"role": "assistant", "content": error_msg}
|
|
|
370 |
return [assistant_message]
|
371 |
|
372 |
try:
|