Jeremy Live
commited on
Commit
·
683b6ad
1
Parent(s):
0e57fc2
Revert "memory per session"
Browse filesThis reverts commit 698b8b0faf1bc95aa16ebc5d7bb8c28bde8ecafc.
app.py
CHANGED
@@ -229,16 +229,8 @@ def create_agent():
|
|
229 |
|
230 |
logger.info("Starting agent creation process...")
|
231 |
|
232 |
-
def create_agent(llm, db_connection
|
233 |
-
"""Create and return a SQL database agent with conversation memory.
|
234 |
-
|
235 |
-
Args:
|
236 |
-
llm: Language model instance
|
237 |
-
db_connection: SQLDatabase connection
|
238 |
-
run_test: If True, executes a small test query which may add entries
|
239 |
-
to agent memory. Disable for per-session agents to start
|
240 |
-
with a clean memory.
|
241 |
-
"""
|
242 |
if not llm:
|
243 |
error_msg = "Cannot create agent: LLM is not available"
|
244 |
logger.error(error_msg)
|
@@ -252,14 +244,10 @@ def create_agent(llm, db_connection, *, run_test: bool = True):
|
|
252 |
try:
|
253 |
logger.info("Creating SQL agent with memory...")
|
254 |
|
255 |
-
# Create conversation memory
|
256 |
-
try:
|
257 |
-
memory_k = int(os.getenv("MEMORY_K", "8"))
|
258 |
-
except Exception:
|
259 |
-
memory_k = 8
|
260 |
memory = ConversationBufferWindowMemory(
|
261 |
memory_key="chat_history",
|
262 |
-
k=
|
263 |
return_messages=True,
|
264 |
output_key="output"
|
265 |
)
|
@@ -283,16 +271,15 @@ def create_agent(llm, db_connection, *, run_test: bool = True):
|
|
283 |
return_intermediate_steps=True # Important for memory to work properly
|
284 |
)
|
285 |
|
286 |
-
#
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
# # Continue even if test fails, as it might be due to model limitations
|
296 |
|
297 |
logger.info("SQL agent created successfully")
|
298 |
return agent, ""
|
@@ -510,7 +497,7 @@ def convert_to_messages_format(chat_history):
|
|
510 |
|
511 |
return messages
|
512 |
|
513 |
-
async def stream_agent_response(question: str, chat_history: List[List[str]]
|
514 |
"""Procesa la pregunta del usuario y devuelve la respuesta del agente con memoria de conversación."""
|
515 |
global agent # Make sure we can modify the agent's memory
|
516 |
|
@@ -530,8 +517,7 @@ async def stream_agent_response(question: str, chat_history: List[List[str]], se
|
|
530 |
user_message = HumanMessage(content=question)
|
531 |
messages.append(user_message)
|
532 |
|
533 |
-
|
534 |
-
if not active_agent:
|
535 |
error_msg = (
|
536 |
"## ⚠️ Error: Agente no inicializado\n\n"
|
537 |
"No se pudo inicializar el agente de base de datos. Por favor, verifica que:\n"
|
@@ -545,11 +531,11 @@ async def stream_agent_response(question: str, chat_history: List[List[str]], se
|
|
545 |
# Update the agent's memory with the full conversation history
|
546 |
try:
|
547 |
# Rebuild agent memory from chat history pairs
|
548 |
-
if hasattr(
|
549 |
-
|
550 |
for i in range(0, len(messages)-1, 2): # (user, assistant)
|
551 |
if i+1 < len(messages):
|
552 |
-
|
553 |
{"input": messages[i].content},
|
554 |
{"output": messages[i+1].content}
|
555 |
)
|
@@ -564,7 +550,7 @@ async def stream_agent_response(question: str, chat_history: List[List[str]], se
|
|
564 |
# Execute the agent with proper error handling
|
565 |
try:
|
566 |
# Let the agent use its memory; don't pass raw chat_history
|
567 |
-
response = await
|
568 |
logger.info(f"Agent response type: {type(response)}")
|
569 |
logger.info(f"Agent response content: {str(response)[:500]}...")
|
570 |
|
@@ -654,7 +640,7 @@ async def stream_agent_response(question: str, chat_history: List[List[str]], se
|
|
654 |
"Devuelve SOLO la consulta SQL en un bloque ```sql``` para responder a: "
|
655 |
f"{question}. No incluyas explicación ni texto adicional."
|
656 |
)
|
657 |
-
sql_only_resp = await
|
658 |
sql_only_text = str(sql_only_resp)
|
659 |
sql_query2 = extract_sql_query(sql_only_text)
|
660 |
if sql_query2 and looks_like_sql(sql_query2):
|
@@ -880,9 +866,6 @@ def create_ui():
|
|
880 |
if not env_ok:
|
881 |
gr.Warning("⚠️ " + env_message)
|
882 |
|
883 |
-
# Create session-scoped state for the agent with memory
|
884 |
-
session_agent_state = gr.State(value=None)
|
885 |
-
|
886 |
# Create the chat interface
|
887 |
with gr.Row():
|
888 |
chatbot = gr.Chatbot(
|
@@ -992,28 +975,12 @@ def create_ui():
|
|
992 |
# Hidden component for streaming output
|
993 |
streaming_output_display = gr.Textbox(visible=False)
|
994 |
|
995 |
-
return
|
996 |
-
demo,
|
997 |
-
chatbot,
|
998 |
-
chart_display,
|
999 |
-
question_input,
|
1000 |
-
submit_button,
|
1001 |
-
streaming_output_display,
|
1002 |
-
session_agent_state,
|
1003 |
-
)
|
1004 |
|
1005 |
def create_application():
|
1006 |
"""Create and configure the Gradio application."""
|
1007 |
# Create the UI components
|
1008 |
-
(
|
1009 |
-
demo,
|
1010 |
-
chatbot,
|
1011 |
-
chart_display,
|
1012 |
-
question_input,
|
1013 |
-
submit_button,
|
1014 |
-
streaming_output_display,
|
1015 |
-
session_agent_state,
|
1016 |
-
) = create_ui()
|
1017 |
|
1018 |
def user_message(user_input: str, chat_history: List[Dict[str, str]]) -> Tuple[str, List[Dict[str, str]]]:
|
1019 |
"""Add user message to chat history (messages format) and clear input."""
|
@@ -1030,36 +997,20 @@ def create_application():
|
|
1030 |
|
1031 |
return "", chat_history
|
1032 |
|
1033 |
-
async def bot_response(
|
1034 |
-
|
1035 |
-
session_agent,
|
1036 |
-
) -> Tuple[List[Dict[str, str]], Optional[go.Figure], Any]:
|
1037 |
-
"""Generate bot response using a session-scoped agent with memory."""
|
1038 |
if not chat_history:
|
1039 |
-
return chat_history, None
|
1040 |
|
1041 |
# Ensure last message is a user turn awaiting assistant reply
|
1042 |
last = chat_history[-1]
|
1043 |
if not isinstance(last, dict) or last.get("role") != "user" or not last.get("content"):
|
1044 |
-
return chat_history, None
|
1045 |
|
1046 |
try:
|
1047 |
question = last["content"]
|
1048 |
logger.info(f"Processing question: {question}")
|
1049 |
|
1050 |
-
# Ensure we have a session-specific agent with memory
|
1051 |
-
local_agent = session_agent
|
1052 |
-
if local_agent is None:
|
1053 |
-
try:
|
1054 |
-
local_agent, _ = create_agent(llm, db_connection, run_test=False)
|
1055 |
-
session_agent = local_agent
|
1056 |
-
logger.info("Created new session agent with memory")
|
1057 |
-
except Exception as e:
|
1058 |
-
logger.error(f"Could not create session agent: {e}")
|
1059 |
-
# Fallback to global agent if available
|
1060 |
-
session_agent = None
|
1061 |
-
local_agent = None
|
1062 |
-
|
1063 |
# Convert prior messages to pair history for stream_agent_response()
|
1064 |
pair_history: List[List[str]] = []
|
1065 |
i = 0
|
@@ -1078,11 +1029,7 @@ def create_application():
|
|
1078 |
i += 1
|
1079 |
|
1080 |
# Call the agent for this new user question
|
1081 |
-
assistant_message, chart_fig = await stream_agent_response(
|
1082 |
-
question,
|
1083 |
-
pair_history,
|
1084 |
-
session_agent=session_agent,
|
1085 |
-
)
|
1086 |
|
1087 |
# Append assistant message back into messages history
|
1088 |
chat_history.append({"role": "assistant", "content": assistant_message})
|
@@ -1119,14 +1066,14 @@ def create_application():
|
|
1119 |
)
|
1120 |
|
1121 |
logger.info("Response generation complete")
|
1122 |
-
return chat_history, chart_fig
|
1123 |
|
1124 |
except Exception as e:
|
1125 |
error_msg = f"## ❌ Error\n\nError al procesar la solicitud:\n\n```\n{str(e)}\n```"
|
1126 |
logger.error(error_msg, exc_info=True)
|
1127 |
# Ensure we add an assistant error message for the UI
|
1128 |
chat_history.append({"role": "assistant", "content": error_msg})
|
1129 |
-
return chat_history, None
|
1130 |
|
1131 |
# Event handlers
|
1132 |
with demo:
|
@@ -1138,8 +1085,8 @@ def create_application():
|
|
1138 |
queue=True
|
1139 |
).then(
|
1140 |
fn=bot_response,
|
1141 |
-
inputs=[chatbot
|
1142 |
-
outputs=[chatbot, chart_display
|
1143 |
api_name="ask"
|
1144 |
)
|
1145 |
|
@@ -1151,8 +1098,8 @@ def create_application():
|
|
1151 |
queue=True
|
1152 |
).then(
|
1153 |
fn=bot_response,
|
1154 |
-
inputs=[chatbot
|
1155 |
-
outputs=[chatbot, chart_display
|
1156 |
)
|
1157 |
|
1158 |
return demo
|
|
|
229 |
|
230 |
logger.info("Starting agent creation process...")
|
231 |
|
232 |
+
def create_agent(llm, db_connection):
|
233 |
+
"""Create and return a SQL database agent with conversation memory."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
234 |
if not llm:
|
235 |
error_msg = "Cannot create agent: LLM is not available"
|
236 |
logger.error(error_msg)
|
|
|
244 |
try:
|
245 |
logger.info("Creating SQL agent with memory...")
|
246 |
|
247 |
+
# Create conversation memory
|
|
|
|
|
|
|
|
|
248 |
memory = ConversationBufferWindowMemory(
|
249 |
memory_key="chat_history",
|
250 |
+
k=5, # Keep last 5 message exchanges in memory
|
251 |
return_messages=True,
|
252 |
output_key="output"
|
253 |
)
|
|
|
271 |
return_intermediate_steps=True # Important for memory to work properly
|
272 |
)
|
273 |
|
274 |
+
# Test the agent with a simple query
|
275 |
+
logger.info("Testing agent with a simple query...")
|
276 |
+
try:
|
277 |
+
test_query = "SELECT 1"
|
278 |
+
test_result = agent.run(test_query)
|
279 |
+
logger.info(f"Agent test query successful: {str(test_result)[:200]}...")
|
280 |
+
except Exception as e:
|
281 |
+
logger.warning(f"Agent test query failed (this might be expected): {str(e)}")
|
282 |
+
# Continue even if test fails, as it might be due to model limitations
|
|
|
283 |
|
284 |
logger.info("SQL agent created successfully")
|
285 |
return agent, ""
|
|
|
497 |
|
498 |
return messages
|
499 |
|
500 |
+
async def stream_agent_response(question: str, chat_history: List[List[str]]) -> Tuple[str, Optional["go.Figure"]]:
|
501 |
"""Procesa la pregunta del usuario y devuelve la respuesta del agente con memoria de conversación."""
|
502 |
global agent # Make sure we can modify the agent's memory
|
503 |
|
|
|
517 |
user_message = HumanMessage(content=question)
|
518 |
messages.append(user_message)
|
519 |
|
520 |
+
if not agent:
|
|
|
521 |
error_msg = (
|
522 |
"## ⚠️ Error: Agente no inicializado\n\n"
|
523 |
"No se pudo inicializar el agente de base de datos. Por favor, verifica que:\n"
|
|
|
531 |
# Update the agent's memory with the full conversation history
|
532 |
try:
|
533 |
# Rebuild agent memory from chat history pairs
|
534 |
+
if hasattr(agent, 'memory') and agent.memory is not None:
|
535 |
+
agent.memory.clear()
|
536 |
for i in range(0, len(messages)-1, 2): # (user, assistant)
|
537 |
if i+1 < len(messages):
|
538 |
+
agent.memory.save_context(
|
539 |
{"input": messages[i].content},
|
540 |
{"output": messages[i+1].content}
|
541 |
)
|
|
|
550 |
# Execute the agent with proper error handling
|
551 |
try:
|
552 |
# Let the agent use its memory; don't pass raw chat_history
|
553 |
+
response = await agent.ainvoke({"input": question})
|
554 |
logger.info(f"Agent response type: {type(response)}")
|
555 |
logger.info(f"Agent response content: {str(response)[:500]}...")
|
556 |
|
|
|
640 |
"Devuelve SOLO la consulta SQL en un bloque ```sql``` para responder a: "
|
641 |
f"{question}. No incluyas explicación ni texto adicional."
|
642 |
)
|
643 |
+
sql_only_resp = await agent.ainvoke({"input": sql_only_prompt})
|
644 |
sql_only_text = str(sql_only_resp)
|
645 |
sql_query2 = extract_sql_query(sql_only_text)
|
646 |
if sql_query2 and looks_like_sql(sql_query2):
|
|
|
866 |
if not env_ok:
|
867 |
gr.Warning("⚠️ " + env_message)
|
868 |
|
|
|
|
|
|
|
869 |
# Create the chat interface
|
870 |
with gr.Row():
|
871 |
chatbot = gr.Chatbot(
|
|
|
975 |
# Hidden component for streaming output
|
976 |
streaming_output_display = gr.Textbox(visible=False)
|
977 |
|
978 |
+
return demo, chatbot, chart_display, question_input, submit_button, streaming_output_display
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
979 |
|
980 |
def create_application():
|
981 |
"""Create and configure the Gradio application."""
|
982 |
# Create the UI components
|
983 |
+
demo, chatbot, chart_display, question_input, submit_button, streaming_output_display = create_ui()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
984 |
|
985 |
def user_message(user_input: str, chat_history: List[Dict[str, str]]) -> Tuple[str, List[Dict[str, str]]]:
|
986 |
"""Add user message to chat history (messages format) and clear input."""
|
|
|
997 |
|
998 |
return "", chat_history
|
999 |
|
1000 |
+
async def bot_response(chat_history: List[Dict[str, str]]) -> Tuple[List[Dict[str, str]], Optional[go.Figure]]:
|
1001 |
+
"""Generate bot response for messages-format chat history and return optional chart figure."""
|
|
|
|
|
|
|
1002 |
if not chat_history:
|
1003 |
+
return chat_history, None
|
1004 |
|
1005 |
# Ensure last message is a user turn awaiting assistant reply
|
1006 |
last = chat_history[-1]
|
1007 |
if not isinstance(last, dict) or last.get("role") != "user" or not last.get("content"):
|
1008 |
+
return chat_history, None
|
1009 |
|
1010 |
try:
|
1011 |
question = last["content"]
|
1012 |
logger.info(f"Processing question: {question}")
|
1013 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1014 |
# Convert prior messages to pair history for stream_agent_response()
|
1015 |
pair_history: List[List[str]] = []
|
1016 |
i = 0
|
|
|
1029 |
i += 1
|
1030 |
|
1031 |
# Call the agent for this new user question
|
1032 |
+
assistant_message, chart_fig = await stream_agent_response(question, pair_history)
|
|
|
|
|
|
|
|
|
1033 |
|
1034 |
# Append assistant message back into messages history
|
1035 |
chat_history.append({"role": "assistant", "content": assistant_message})
|
|
|
1066 |
)
|
1067 |
|
1068 |
logger.info("Response generation complete")
|
1069 |
+
return chat_history, chart_fig
|
1070 |
|
1071 |
except Exception as e:
|
1072 |
error_msg = f"## ❌ Error\n\nError al procesar la solicitud:\n\n```\n{str(e)}\n```"
|
1073 |
logger.error(error_msg, exc_info=True)
|
1074 |
# Ensure we add an assistant error message for the UI
|
1075 |
chat_history.append({"role": "assistant", "content": error_msg})
|
1076 |
+
return chat_history, None
|
1077 |
|
1078 |
# Event handlers
|
1079 |
with demo:
|
|
|
1085 |
queue=True
|
1086 |
).then(
|
1087 |
fn=bot_response,
|
1088 |
+
inputs=[chatbot],
|
1089 |
+
outputs=[chatbot, chart_display],
|
1090 |
api_name="ask"
|
1091 |
)
|
1092 |
|
|
|
1098 |
queue=True
|
1099 |
).then(
|
1100 |
fn=bot_response,
|
1101 |
+
inputs=[chatbot],
|
1102 |
+
outputs=[chatbot, chart_display]
|
1103 |
)
|
1104 |
|
1105 |
return demo
|